diff --git a/ethclient/gethclient/gethclient_test.go b/ethclient/gethclient/gethclient_test.go index 5a0f4d2534..de45b10695 100644 --- a/ethclient/gethclient/gethclient_test.go +++ b/ethclient/gethclient/gethclient_test.go @@ -39,11 +39,12 @@ import ( ) var ( - testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") - testAddr = crypto.PubkeyToAddress(testKey.PublicKey) - testSlot = common.HexToHash("0xdeadbeef") - testValue = crypto.Keccak256Hash(testSlot[:]) - testBalance = big.NewInt(2e15) + testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + testAddr = crypto.PubkeyToAddress(testKey.PublicKey) + testContract = common.HexToAddress("0xbeef") + testSlot = common.HexToHash("0xdeadbeef") + testValue = crypto.Keccak256Hash(testSlot[:]) + testBalance = big.NewInt(2e15) ) func newTestBackend(t *testing.T) (*node.Node, []*types.Block) { @@ -78,8 +79,9 @@ func newTestBackend(t *testing.T) (*node.Node, []*types.Block) { func generateTestChain() (*core.Genesis, []*types.Block) { genesis := &core.Genesis{ - Config: params.AllEthashProtocolChanges, - Alloc: core.GenesisAlloc{testAddr: {Balance: testBalance, Storage: map[common.Hash]common.Hash{testSlot: testValue}}}, + Config: params.AllEthashProtocolChanges, + Alloc: core.GenesisAlloc{testAddr: {Balance: testBalance, Storage: map[common.Hash]common.Hash{testSlot: testValue}}, + testContract: {Nonce: 1, Code: []byte{0x13, 0x37}}}, ExtraData: []byte("test genesis"), Timestamp: 9000, } @@ -103,8 +105,11 @@ func TestGethClient(t *testing.T) { test func(t *testing.T) }{ { - "TestGetProof", - func(t *testing.T) { testGetProof(t, client) }, + "TestGetProof1", + func(t *testing.T) { testGetProof(t, client, testAddr) }, + }, { + "TestGetProof2", + func(t *testing.T) { testGetProof(t, client, testContract) }, }, { "TestGetProofCanonicalizeKeys", func(t *testing.T) { testGetProofCanonicalizeKeys(t, client) }, @@ -201,38 +206,41 @@ func testAccessList(t *testing.T, client *rpc.Client) { } } -func testGetProof(t *testing.T, client *rpc.Client) { +func testGetProof(t *testing.T, client *rpc.Client, addr common.Address) { ec := New(client) ethcl := ethclient.NewClient(client) - result, err := ec.GetProof(context.Background(), testAddr, []string{testSlot.String()}, nil) + result, err := ec.GetProof(context.Background(), addr, []string{testSlot.String()}, nil) if err != nil { t.Fatal(err) } - if !bytes.Equal(result.Address[:], testAddr[:]) { - t.Fatalf("unexpected address, want: %v got: %v", testAddr, result.Address) + if result.Address != addr { + t.Fatalf("unexpected address, have: %v want: %v", result.Address, addr) } // test nonce - nonce, _ := ethcl.NonceAt(context.Background(), result.Address, nil) - if result.Nonce != nonce { + if nonce, _ := ethcl.NonceAt(context.Background(), addr, nil); result.Nonce != nonce { t.Fatalf("invalid nonce, want: %v got: %v", nonce, result.Nonce) } // test balance - balance, _ := ethcl.BalanceAt(context.Background(), result.Address, nil) - if result.Balance.Cmp(balance) != 0 { + if balance, _ := ethcl.BalanceAt(context.Background(), addr, nil); result.Balance.Cmp(balance) != 0 { t.Fatalf("invalid balance, want: %v got: %v", balance, result.Balance) } - // test storage if len(result.StorageProof) != 1 { t.Fatalf("invalid storage proof, want 1 proof, got %v proof(s)", len(result.StorageProof)) } - proof := result.StorageProof[0] - slotValue, _ := ethcl.StorageAt(context.Background(), testAddr, testSlot, nil) - if !bytes.Equal(slotValue, proof.Value.Bytes()) { - t.Fatalf("invalid storage proof value, want: %v, got: %v", slotValue, proof.Value.Bytes()) + for _, proof := range result.StorageProof { + if proof.Key != testSlot.String() { + t.Fatalf("invalid storage proof key, want: %q, got: %q", testSlot.String(), proof.Key) + } + slotValue, _ := ethcl.StorageAt(context.Background(), addr, common.HexToHash(proof.Key), nil) + if have, want := common.BigToHash(proof.Value), common.BytesToHash(slotValue); have != want { + t.Fatalf("addr %x, invalid storage proof value: have: %v, want: %v", addr, have, want) + } } - if proof.Key != testSlot.String() { - t.Fatalf("invalid storage proof key, want: %q, got: %q", testSlot.String(), proof.Key) + // test code + code, _ := ethcl.CodeAt(context.Background(), addr, nil) + if have, want := result.CodeHash, crypto.Keccak256Hash(code); have != want { + t.Fatalf("codehash wrong, have %v want %v ", have, want) } } diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index cf1960fcf6..640693132e 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -675,10 +675,6 @@ func (s *BlockChainAPI) GetProof(ctx context.Context, address common.Address, st keys = make([]common.Hash, len(storageKeys)) keyLengths = make([]int, len(storageKeys)) storageProof = make([]StorageResult, len(storageKeys)) - - storageTrie state.Trie - storageHash = types.EmptyRootHash - codeHash = types.EmptyCodeHash ) // Deserialize all keys. This prevents state access on invalid input. for i, hexKey := range storageKeys { @@ -688,51 +684,49 @@ func (s *BlockChainAPI) GetProof(ctx context.Context, address common.Address, st return nil, err } } - state, header, err := s.b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) - if state == nil || err != nil { + statedb, header, err := s.b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) + if statedb == nil || err != nil { return nil, err } - if storageRoot := state.GetStorageRoot(address); storageRoot != types.EmptyRootHash && storageRoot != (common.Hash{}) { - id := trie.StorageTrieID(header.Root, crypto.Keccak256Hash(address.Bytes()), storageRoot) - tr, err := trie.NewStateTrie(id, state.Database().TrieDB()) - if err != nil { - return nil, err - } - storageTrie = tr - } - // If we have a storageTrie, the account exists and we must update - // the storage root hash and the code hash. - if storageTrie != nil { - storageHash = storageTrie.Hash() - codeHash = state.GetCodeHash(address) - } - // Create the proofs for the storageKeys. - for i, key := range keys { - // Output key encoding is a bit special: if the input was a 32-byte hash, it is - // returned as such. Otherwise, we apply the QUANTITY encoding mandated by the - // JSON-RPC spec for getProof. This behavior exists to preserve backwards - // compatibility with older client versions. - var outputKey string - if keyLengths[i] != 32 { - outputKey = hexutil.EncodeBig(key.Big()) - } else { - outputKey = hexutil.Encode(key[:]) - } + codeHash := statedb.GetCodeHash(address) + storageRoot := statedb.GetStorageRoot(address) - if storageTrie == nil { - storageProof[i] = StorageResult{outputKey, &hexutil.Big{}, []string{}} - continue + if len(keys) > 0 { + var storageTrie state.Trie + if storageRoot != types.EmptyRootHash && storageRoot != (common.Hash{}) { + id := trie.StorageTrieID(header.Root, crypto.Keccak256Hash(address.Bytes()), storageRoot) + st, err := trie.NewStateTrie(id, statedb.Database().TrieDB()) + if err != nil { + return nil, err + } + storageTrie = st } - var proof proofList - if err := storageTrie.Prove(crypto.Keccak256(key.Bytes()), &proof); err != nil { - return nil, err + // Create the proofs for the storageKeys. + for i, key := range keys { + // Output key encoding is a bit special: if the input was a 32-byte hash, it is + // returned as such. Otherwise, we apply the QUANTITY encoding mandated by the + // JSON-RPC spec for getProof. This behavior exists to preserve backwards + // compatibility with older client versions. + var outputKey string + if keyLengths[i] != 32 { + outputKey = hexutil.EncodeBig(key.Big()) + } else { + outputKey = hexutil.Encode(key[:]) + } + if storageTrie == nil { + storageProof[i] = StorageResult{outputKey, &hexutil.Big{}, []string{}} + continue + } + var proof proofList + if err := storageTrie.Prove(crypto.Keccak256(key.Bytes()), &proof); err != nil { + return nil, err + } + value := (*hexutil.Big)(statedb.GetState(address, key).Big()) + storageProof[i] = StorageResult{outputKey, value, proof} } - value := (*hexutil.Big)(state.GetState(address, key).Big()) - storageProof[i] = StorageResult{outputKey, value, proof} } - // Create the accountProof. - tr, err := trie.NewStateTrie(trie.StateTrieID(header.Root), state.Database().TrieDB()) + tr, err := trie.NewStateTrie(trie.StateTrieID(header.Root), statedb.Database().TrieDB()) if err != nil { return nil, err } @@ -743,12 +737,12 @@ func (s *BlockChainAPI) GetProof(ctx context.Context, address common.Address, st return &AccountResult{ Address: address, AccountProof: accountProof, - Balance: (*hexutil.Big)(state.GetBalance(address)), + Balance: (*hexutil.Big)(statedb.GetBalance(address)), CodeHash: codeHash, - Nonce: hexutil.Uint64(state.GetNonce(address)), - StorageHash: storageHash, + Nonce: hexutil.Uint64(statedb.GetNonce(address)), + StorageHash: storageRoot, StorageProof: storageProof, - }, state.Error() + }, statedb.Error() } // decodeHash parses a hex-encoded 32-byte hash. The input may optionally