Skip to content

Commit

Permalink
Backfill eth getproof tests (#7092)
Browse files Browse the repository at this point in the history
This PR adds missing tests for eth_getProof and does some mild
refactoring with switching from strings to more strict types. It's
likely best/most easily reviewed commit by commit.

Note, the tests include quite a number of helper types and functions for
doing the proof validation. This is largely because unlike Geth,
Erigon's approach to trie computations only requires serializing the
trie nodes, not deserializing them. Consequently, it wasn't obvious how
to leverage the existing trie code for doing deserialization and proof
checks. I checked on Discord, but, there were no suggestions. Of course,
any feedback is welcome and I'd be happy to remove this code if it can
be avoided.

Additionally, I've opted to change the interface type for `GetProof` to
use a `common.Hash` for the storage keys instead of a `string`. I
_think_ this should be fairly safe, as until very recently it was
unimplemented. That being said, since it's an interface, it has the
potential to break other consumers, anyone who was generating mocks
against it etc. There's additionally a `GetStorageAt` that follows the
same parameter. I'd be happy to submit a PR modifying this one as well
if desired.

Also, as a small note, there is test code for checking storage proofs,
but, storage proofs aren't currently supported by the implementation. My
hope is to add storage proofs and historic proofs in a followup PR.

---------

Co-authored-by: Jason Yellick <jason@enya.ai>
  • Loading branch information
jyellick and jyellick committed Mar 14, 2023
1 parent 4f6d769 commit b21569c
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 12 deletions.
2 changes: 1 addition & 1 deletion cmd/rpcdaemon/commands/eth_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ type EthAPI interface {
SendTransaction(_ context.Context, txObject interface{}) (common.Hash, error)
Sign(ctx context.Context, _ common.Address, _ hexutil.Bytes) (hexutil.Bytes, error)
SignTransaction(_ context.Context, txObject interface{}) (common.Hash, error)
GetProof(ctx context.Context, address common.Address, storageKeys []string, blockNr rpc.BlockNumberOrHash) (*accounts.AccProofResult, error)
GetProof(ctx context.Context, address common.Address, storageKeys []common.Hash, blockNr rpc.BlockNumberOrHash) (*accounts.AccProofResult, error)
CreateAccessList(ctx context.Context, args ethapi2.CallArgs, blockNrOrHash *rpc.BlockNumberOrHash, optimizeGas *bool) (*accessListResult, error)

// Mining related (see ./eth_mining.go)
Expand Down
5 changes: 2 additions & 3 deletions cmd/rpcdaemon/commands/eth_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ func (api *APIImpl) EstimateGas(ctx context.Context, argsOrNil *ethapi2.CallArgs
}

// GetProof is partially implemented; no Storage proofs; only for the latest block
func (api *APIImpl) GetProof(ctx context.Context, address libcommon.Address, storageKeys []string, blockNrOrHash rpc.BlockNumberOrHash) (*accounts.AccProofResult, error) {
func (api *APIImpl) GetProof(ctx context.Context, address libcommon.Address, storageKeys []libcommon.Hash, blockNrOrHash rpc.BlockNumberOrHash) (*accounts.AccProofResult, error) {

tx, err := api.db.BeginRo(ctx)
if err != nil {
Expand Down Expand Up @@ -334,8 +334,7 @@ func (api *APIImpl) GetProof(ctx context.Context, address libcommon.Address, sto
rl.AddKey(addrHash[:])

loader := trie.NewFlatDBTrieLoader("getProof")
trace := true
if err := loader.Reset(rl, nil, nil, trace); err != nil {
if err := loader.Reset(rl, nil, nil, false); err != nil {
return nil, err
}

Expand Down
227 changes: 227 additions & 0 deletions cmd/rpcdaemon/commands/eth_call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ import (
"time"

libcommon "github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon-lib/common/length"
"github.com/ledgerwatch/erigon/turbo/trie"

"github.com/ledgerwatch/erigon/rlp"
"github.com/ledgerwatch/erigon/rpc/rpccfg"
"github.com/ledgerwatch/erigon/turbo/adapter/ethapi"

"github.com/holiman/uint256"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ledgerwatch/erigon-lib/gointerfaces/txpool"
"github.com/ledgerwatch/erigon-lib/kv"
Expand All @@ -26,6 +30,7 @@ import (
"github.com/ledgerwatch/erigon/core/rawdb"
"github.com/ledgerwatch/erigon/core/state"
"github.com/ledgerwatch/erigon/core/types"
"github.com/ledgerwatch/erigon/core/types/accounts"
"github.com/ledgerwatch/erigon/crypto"
"github.com/ledgerwatch/erigon/params"
"github.com/ledgerwatch/erigon/rpc"
Expand Down Expand Up @@ -97,6 +102,228 @@ func TestEthCallToPrunedBlock(t *testing.T) {
}
}

type valueNode []byte
type hashNode libcommon.Hash

type shortNode struct {
Key trie.Keybytes
Val any
}

type fullNode struct {
Children [17]any
}

func decodeRef(t *testing.T, buf []byte) (any, []byte) {
t.Helper()
kind, val, rest, err := rlp.Split(buf)
require.NoError(t, err)
switch {
case kind == rlp.List:
require.Less(t, len(buf)-len(rest), length.Hash, "embedded nodes must be less than hash size")
return decodeNode(t, buf), rest
case kind == rlp.String && len(val) == 0:
return nil, rest
case kind == rlp.String && len(val) == 32:
return hashNode(libcommon.CastToHash(val)), rest
default:
t.Fatalf("invalid RLP string size %d (want 0 through 32)", len(val))
return nil, rest
}
}

func decodeFull(t *testing.T, elems []byte) fullNode {
t.Helper()
n := fullNode{}
for i := 0; i < 16; i++ {
n.Children[i], elems = decodeRef(t, elems)
}
val, _, err := rlp.SplitString(elems)
require.NoError(t, err)
if len(val) > 0 {
n.Children[16] = valueNode(val)
}
return n
}

func decodeShort(t *testing.T, elems []byte) shortNode {
t.Helper()
kbuf, rest, err := rlp.SplitString(elems)
require.NoError(t, err)
kb := trie.CompactToKeybytes(kbuf)
if kb.Terminating {
val, _, err := rlp.SplitString(rest)
require.NoError(t, err)
return shortNode{
Key: kb,
Val: valueNode(val),
}
}

val, _ := decodeRef(t, rest)
return shortNode{
Key: kb,
Val: val,
}
}

func decodeNode(t *testing.T, encoded []byte) any {
t.Helper()
require.NotEmpty(t, encoded)
elems, _, err := rlp.SplitList(encoded)
require.NoError(t, err)
switch c, _ := rlp.CountValues(elems); c {
case 2:
return decodeShort(t, elems)
case 17:
return decodeFull(t, elems)
default:
t.Fatalf("invalid number of list elements: %v", c)
return nil // unreachable
}
}

// proofMap creates a map from hash to proof node
func proofMap(t *testing.T, proof []hexutil.Bytes) map[libcommon.Hash]any {
res := map[libcommon.Hash]any{}
for _, proofB := range proof {
res[crypto.Keccak256Hash(proofB)] = decodeNode(t, proofB)
}
return res
}

func verifyProof(t *testing.T, root libcommon.Hash, key []byte, proofs map[libcommon.Hash]any) []byte {
t.Helper()
key = (&trie.Keybytes{Data: key}).ToHex()
var node any = hashNode(root)
for {
switch nt := node.(type) {
case fullNode:
require.NotEmpty(t, key, "full nodes should not have values")
node, key = nt.Children[key[0]], key[1:]
case shortNode:
shortHex := nt.Key.ToHex()[:nt.Key.Nibbles()] // There is a trailing 0 on odd otherwise
require.LessOrEqual(t, len(shortHex), len(key))
require.Equal(t, shortHex, key[:len(shortHex)])
node, key = nt.Val, key[len(shortHex):]
case hashNode:
var ok bool
node, ok = proofs[libcommon.Hash(nt)]
require.True(t, ok, "missing hash %x", nt)
case valueNode:
require.Len(t, key, 0)
return nt
default:
t.Fatalf("unexpected type: %T", node)
}
}
}

func verifyAccountProof(t *testing.T, stateRoot libcommon.Hash, proof *accounts.AccProofResult) {
t.Helper()
accountKey := crypto.Keccak256(proof.Address[:])
pm := proofMap(t, proof.AccountProof)
value := verifyProof(t, stateRoot, accountKey, pm)

expected, err := rlp.EncodeToBytes([]any{
uint64(proof.Nonce),
proof.Balance.ToInt().Bytes(),
proof.StorageHash,
proof.CodeHash,
})
require.NoError(t, err)

require.Equal(t, expected, value)
}

func verifyStorageProof(t *testing.T, storageRoot libcommon.Hash, proof accounts.StorProofResult) {
t.Helper()

storageKey := crypto.Keccak256(proof.Key[:])
pm := proofMap(t, proof.Proof)
value := verifyProof(t, storageRoot, storageKey, pm)

expected, err := rlp.EncodeToBytes(proof.Value.ToInt().Bytes())
require.NoError(t, err)

require.Equal(t, expected, value)
}

func TestGetProof(t *testing.T) {
pruneTo := uint64(3)

m, bankAddress, _ := chainWithDeployedContract(t)
br := snapshotsync.NewBlockReaderWithSnapshots(m.BlockSnapshots, m.TransactionsV3)

doPrune(t, m.DB, pruneTo)

agg := m.HistoryV3Components()

stateCache := kvcache.New(kvcache.DefaultCoherentConfig)
api := NewEthAPI(NewBaseApi(nil, stateCache, br, agg, false, rpccfg.DefaultEvmCallTimeout, m.Engine), m.DB, nil, nil, nil, 5000000, 100_000)

tests := []struct {
name string
blockNum uint64
storageKeys []libcommon.Hash
expectedErr string
}{
{
name: "currentBlock",
blockNum: 2,
},
{
name: "withState",
blockNum: 2,
storageKeys: []libcommon.Hash{{1}},
expectedErr: "the method is currently not implemented: eth_getProof with storageKeys",
},
{
name: "olderBlock",
blockNum: 1,
expectedErr: "the method is currently not implemented: eth_getProof for block != latest",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proof, err := api.GetProof(
context.Background(),
bankAddress,
tt.storageKeys,
rpc.BlockNumberOrHashWithNumber(rpc.BlockNumber(tt.blockNum)),
)
if tt.expectedErr != "" {
require.EqualError(t, err, tt.expectedErr)
require.Nil(t, proof)
return
}

tx, err := m.DB.BeginRo(context.Background())
assert.NoError(t, err)
defer tx.Rollback()
header, err := api.headerByRPCNumber(rpc.BlockNumber(tt.blockNum), tx)
require.NoError(t, err)

require.NoError(t, err)
require.NotNil(t, proof)

require.Equal(t, bankAddress, proof.Address)
verifyAccountProof(t, header.Root, proof)

require.Equal(t, len(tt.storageKeys), len(proof.StorageProof))
for _, storageKey := range tt.storageKeys {
for _, storageProof := range proof.StorageProof {
if storageProof.Key != storageKey {
continue
}
verifyStorageProof(t, proof.StorageHash, storageProof)
}
}
})
}
}

func TestGetBlockByTimestampLatestTime(t *testing.T) {
ctx := context.Background()
m, _, _ := rpcdaemontest.CreateTestSentry(t)
Expand Down
8 changes: 4 additions & 4 deletions core/types/accounts/account_proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ import (
// Result structs for GetProof
type AccProofResult struct {
Address libcommon.Address `json:"address"`
AccountProof []string `json:"accountProof"`
AccountProof []hexutil.Bytes `json:"accountProof"`
Balance *hexutil.Big `json:"balance"`
CodeHash libcommon.Hash `json:"codeHash"`
Nonce hexutil.Uint64 `json:"nonce"`
StorageHash libcommon.Hash `json:"storageHash"`
StorageProof []StorProofResult `json:"storageProof"`
}
type StorProofResult struct {
Key string `json:"key"`
Value *hexutil.Big `json:"value"`
Proof []string `json:"proof"`
Key libcommon.Hash `json:"key"`
Value *hexutil.Big `json:"value"`
Proof []hexutil.Bytes `json:"proof"`
}
8 changes: 4 additions & 4 deletions turbo/trie/hashbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (hb *HashBuilder) Reset() {
}

func (hb *HashBuilder) SetProofReturn(accProofResult *accounts.AccProofResult) {
accProofResult.AccountProof = make([]string, 0)
accProofResult.AccountProof = make([]hexutil.Bytes, 0)
accProofResult.StorageProof = make([]accounts.StorProofResult, 0)
hb.accProofResult = accProofResult
}
Expand Down Expand Up @@ -202,7 +202,7 @@ func (hb *HashBuilder) completeLeafHash(kp, kl, compactLen int, key []byte, comp

if hb.collectNode {
nodeBytes := hexutil.Bytes(proofBuf.Bytes())
hb.accProofResult.AccountProof = append([]string{nodeBytes.String()}, hb.accProofResult.AccountProof...)
hb.accProofResult.AccountProof = append(hb.accProofResult.AccountProof, nodeBytes)
hb.collectNode = false
}

Expand Down Expand Up @@ -465,7 +465,7 @@ func (hb *HashBuilder) extensionHash(key []byte) error {

if hb.accProofResult != nil {
nodeBytes := hexutil.Bytes(proofBuf.Bytes())
hb.accProofResult.AccountProof = append([]string{nodeBytes.String()}, hb.accProofResult.AccountProof...)
hb.accProofResult.AccountProof = append(hb.accProofResult.AccountProof, nodeBytes)
}

hb.hashStack[len(hb.hashStack)-hashStackStride] = 0x80 + length2.Hash
Expand Down Expand Up @@ -581,7 +581,7 @@ func (hb *HashBuilder) branchHash(set uint16) error {

if hb.collectNode {
nodeBytes := hexutil.Bytes(proofBuf.Bytes())
hb.accProofResult.AccountProof = append([]string{nodeBytes.String()}, hb.accProofResult.AccountProof...)
hb.accProofResult.AccountProof = append(hb.accProofResult.AccountProof, nodeBytes)
hb.collectNode = false
}

Expand Down

0 comments on commit b21569c

Please sign in to comment.