From 98513a9cdf27fd24df6e8e7815afdc9b96cec552 Mon Sep 17 00:00:00 2001 From: Austin Larson Date: Wed, 9 Jul 2025 10:10:43 -0400 Subject: [PATCH 1/5] refactor: decouple work enqueuing from proofs --- x/sync/manager.go | 63 ++++++++++++++++++++++++++------------------- x/sync/sync_test.go | 49 +++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 26 deletions(-) diff --git a/x/sync/manager.go b/x/sync/manager.go index d4dafa030501..e0aa9e7d6dc6 100644 --- a/x/sync/manager.go +++ b/x/sync/manager.go @@ -327,7 +327,7 @@ func (m *Manager) requestChangeProof(ctx context.Context, work *workItem) { if work.localRootID == targetRootID { // Start root is the same as the end root, so we're done. - m.completeWorkItem(ctx, work, work.end, targetRootID, nil) + m.completeWorkItem(work, work.end, targetRootID) m.finishWorkItem() return } @@ -342,7 +342,7 @@ func (m *Manager) requestChangeProof(ctx context.Context, work *workItem) { return } work.start = maybe.Nothing[[]byte]() - m.completeWorkItem(ctx, work, maybe.Nothing[[]byte](), targetRootID, nil) + m.completeWorkItem(work, maybe.Nothing[[]byte](), targetRootID) return } @@ -401,7 +401,7 @@ func (m *Manager) requestRangeProof(ctx context.Context, work *workItem) { return } work.start = maybe.Nothing[[]byte]() - m.completeWorkItem(ctx, work, maybe.Nothing[[]byte](), targetRootID, nil) + m.completeWorkItem(work, maybe.Nothing[[]byte](), targetRootID) return } @@ -545,7 +545,17 @@ func (m *Manager) handleRangeProofResponse( largestHandledKey = maybe.Some(rangeProof.KeyChanges[len(rangeProof.KeyChanges)-1].Key) } - m.completeWorkItem(ctx, work, largestHandledKey, targetRootID, rangeProof.EndProof) + // Find the next key to fetch. + // If this is empty, then we have no more keys to fetch. + if !largestHandledKey.IsNothing() { + nextKey, err := m.findNextKey(ctx, largestHandledKey.Value(), work.end, rangeProof.EndProof) + if err != nil { + m.setError(err) + return nil + } + largestHandledKey = nextKey + } + m.completeWorkItem(work, largestHandledKey, targetRootID) return nil } @@ -569,6 +579,10 @@ func (m *Manager) handleChangeProofResponse( startKey := maybeBytesToMaybe(request.StartKey) endKey := maybeBytesToMaybe(request.EndKey) + var ( + largestHandledKey maybe.Maybe[[]byte] + endProof []merkledb.ProofNode + ) switch changeProofResp := changeProofResp.Response.(type) { case *pb.SyncGetChangeProofResponse_ChangeProof: // The server had enough history to send us a change proof @@ -601,7 +615,7 @@ func (m *Manager) handleChangeProofResponse( return fmt.Errorf("%w due to %w", errInvalidChangeProof, err) } - largestHandledKey := work.end + largestHandledKey = work.end // if the proof wasn't empty, apply changes to the sync DB if len(changeProof.KeyChanges) > 0 { if err := m.config.DB.CommitChangeProof(ctx, &changeProof); err != nil { @@ -610,8 +624,8 @@ func (m *Manager) handleChangeProofResponse( } largestHandledKey = maybe.Some(changeProof.KeyChanges[len(changeProof.KeyChanges)-1].Key) } + endProof = changeProof.EndProof - m.completeWorkItem(ctx, work, largestHandledKey, targetRootID, changeProof.EndProof) case *pb.SyncGetChangeProofResponse_RangeProof: var rangeProof merkledb.RangeProof if err := rangeProof.UnmarshalProto(changeProofResp.RangeProof); err != nil { @@ -633,7 +647,7 @@ func (m *Manager) handleChangeProofResponse( return err } - largestHandledKey := work.end + largestHandledKey = work.end if len(rangeProof.KeyChanges) > 0 { // Add all the key-value pairs we got to the database. if err := m.config.DB.CommitRangeProof(ctx, work.start, work.end, &rangeProof); err != nil { @@ -642,14 +656,25 @@ func (m *Manager) handleChangeProofResponse( } largestHandledKey = maybe.Some(rangeProof.KeyChanges[len(rangeProof.KeyChanges)-1].Key) } + endProof = rangeProof.EndProof - m.completeWorkItem(ctx, work, largestHandledKey, targetRootID, rangeProof.EndProof) default: return fmt.Errorf( "%w: %T", errUnexpectedChangeProofResponse, changeProofResp, ) } + // Find the next key to fetch. + // If this is empty, then we have no more keys to fetch. + if !largestHandledKey.IsNothing() { + nextKey, err := m.findNextKey(ctx, largestHandledKey.Value(), work.end, endProof) + if err != nil { + m.setError(err) + return nil + } + largestHandledKey = nextKey + } + m.completeWorkItem(work, largestHandledKey, targetRootID) return nil } @@ -933,28 +958,14 @@ func (m *Manager) setError(err error) { // that gave us the range up to and including [largestHandledKey]. // // Assumes [m.workLock] is not held. -func (m *Manager) completeWorkItem(ctx context.Context, work *workItem, largestHandledKey maybe.Maybe[[]byte], rootID ids.ID, proofOfLargestKey []merkledb.ProofNode) { +func (m *Manager) completeWorkItem(work *workItem, largestHandledKey maybe.Maybe[[]byte], rootID ids.ID) { if !maybe.Equal(largestHandledKey, work.end, bytes.Equal) { - // The largest handled key isn't equal to the end of the work item. - // Find the start of the next key range to fetch. - // Note that [largestHandledKey] can't be Nothing. - // Proof: Suppose it is. That means that we got a range/change proof that proved up to the - // greatest key-value pair in the database. That means we requested a proof with no upper - // bound. That is, [workItem.end] is Nothing. Since we're here, [bothNothing] is false, - // which means [workItem.end] isn't Nothing. Contradiction. - nextStartKey, err := m.findNextKey(ctx, largestHandledKey.Value(), work.end, proofOfLargestKey) - if err != nil { - m.setError(err) - return - } - - // nextStartKey being Nothing indicates that the entire range has been completed - if nextStartKey.IsNothing() { + // largestHandledKey being Nothing indicates that the entire range has been completed + if largestHandledKey.IsNothing() { largestHandledKey = work.end } else { // the full range wasn't completed, so enqueue a new work item for the range [nextStartKey, workItem.end] - m.enqueueWork(newWorkItem(work.localRootID, nextStartKey, work.end, work.priority, time.Now())) - largestHandledKey = nextStartKey + m.enqueueWork(newWorkItem(work.localRootID, largestHandledKey, work.end, work.priority, time.Now())) } } diff --git a/x/sync/sync_test.go b/x/sync/sync_test.go index c96594cafefe..79e1fe235723 100644 --- a/x/sync/sync_test.go +++ b/x/sync/sync_test.go @@ -422,6 +422,55 @@ func Test_Sync_FindNextKey_ExtraValues(t *testing.T) { require.True(isPrefix(midPointVal, nextKey.Value())) } +func Test_Sync_FindNextKey_IdenticalKeys(t *testing.T) { + require := require.New(t) + + db, err := merkledb.New( + context.Background(), + memdb.New(), + newDefaultDBConfig(), + ) + require.NoError(err) + + testKeys := [][]byte{ + {0x10}, + {0x11, 0x11}, + {0x12, 0x34}, + {0x15}, + } + + for i, key := range testKeys { + value := []byte{byte(i + 1)} + require.NoError(db.Put(key, value)) + } + + targetRoot, err := db.GetMerkleRoot(context.Background()) + require.NoError(err) + + // Get the proof for the test key + testKey := []byte{0x11, 0x11} + proof, err := db.GetRangeProof(context.Background(), maybe.Some(testKey), maybe.Some(testKey), 1) + require.NoError(err) + + ctx := context.Background() + syncer, err := NewManager(ManagerConfig{ + DB: db, + RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(db)), + ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(db)), + TargetRoot: targetRoot, + SimultaneousWorkLimit: 5, + Log: logging.NoLog{}, + BranchFactor: merkledb.BranchFactor16, + }, prometheus.NewRegistry()) + require.NoError(err) + + // Since both keys are identical, the next key should be nothing, since the range is complete + nextKey, err := syncer.findNextKey(context.Background(), testKey, maybe.Some([]byte{0x11, 0x11}), proof.EndProof) + require.NoError(err) + + require.Equal(maybe.Nothing[[]byte](), nextKey) +} + func TestFindNextKeyEmptyEndProof(t *testing.T) { require := require.New(t) now := time.Now().UnixNano() From 79712412829138e90957473a374f12e2a08343ae Mon Sep 17 00:00:00 2001 From: Austin Larson Date: Tue, 15 Jul 2025 14:50:56 -0400 Subject: [PATCH 2/5] refactor: abstract client-side proof logic --- x/sync/db.go | 22 +- x/sync/manager.go | 441 ++---------------------- x/sync/merkledb_client.go | 478 +++++++++++++++++++++++++ x/sync/proof_test.go | 658 +++++++++++++++++++++++++++++++++++ x/sync/sync_test.go | 708 +++----------------------------------- 5 files changed, 1222 insertions(+), 1085 deletions(-) create mode 100644 x/sync/merkledb_client.go create mode 100644 x/sync/proof_test.go diff --git a/x/sync/db.go b/x/sync/db.go index 5ed9061b5889..1d30c5944d5b 100644 --- a/x/sync/db.go +++ b/x/sync/db.go @@ -3,7 +3,14 @@ package sync -import "github.com/ava-labs/avalanchego/x/merkledb" +import ( + "context" + + "github.com/ava-labs/avalanchego/utils/maybe" + "github.com/ava-labs/avalanchego/x/merkledb" + + pb "github.com/ava-labs/avalanchego/proto/pb/sync" +) type DB interface { merkledb.Clearer @@ -12,3 +19,16 @@ type DB interface { merkledb.ChangeProofer merkledb.RangeProofer } + +type ProofClient interface { + merkledb.Clearer + merkledb.MerkleRootGetter + HandleRangeProofResponse(ctx context.Context, request *pb.SyncGetRangeProofRequest, responseBytes []byte, onFinish func(maybe.Maybe[[]byte])) error + HandleChangeProofResponse( + ctx context.Context, + request *pb.SyncGetChangeProofRequest, + responseBytes []byte, + onFinish func(maybe.Maybe[[]byte]), + ) error + RegisterErrorHandler(handler func(error)) +} diff --git a/x/sync/manager.go b/x/sync/manager.go index e0aa9e7d6dc6..6ddee2578bfa 100644 --- a/x/sync/manager.go +++ b/x/sync/manager.go @@ -9,14 +9,12 @@ import ( "errors" "fmt" "math" - "slices" "sync" "sync/atomic" "time" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" - "golang.org/x/exp/maps" "google.golang.org/protobuf/proto" "github.com/ava-labs/avalanchego/ids" @@ -24,7 +22,6 @@ import ( "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/maybe" "github.com/ava-labs/avalanchego/utils/set" - "github.com/ava-labs/avalanchego/x/merkledb" pb "github.com/ava-labs/avalanchego/proto/pb/sync" ) @@ -134,7 +131,6 @@ type Manager struct { // Set to true when StartSyncing is called. syncing bool closeOnce sync.Once - tokenSize int stateSyncNodeIdx uint32 metrics SyncMetrics @@ -142,16 +138,13 @@ type Manager struct { // TODO remove non-config values out of this struct type ManagerConfig struct { - DB DB + ProofClient ProofClient RangeProofClient *p2p.Client ChangeProofClient *p2p.Client SimultaneousWorkLimit int Log logging.Logger TargetRoot ids.ID - BranchFactor merkledb.BranchFactor StateSyncNodes []ids.NodeID - // If not specified, [merkledb.DefaultHasher] will be used. - Hasher merkledb.Hasher } func NewManager(config ManagerConfig, registerer prometheus.Registerer) (*Manager, error) { @@ -160,20 +153,13 @@ func NewManager(config ManagerConfig, registerer prometheus.Registerer) (*Manage return nil, ErrNoRangeProofClientProvided case config.ChangeProofClient == nil: return nil, ErrNoChangeProofClientProvided - case config.DB == nil: + case config.ProofClient == nil: return nil, ErrNoDatabaseProvided case config.Log == nil: return nil, ErrNoLogProvided case config.SimultaneousWorkLimit == 0: return nil, ErrZeroWorkLimit } - if err := config.BranchFactor.Valid(); err != nil { - return nil, err - } - - if config.Hasher == nil { - config.Hasher = merkledb.DefaultHasher - } metrics, err := NewMetrics("sync", registerer) if err != nil { @@ -185,10 +171,12 @@ func NewManager(config ManagerConfig, registerer prometheus.Registerer) (*Manage doneChan: make(chan struct{}), unprocessedWork: newWorkHeap(), processedWork: newWorkHeap(), - tokenSize: merkledb.BranchFactorToTokenSize[config.BranchFactor], metrics: metrics, } m.unprocessedWorkCond.L = &m.workLock + m.config.ProofClient.RegisterErrorHandler(func(err error) { + m.setError(err) + }) return m, nil } @@ -337,7 +325,7 @@ func (m *Manager) requestChangeProof(ctx context.Context, work *workItem) { // The trie is empty after this change. // Delete all the key-value pairs in the range. - if err := m.config.DB.Clear(); err != nil { + if err := m.config.ProofClient.Clear(); err != nil { m.setError(err) return } @@ -396,7 +384,7 @@ func (m *Manager) requestRangeProof(ctx context.Context, work *workItem) { if targetRootID == ids.Empty { defer m.finishWorkItem() - if err := m.config.DB.Clear(); err != nil { + if err := m.config.ProofClient.Clear(); err != nil { m.setError(err) return } @@ -510,53 +498,14 @@ func (m *Manager) handleRangeProofResponse( return err } - var rangeProofProto pb.RangeProof - if err := proto.Unmarshal(responseBytes, &rangeProofProto); err != nil { - return err - } - - var rangeProof merkledb.RangeProof - if err := rangeProof.UnmarshalProto(&rangeProofProto); err != nil { - return err - } - - if err := verifyRangeProof( + return m.config.ProofClient.HandleRangeProofResponse( ctx, - &rangeProof, - int(request.KeyLimit), - maybeBytesToMaybe(request.StartKey), - maybeBytesToMaybe(request.EndKey), - request.RootHash, - m.tokenSize, - m.config.Hasher, - ); err != nil { - return err - } - - largestHandledKey := work.end - - // Replace all the key-value pairs in the DB from start to end with values from the response. - if err := m.config.DB.CommitRangeProof(ctx, work.start, work.end, &rangeProof); err != nil { - m.setError(err) - return nil - } - - if len(rangeProof.KeyChanges) > 0 { - largestHandledKey = maybe.Some(rangeProof.KeyChanges[len(rangeProof.KeyChanges)-1].Key) - } - - // Find the next key to fetch. - // If this is empty, then we have no more keys to fetch. - if !largestHandledKey.IsNothing() { - nextKey, err := m.findNextKey(ctx, largestHandledKey.Value(), work.end, rangeProof.EndProof) - if err != nil { - m.setError(err) - return nil - } - largestHandledKey = nextKey - } - m.completeWorkItem(work, largestHandledKey, targetRootID) - return nil + request, + responseBytes, + func(largestHandledKey maybe.Maybe[[]byte]) { + m.completeWorkItem(work, largestHandledKey, targetRootID) + }, + ) } func (m *Manager) handleChangeProofResponse( @@ -571,282 +520,14 @@ func (m *Manager) handleChangeProofResponse( return err } - var changeProofResp pb.SyncGetChangeProofResponse - if err := proto.Unmarshal(responseBytes, &changeProofResp); err != nil { - return err - } - - startKey := maybeBytesToMaybe(request.StartKey) - endKey := maybeBytesToMaybe(request.EndKey) - - var ( - largestHandledKey maybe.Maybe[[]byte] - endProof []merkledb.ProofNode + return m.config.ProofClient.HandleChangeProofResponse( + ctx, + request, + responseBytes, + func(largestHandledKey maybe.Maybe[[]byte]) { + m.completeWorkItem(work, largestHandledKey, targetRootID) + }, ) - switch changeProofResp := changeProofResp.Response.(type) { - case *pb.SyncGetChangeProofResponse_ChangeProof: - // The server had enough history to send us a change proof - var changeProof merkledb.ChangeProof - if err := changeProof.UnmarshalProto(changeProofResp.ChangeProof); err != nil { - return err - } - - // Ensure the response does not contain more than the requested number of leaves - // and the start and end roots match the requested roots. - if len(changeProof.KeyChanges) > int(request.KeyLimit) { - return fmt.Errorf( - "%w: (%d) > %d)", - errTooManyKeys, len(changeProof.KeyChanges), request.KeyLimit, - ) - } - - endRoot, err := ids.ToID(request.EndRootHash) - if err != nil { - return err - } - - if err := m.config.DB.VerifyChangeProof( - ctx, - &changeProof, - startKey, - endKey, - endRoot, - ); err != nil { - return fmt.Errorf("%w due to %w", errInvalidChangeProof, err) - } - - largestHandledKey = work.end - // if the proof wasn't empty, apply changes to the sync DB - if len(changeProof.KeyChanges) > 0 { - if err := m.config.DB.CommitChangeProof(ctx, &changeProof); err != nil { - m.setError(err) - return nil - } - largestHandledKey = maybe.Some(changeProof.KeyChanges[len(changeProof.KeyChanges)-1].Key) - } - endProof = changeProof.EndProof - - case *pb.SyncGetChangeProofResponse_RangeProof: - var rangeProof merkledb.RangeProof - if err := rangeProof.UnmarshalProto(changeProofResp.RangeProof); err != nil { - return err - } - - // The server did not have enough history to send us a change proof - // so they sent a range proof instead. - if err := verifyRangeProof( - ctx, - &rangeProof, - int(request.KeyLimit), - startKey, - endKey, - request.EndRootHash, - m.tokenSize, - m.config.Hasher, - ); err != nil { - return err - } - - largestHandledKey = work.end - if len(rangeProof.KeyChanges) > 0 { - // Add all the key-value pairs we got to the database. - if err := m.config.DB.CommitRangeProof(ctx, work.start, work.end, &rangeProof); err != nil { - m.setError(err) - return nil - } - largestHandledKey = maybe.Some(rangeProof.KeyChanges[len(rangeProof.KeyChanges)-1].Key) - } - endProof = rangeProof.EndProof - - default: - return fmt.Errorf( - "%w: %T", - errUnexpectedChangeProofResponse, changeProofResp, - ) - } - // Find the next key to fetch. - // If this is empty, then we have no more keys to fetch. - if !largestHandledKey.IsNothing() { - nextKey, err := m.findNextKey(ctx, largestHandledKey.Value(), work.end, endProof) - if err != nil { - m.setError(err) - return nil - } - largestHandledKey = nextKey - } - m.completeWorkItem(work, largestHandledKey, targetRootID) - - return nil -} - -// findNextKey returns the start of the key range that should be fetched next -// given that we just received a range/change proof that proved a range of -// key-value pairs ending at [lastReceivedKey]. -// -// [rangeEnd] is the end of the range that we want to fetch. -// -// Returns Nothing if there are no more keys to fetch in [lastReceivedKey, rangeEnd]. -// -// [endProof] is the end proof of the last proof received. -// -// Invariant: [lastReceivedKey] < [rangeEnd]. -// If [rangeEnd] is Nothing it's considered > [lastReceivedKey]. -func (m *Manager) findNextKey( - ctx context.Context, - lastReceivedKey []byte, - rangeEnd maybe.Maybe[[]byte], - endProof []merkledb.ProofNode, -) (maybe.Maybe[[]byte], error) { - if len(endProof) == 0 { - // We try to find the next key to fetch by looking at the end proof. - // If the end proof is empty, we have no information to use. - // Start fetching from the next key after [lastReceivedKey]. - nextKey := lastReceivedKey - nextKey = append(nextKey, 0) - return maybe.Some(nextKey), nil - } - - // We want the first key larger than the [lastReceivedKey]. - // This is done by taking two proofs for the same key - // (one that was just received as part of a proof, and one from the local db) - // and traversing them from the longest key to the shortest key. - // For each node in these proofs, compare if the children of that node exist - // or have the same ID in the other proof. - proofKeyPath := merkledb.ToKey(lastReceivedKey) - - // If the received proof is an exclusion proof, the last node may be for a - // key that is after the [lastReceivedKey]. - // If the last received node's key is after the [lastReceivedKey], it can - // be removed to obtain a valid proof for a prefix of the [lastReceivedKey]. - if !proofKeyPath.HasPrefix(endProof[len(endProof)-1].Key) { - endProof = endProof[:len(endProof)-1] - // update the proofKeyPath to be for the prefix - proofKeyPath = endProof[len(endProof)-1].Key - } - - // get a proof for the same key as the received proof from the local db - localProofOfKey, err := m.config.DB.GetProof(ctx, proofKeyPath.Bytes()) - if err != nil { - return maybe.Nothing[[]byte](), err - } - localProofNodes := localProofOfKey.Path - - // The local proof may also be an exclusion proof with an extra node. - // Remove this extra node if it exists to get a proof of the same key as the received proof - if !proofKeyPath.HasPrefix(localProofNodes[len(localProofNodes)-1].Key) { - localProofNodes = localProofNodes[:len(localProofNodes)-1] - } - - nextKey := maybe.Nothing[[]byte]() - - // Add sentinel node back into the localProofNodes, if it is missing. - // Required to ensure that a common node exists in both proofs - if len(localProofNodes) > 0 && localProofNodes[0].Key.Length() != 0 { - sentinel := merkledb.ProofNode{ - Children: map[byte]ids.ID{ - localProofNodes[0].Key.Token(0, m.tokenSize): ids.Empty, - }, - } - localProofNodes = append([]merkledb.ProofNode{sentinel}, localProofNodes...) - } - - // Add sentinel node back into the endProof, if it is missing. - // Required to ensure that a common node exists in both proofs - if len(endProof) > 0 && endProof[0].Key.Length() != 0 { - sentinel := merkledb.ProofNode{ - Children: map[byte]ids.ID{ - endProof[0].Key.Token(0, m.tokenSize): ids.Empty, - }, - } - endProof = append([]merkledb.ProofNode{sentinel}, endProof...) - } - - localProofNodeIndex := len(localProofNodes) - 1 - receivedProofNodeIndex := len(endProof) - 1 - - // traverse the two proofs from the deepest nodes up to the sentinel node until a difference is found - for localProofNodeIndex >= 0 && receivedProofNodeIndex >= 0 && nextKey.IsNothing() { - localProofNode := localProofNodes[localProofNodeIndex] - receivedProofNode := endProof[receivedProofNodeIndex] - - // [deepestNode] is the proof node with the longest key (deepest in the trie) in the - // two proofs that hasn't been handled yet. - // [deepestNodeFromOtherProof] is the proof node from the other proof with - // the same key/depth if it exists, nil otherwise. - var deepestNode, deepestNodeFromOtherProof *merkledb.ProofNode - - // select the deepest proof node from the two proofs - switch { - case receivedProofNode.Key.Length() > localProofNode.Key.Length(): - // there was a branch node in the received proof that isn't in the local proof - // see if the received proof node has children not present in the local proof - deepestNode = &receivedProofNode - - // we have dealt with this received node, so move on to the next received node - receivedProofNodeIndex-- - - case localProofNode.Key.Length() > receivedProofNode.Key.Length(): - // there was a branch node in the local proof that isn't in the received proof - // see if the local proof node has children not present in the received proof - deepestNode = &localProofNode - - // we have dealt with this local node, so move on to the next local node - localProofNodeIndex-- - - default: - // the two nodes are at the same depth - // see if any of the children present in the local proof node are different - // from the children in the received proof node - deepestNode = &localProofNode - deepestNodeFromOtherProof = &receivedProofNode - - // we have dealt with this local node and received node, so move on to the next nodes - localProofNodeIndex-- - receivedProofNodeIndex-- - } - - // We only want to look at the children with keys greater than the proofKey. - // The proof key has the deepest node's key as a prefix, - // so only the next token of the proof key needs to be considered. - - // If the deepest node has the same key as [proofKeyPath], - // then all of its children have keys greater than the proof key, - // so we can start at the 0 token. - startingChildToken := 0 - - // If the deepest node has a key shorter than the key being proven, - // we can look at the next token index of the proof key to determine which of that - // node's children have keys larger than [proofKeyPath]. - // Any child with a token greater than the [proofKeyPath]'s token at that - // index will have a larger key. - if deepestNode.Key.Length() < proofKeyPath.Length() { - startingChildToken = int(proofKeyPath.Token(deepestNode.Key.Length(), m.tokenSize)) + 1 - } - - // determine if there are any differences in the children for the deepest unhandled node of the two proofs - if childIndex, hasDifference := findChildDifference(deepestNode, deepestNodeFromOtherProof, startingChildToken); hasDifference { - nextKey = maybe.Some(deepestNode.Key.Extend(merkledb.ToToken(childIndex, m.tokenSize)).Bytes()) - break - } - } - - // If the nextKey is before or equal to the [lastReceivedKey] - // then we couldn't find a better answer than the [lastReceivedKey]. - // Set the nextKey to [lastReceivedKey] + 0, which is the first key in - // the open range (lastReceivedKey, rangeEnd). - if nextKey.HasValue() && bytes.Compare(nextKey.Value(), lastReceivedKey) <= 0 { - nextKeyVal := slices.Clone(lastReceivedKey) - nextKeyVal = append(nextKeyVal, 0) - nextKey = maybe.Some(nextKeyVal) - } - - // If the [nextKey] is larger than the end of the range, return Nothing to signal that there is no next key in range - if rangeEnd.HasValue() && bytes.Compare(nextKey.Value(), rangeEnd.Value()) >= 0 { - return maybe.Nothing[[]byte](), nil - } - - // the nextKey is within the open range (lastReceivedKey, rangeEnd), so return it - return nextKey, nil } func (m *Manager) Error() error { @@ -873,7 +554,7 @@ func (m *Manager) Wait(ctx context.Context) error { return err } - root, err := m.config.DB.GetMerkleRoot(ctx) + root, err := m.config.ProofClient.GetMerkleRoot(ctx) if err != nil { return err } @@ -1105,84 +786,6 @@ func midPoint(startMaybe, endMaybe maybe.Maybe[[]byte]) maybe.Maybe[[]byte] { return maybe.Some(midpoint) } -// findChildDifference returns the first child index that is different between node 1 and node 2 if one exists and -// a bool indicating if any difference was found -func findChildDifference(node1, node2 *merkledb.ProofNode, startIndex int) (byte, bool) { - // Children indices >= [startIndex] present in at least one of the nodes. - childIndices := set.Set[byte]{} - for _, node := range []*merkledb.ProofNode{node1, node2} { - if node == nil { - continue - } - for key := range node.Children { - if int(key) >= startIndex { - childIndices.Add(key) - } - } - } - - sortedChildIndices := maps.Keys(childIndices) - slices.Sort(sortedChildIndices) - var ( - child1, child2 ids.ID - ok1, ok2 bool - ) - for _, childIndex := range sortedChildIndices { - if node1 != nil { - child1, ok1 = node1.Children[childIndex] - } - if node2 != nil { - child2, ok2 = node2.Children[childIndex] - } - // if one node has a child and the other doesn't or the children ids don't match, - // return the current child index as the first difference - if (ok1 || ok2) && child1 != child2 { - return childIndex, true - } - } - // there were no differences found - return 0, false -} - -// Verify [rangeProof] is a valid range proof for keys in [start, end] for -// root [rootBytes]. Returns [errTooManyKeys] if the response contains more -// than [keyLimit] keys. -func verifyRangeProof( - ctx context.Context, - rangeProof *merkledb.RangeProof, - keyLimit int, - start maybe.Maybe[[]byte], - end maybe.Maybe[[]byte], - rootBytes []byte, - tokenSize int, - hasher merkledb.Hasher, -) error { - root, err := ids.ToID(rootBytes) - if err != nil { - return err - } - - // Ensure the response does not contain more than the maximum requested number of leaves. - if len(rangeProof.KeyChanges) > keyLimit { - return fmt.Errorf( - "%w: (%d) > %d)", - errTooManyKeys, len(rangeProof.KeyChanges), keyLimit, - ) - } - - if err := rangeProof.Verify( - ctx, - start, - end, - root, - tokenSize, - hasher, - ); err != nil { - return fmt.Errorf("%w due to %w", errInvalidRangeProof, err) - } - return nil -} - func calculateBackoff(attempt int) time.Duration { if attempt == 0 { return 0 diff --git a/x/sync/merkledb_client.go b/x/sync/merkledb_client.go new file mode 100644 index 000000000000..76bddd72d1a2 --- /dev/null +++ b/x/sync/merkledb_client.go @@ -0,0 +1,478 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package sync + +import ( + "bytes" + "context" + "fmt" + "slices" + + "golang.org/x/exp/maps" + "google.golang.org/protobuf/proto" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/maybe" + "github.com/ava-labs/avalanchego/utils/set" + "github.com/ava-labs/avalanchego/x/merkledb" + + pb "github.com/ava-labs/avalanchego/proto/pb/sync" +) + +var _ ProofClient = (*xSyncClient)(nil) + +type ClientConfig struct { + Hasher merkledb.Hasher + BranchFactor merkledb.BranchFactor +} + +type xSyncClient struct { + db DB + config *ClientConfig + tokenSize int + setError func(error) +} + +func NewClient(db DB, config *ClientConfig) (*xSyncClient, error) { + if err := config.BranchFactor.Valid(); err != nil { + return nil, err + } + + if config.Hasher == nil { + config.Hasher = merkledb.DefaultHasher + } + + return &xSyncClient{ + db: db, + config: config, + tokenSize: merkledb.BranchFactorToTokenSize[config.BranchFactor], + }, nil +} + +func (c *xSyncClient) RegisterErrorHandler(handler func(error)) { + c.setError = handler +} + +func (c *xSyncClient) Clear() error { + return c.db.Clear() +} + +func (c *xSyncClient) GetMerkleRoot(ctx context.Context) (ids.ID, error) { + return c.db.GetMerkleRoot(ctx) +} + +func (c *xSyncClient) HandleRangeProofResponse(ctx context.Context, request *pb.SyncGetRangeProofRequest, responseBytes []byte, onFinish func(maybe.Maybe[[]byte])) error { + var rangeProofProto pb.RangeProof + if err := proto.Unmarshal(responseBytes, &rangeProofProto); err != nil { + return err + } + + var rangeProof merkledb.RangeProof + if err := rangeProof.UnmarshalProto(&rangeProofProto); err != nil { + return err + } + + start := maybeBytesToMaybe(request.StartKey) + end := maybeBytesToMaybe(request.EndKey) + + if err := verifyRangeProof( + ctx, + &rangeProof, + int(request.KeyLimit), + start, + end, + request.RootHash, + c.tokenSize, + c.config.Hasher, + ); err != nil { + return err + } + + largestHandledKey := end + + // Replace all the key-value pairs in the DB from start to end with values from the response. + if err := c.db.CommitRangeProof(ctx, start, end, &rangeProof); err != nil { + c.setError(err) + return nil + } + + if len(rangeProof.KeyChanges) > 0 { + largestHandledKey = maybe.Some(rangeProof.KeyChanges[len(rangeProof.KeyChanges)-1].Key) + } + + // Find the next key to fetch. + // If this is empty, then we have no more keys to fetch. + if !largestHandledKey.IsNothing() { + nextKey, err := c.findNextKey(ctx, largestHandledKey.Value(), end, rangeProof.EndProof) + if err != nil { + c.setError(err) + return nil + } + largestHandledKey = nextKey + } + + onFinish(largestHandledKey) + return nil +} + +func (c *xSyncClient) HandleChangeProofResponse( + ctx context.Context, + request *pb.SyncGetChangeProofRequest, + responseBytes []byte, + onFinish func(maybe.Maybe[[]byte]), +) error { + var changeProofResp pb.SyncGetChangeProofResponse + if err := proto.Unmarshal(responseBytes, &changeProofResp); err != nil { + return err + } + + startKey := maybeBytesToMaybe(request.StartKey) + endKey := maybeBytesToMaybe(request.EndKey) + + var ( + largestHandledKey maybe.Maybe[[]byte] + endProof []merkledb.ProofNode + ) + switch changeProofResp := changeProofResp.Response.(type) { + case *pb.SyncGetChangeProofResponse_ChangeProof: + // The server had enough history to send us a change proof + var changeProof merkledb.ChangeProof + if err := changeProof.UnmarshalProto(changeProofResp.ChangeProof); err != nil { + return err + } + + // Ensure the response does not contain more than the requested number of leaves + // and the start and end roots match the requested roots. + if len(changeProof.KeyChanges) > int(request.KeyLimit) { + return fmt.Errorf( + "%w: (%d) > %d)", + errTooManyKeys, len(changeProof.KeyChanges), request.KeyLimit, + ) + } + + endRoot, err := ids.ToID(request.EndRootHash) + if err != nil { + return err + } + + if err := c.db.VerifyChangeProof( + ctx, + &changeProof, + startKey, + endKey, + endRoot, + ); err != nil { + return fmt.Errorf("%w due to %w", errInvalidChangeProof, err) + } + + largestHandledKey = endKey + // if the proof wasn't empty, apply changes to the sync DB + if len(changeProof.KeyChanges) > 0 { + if err := c.db.CommitChangeProof(ctx, &changeProof); err != nil { + c.setError(err) + return nil + } + largestHandledKey = maybe.Some(changeProof.KeyChanges[len(changeProof.KeyChanges)-1].Key) + } + endProof = changeProof.EndProof + + case *pb.SyncGetChangeProofResponse_RangeProof: + var rangeProof merkledb.RangeProof + if err := rangeProof.UnmarshalProto(changeProofResp.RangeProof); err != nil { + return err + } + + // The server did not have enough history to send us a change proof + // so they sent a range proof instead. + if err := verifyRangeProof( + ctx, + &rangeProof, + int(request.KeyLimit), + startKey, + endKey, + request.EndRootHash, + c.tokenSize, + c.config.Hasher, + ); err != nil { + return err + } + + largestHandledKey = endKey + if len(rangeProof.KeyChanges) > 0 { + // Add all the key-value pairs we got to the database. + if err := c.db.CommitRangeProof(ctx, startKey, endKey, &rangeProof); err != nil { + c.setError(err) + return nil + } + largestHandledKey = maybe.Some(rangeProof.KeyChanges[len(rangeProof.KeyChanges)-1].Key) + } + endProof = rangeProof.EndProof + + default: + return fmt.Errorf( + "%w: %T", + errUnexpectedChangeProofResponse, changeProofResp, + ) + } + // Find the next key to fetch. + // If this is empty, then we have no more keys to fetch. + if !largestHandledKey.IsNothing() { + nextKey, err := c.findNextKey(ctx, largestHandledKey.Value(), endKey, endProof) + if err != nil { + c.setError(err) + return nil + } + largestHandledKey = nextKey + } + onFinish(largestHandledKey) + return nil +} + +// findNextKey returns the start of the key range that should be fetched next +// given that we just received a range/change proof that proved a range of +// key-value pairs ending at [lastReceivedKey]. +// +// [rangeEnd] is the end of the range that we want to fetch. +// +// Returns Nothing if there are no more keys to fetch in [lastReceivedKey, rangeEnd]. +// +// [endProof] is the end proof of the last proof received. +// +// Invariant: [lastReceivedKey] < [rangeEnd]. +// If [rangeEnd] is Nothing it's considered > [lastReceivedKey]. +func (c *xSyncClient) findNextKey( + ctx context.Context, + lastReceivedKey []byte, + rangeEnd maybe.Maybe[[]byte], + endProof []merkledb.ProofNode, +) (maybe.Maybe[[]byte], error) { + if len(endProof) == 0 { + // We try to find the next key to fetch by looking at the end proof. + // If the end proof is empty, we have no information to use. + // Start fetching from the next key after [lastReceivedKey]. + nextKey := lastReceivedKey + nextKey = append(nextKey, 0) + return maybe.Some(nextKey), nil + } + + // We want the first key larger than the [lastReceivedKey]. + // This is done by taking two proofs for the same key + // (one that was just received as part of a proof, and one from the local db) + // and traversing them from the longest key to the shortest key. + // For each node in these proofs, compare if the children of that node exist + // or have the same ID in the other proof. + proofKeyPath := merkledb.ToKey(lastReceivedKey) + + // If the received proof is an exclusion proof, the last node may be for a + // key that is after the [lastReceivedKey]. + // If the last received node's key is after the [lastReceivedKey], it can + // be removed to obtain a valid proof for a prefix of the [lastReceivedKey]. + if !proofKeyPath.HasPrefix(endProof[len(endProof)-1].Key) { + endProof = endProof[:len(endProof)-1] + // update the proofKeyPath to be for the prefix + proofKeyPath = endProof[len(endProof)-1].Key + } + + // get a proof for the same key as the received proof from the local db + localProofOfKey, err := c.db.GetProof(ctx, proofKeyPath.Bytes()) + if err != nil { + return maybe.Nothing[[]byte](), err + } + localProofNodes := localProofOfKey.Path + + // The local proof may also be an exclusion proof with an extra node. + // Remove this extra node if it exists to get a proof of the same key as the received proof + if !proofKeyPath.HasPrefix(localProofNodes[len(localProofNodes)-1].Key) { + localProofNodes = localProofNodes[:len(localProofNodes)-1] + } + + nextKey := maybe.Nothing[[]byte]() + + // Add sentinel node back into the localProofNodes, if it is missing. + // Required to ensure that a common node exists in both proofs + if len(localProofNodes) > 0 && localProofNodes[0].Key.Length() != 0 { + sentinel := merkledb.ProofNode{ + Children: map[byte]ids.ID{ + localProofNodes[0].Key.Token(0, c.tokenSize): ids.Empty, + }, + } + localProofNodes = append([]merkledb.ProofNode{sentinel}, localProofNodes...) + } + + // Add sentinel node back into the endProof, if it is missing. + // Required to ensure that a common node exists in both proofs + if len(endProof) > 0 && endProof[0].Key.Length() != 0 { + sentinel := merkledb.ProofNode{ + Children: map[byte]ids.ID{ + endProof[0].Key.Token(0, c.tokenSize): ids.Empty, + }, + } + endProof = append([]merkledb.ProofNode{sentinel}, endProof...) + } + + localProofNodeIndex := len(localProofNodes) - 1 + receivedProofNodeIndex := len(endProof) - 1 + + // traverse the two proofs from the deepest nodes up to the sentinel node until a difference is found + for localProofNodeIndex >= 0 && receivedProofNodeIndex >= 0 && nextKey.IsNothing() { + localProofNode := localProofNodes[localProofNodeIndex] + receivedProofNode := endProof[receivedProofNodeIndex] + + // [deepestNode] is the proof node with the longest key (deepest in the trie) in the + // two proofs that hasn't been handled yet. + // [deepestNodeFromOtherProof] is the proof node from the other proof with + // the same key/depth if it exists, nil otherwise. + var deepestNode, deepestNodeFromOtherProof *merkledb.ProofNode + + // select the deepest proof node from the two proofs + switch { + case receivedProofNode.Key.Length() > localProofNode.Key.Length(): + // there was a branch node in the received proof that isn't in the local proof + // see if the received proof node has children not present in the local proof + deepestNode = &receivedProofNode + + // we have dealt with this received node, so move on to the next received node + receivedProofNodeIndex-- + + case localProofNode.Key.Length() > receivedProofNode.Key.Length(): + // there was a branch node in the local proof that isn't in the received proof + // see if the local proof node has children not present in the received proof + deepestNode = &localProofNode + + // we have dealt with this local node, so move on to the next local node + localProofNodeIndex-- + + default: + // the two nodes are at the same depth + // see if any of the children present in the local proof node are different + // from the children in the received proof node + deepestNode = &localProofNode + deepestNodeFromOtherProof = &receivedProofNode + + // we have dealt with this local node and received node, so move on to the next nodes + localProofNodeIndex-- + receivedProofNodeIndex-- + } + + // We only want to look at the children with keys greater than the proofKey. + // The proof key has the deepest node's key as a prefix, + // so only the next token of the proof key needs to be considered. + + // If the deepest node has the same key as [proofKeyPath], + // then all of its children have keys greater than the proof key, + // so we can start at the 0 token. + startingChildToken := 0 + + // If the deepest node has a key shorter than the key being proven, + // we can look at the next token index of the proof key to determine which of that + // node's children have keys larger than [proofKeyPath]. + // Any child with a token greater than the [proofKeyPath]'s token at that + // index will have a larger key. + if deepestNode.Key.Length() < proofKeyPath.Length() { + startingChildToken = int(proofKeyPath.Token(deepestNode.Key.Length(), c.tokenSize)) + 1 + } + + // determine if there are any differences in the children for the deepest unhandled node of the two proofs + if childIndex, hasDifference := findChildDifference(deepestNode, deepestNodeFromOtherProof, startingChildToken); hasDifference { + nextKey = maybe.Some(deepestNode.Key.Extend(merkledb.ToToken(childIndex, c.tokenSize)).Bytes()) + break + } + } + + // If the nextKey is before or equal to the [lastReceivedKey] + // then we couldn't find a better answer than the [lastReceivedKey]. + // Set the nextKey to [lastReceivedKey] + 0, which is the first key in + // the open range (lastReceivedKey, rangeEnd). + if nextKey.HasValue() && bytes.Compare(nextKey.Value(), lastReceivedKey) <= 0 { + nextKeyVal := slices.Clone(lastReceivedKey) + nextKeyVal = append(nextKeyVal, 0) + nextKey = maybe.Some(nextKeyVal) + } + + // If the [nextKey] is larger than the end of the range, return Nothing to signal that there is no next key in range + if rangeEnd.HasValue() && bytes.Compare(nextKey.Value(), rangeEnd.Value()) >= 0 { + return maybe.Nothing[[]byte](), nil + } + + // the nextKey is within the open range (lastReceivedKey, rangeEnd), so return it + return nextKey, nil +} + +// findChildDifference returns the first child index that is different between node 1 and node 2 if one exists and +// a bool indicating if any difference was found +func findChildDifference(node1, node2 *merkledb.ProofNode, startIndex int) (byte, bool) { + // Children indices >= [startIndex] present in at least one of the nodes. + childIndices := set.Set[byte]{} + for _, node := range []*merkledb.ProofNode{node1, node2} { + if node == nil { + continue + } + for key := range node.Children { + if int(key) >= startIndex { + childIndices.Add(key) + } + } + } + + sortedChildIndices := maps.Keys(childIndices) + slices.Sort(sortedChildIndices) + var ( + child1, child2 ids.ID + ok1, ok2 bool + ) + for _, childIndex := range sortedChildIndices { + if node1 != nil { + child1, ok1 = node1.Children[childIndex] + } + if node2 != nil { + child2, ok2 = node2.Children[childIndex] + } + // if one node has a child and the other doesn't or the children ids don't match, + // return the current child index as the first difference + if (ok1 || ok2) && child1 != child2 { + return childIndex, true + } + } + // there were no differences found + return 0, false +} + +// Verify [rangeProof] is a valid range proof for keys in [start, end] for +// root [rootBytes]. Returns [errTooManyKeys] if the response contains more +// than [keyLimit] keys. +func verifyRangeProof( + ctx context.Context, + rangeProof *merkledb.RangeProof, + keyLimit int, + start maybe.Maybe[[]byte], + end maybe.Maybe[[]byte], + rootBytes []byte, + tokenSize int, + hasher merkledb.Hasher, +) error { + root, err := ids.ToID(rootBytes) + if err != nil { + return err + } + + // Ensure the response does not contain more than the maximum requested number of leaves. + if len(rangeProof.KeyChanges) > keyLimit { + return fmt.Errorf( + "%w: (%d) > %d)", + errTooManyKeys, len(rangeProof.KeyChanges), keyLimit, + ) + } + + if err := rangeProof.Verify( + ctx, + start, + end, + root, + tokenSize, + hasher, + ); err != nil { + return fmt.Errorf("%w due to %w", errInvalidRangeProof, err) + } + return nil +} diff --git a/x/sync/proof_test.go b/x/sync/proof_test.go new file mode 100644 index 000000000000..0d325d56f292 --- /dev/null +++ b/x/sync/proof_test.go @@ -0,0 +1,658 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package sync + +import ( + "bytes" + "context" + "math/rand" + "slices" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/database/memdb" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/network/p2p/p2ptest" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/maybe" + "github.com/ava-labs/avalanchego/x/merkledb" +) + +func Test_Sync_FindNextKey_InSync(t *testing.T) { + require := require.New(t) + + now := time.Now().UnixNano() + t.Logf("seed: %d", now) + r := rand.New(rand.NewSource(now)) // #nosec G404 + dbToSync, err := generateTrie(t, r, 1000) + require.NoError(err) + syncRoot, err := dbToSync.GetMerkleRoot(context.Background()) + require.NoError(err) + + db, err := merkledb.New( + context.Background(), + memdb.New(), + newDefaultDBConfig(), + ) + require.NoError(err) + + ctx := context.Background() + proofClient, err := NewClient(db, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + + syncer, err := NewManager(ManagerConfig{ + ProofClient: proofClient, + RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(dbToSync)), + ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(dbToSync)), + TargetRoot: syncRoot, + SimultaneousWorkLimit: 5, + Log: logging.NoLog{}, + }, prometheus.NewRegistry()) + require.NoError(err) + require.NotNil(syncer) + + require.NoError(syncer.Start(context.Background())) + require.NoError(syncer.Wait(context.Background())) + + proof, err := dbToSync.GetRangeProof(context.Background(), maybe.Nothing[[]byte](), maybe.Nothing[[]byte](), 500) + require.NoError(err) + + // the two dbs should be in sync, so next key should be nil + lastKey := proof.KeyChanges[len(proof.KeyChanges)-1].Key + nextKey, err := proofClient.findNextKey(context.Background(), lastKey, maybe.Nothing[[]byte](), proof.EndProof) + require.NoError(err) + require.True(nextKey.IsNothing()) + + // add an extra value to sync db past the last key returned + newKey := midPoint(maybe.Some(lastKey), maybe.Nothing[[]byte]()) + newKeyVal := newKey.Value() + require.NoError(db.Put(newKeyVal, []byte{1})) + + // create a range endpoint that is before the newly added key, but after the last key + endPointBeforeNewKey := make([]byte, 0, 2) + for i := 0; i < len(newKeyVal); i++ { + endPointBeforeNewKey = append(endPointBeforeNewKey, newKeyVal[i]) + + // we need the new key to be after the last key + // don't subtract anything from the current byte if newkey and lastkey are equal + if lastKey[i] == newKeyVal[i] { + continue + } + + // if the first nibble is > 0, subtract "1" from it + if endPointBeforeNewKey[i] >= 16 { + endPointBeforeNewKey[i] -= 16 + break + } + // if the second nibble > 0, subtract 1 from it + if endPointBeforeNewKey[i] > 0 { + endPointBeforeNewKey[i] -= 1 + break + } + // both nibbles were 0, so move onto the next byte + } + + nextKey, err = proofClient.findNextKey(context.Background(), lastKey, maybe.Some(endPointBeforeNewKey), proof.EndProof) + require.NoError(err) + + // next key would be after the end of the range, so it returns Nothing instead + require.True(nextKey.IsNothing()) +} + +func Test_Sync_FindNextKey_Deleted(t *testing.T) { + require := require.New(t) + + db, err := merkledb.New( + context.Background(), + memdb.New(), + newDefaultDBConfig(), + ) + require.NoError(err) + require.NoError(db.Put([]byte{0x10}, []byte{1})) + require.NoError(db.Put([]byte{0x11, 0x11}, []byte{2})) + + syncRoot, err := db.GetMerkleRoot(context.Background()) + require.NoError(err) + + ctx := context.Background() + proofClient, err := NewClient(db, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + + _, err = NewManager(ManagerConfig{ + ProofClient: proofClient, + RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(db)), + ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(db)), + TargetRoot: syncRoot, + SimultaneousWorkLimit: 5, + Log: logging.NoLog{}, + }, prometheus.NewRegistry()) + require.NoError(err) + + // 0x12 was "deleted" and there should be no extra node in the proof since there was nothing with a common prefix + noExtraNodeProof, err := db.GetProof(context.Background(), []byte{0x12}) + require.NoError(err) + + // 0x11 was "deleted" and 0x11.0x11 should be in the exclusion proof + extraNodeProof, err := db.GetProof(context.Background(), []byte{0x11}) + require.NoError(err) + + // there is now another value in the range that needs to be sync'ed + require.NoError(db.Put([]byte{0x13}, []byte{3})) + + nextKey, err := proofClient.findNextKey(context.Background(), []byte{0x12}, maybe.Some([]byte{0x20}), noExtraNodeProof.Path) + require.NoError(err) + require.Equal(maybe.Some([]byte{0x13}), nextKey) + + nextKey, err = proofClient.findNextKey(context.Background(), []byte{0x11}, maybe.Some([]byte{0x20}), extraNodeProof.Path) + require.NoError(err) + require.Equal(maybe.Some([]byte{0x13}), nextKey) +} + +func Test_Sync_FindNextKey_BranchInLocal(t *testing.T) { + require := require.New(t) + + db, err := merkledb.New( + context.Background(), + memdb.New(), + newDefaultDBConfig(), + ) + require.NoError(err) + require.NoError(db.Put([]byte{0x11}, []byte{1})) + require.NoError(db.Put([]byte{0x11, 0x11}, []byte{2})) + + targetRoot, err := db.GetMerkleRoot(context.Background()) + require.NoError(err) + + proof, err := db.GetProof(context.Background(), []byte{0x11, 0x11}) + require.NoError(err) + + ctx := context.Background() + proofClient, err := NewClient(db, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + + _, err = NewManager(ManagerConfig{ + ProofClient: proofClient, + RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(db)), + ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(db)), + TargetRoot: targetRoot, + SimultaneousWorkLimit: 5, + Log: logging.NoLog{}, + }, prometheus.NewRegistry()) + require.NoError(err) + require.NoError(db.Put([]byte{0x11, 0x15}, []byte{4})) + + nextKey, err := proofClient.findNextKey(context.Background(), []byte{0x11, 0x11}, maybe.Some([]byte{0x20}), proof.Path) + require.NoError(err) + require.Equal(maybe.Some([]byte{0x11, 0x15}), nextKey) +} + +func Test_Sync_FindNextKey_BranchInReceived(t *testing.T) { + require := require.New(t) + + db, err := merkledb.New( + context.Background(), + memdb.New(), + newDefaultDBConfig(), + ) + require.NoError(err) + require.NoError(db.Put([]byte{0x11}, []byte{1})) + require.NoError(db.Put([]byte{0x12}, []byte{2})) + require.NoError(db.Put([]byte{0x12, 0xA0}, []byte{4})) + + targetRoot, err := db.GetMerkleRoot(context.Background()) + require.NoError(err) + + proof, err := db.GetProof(context.Background(), []byte{0x12}) + require.NoError(err) + + ctx := context.Background() + proofClient, err := NewClient(db, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + + _, err = NewManager(ManagerConfig{ + ProofClient: proofClient, + RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(db)), + ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(db)), + TargetRoot: targetRoot, + SimultaneousWorkLimit: 5, + Log: logging.NoLog{}, + }, prometheus.NewRegistry()) + require.NoError(err) + require.NoError(db.Delete([]byte{0x12, 0xA0})) + + nextKey, err := proofClient.findNextKey(context.Background(), []byte{0x12}, maybe.Some([]byte{0x20}), proof.Path) + require.NoError(err) + require.Equal(maybe.Some([]byte{0x12, 0xA0}), nextKey) +} + +func Test_Sync_FindNextKey_ExtraValues(t *testing.T) { + require := require.New(t) + + now := time.Now().UnixNano() + t.Logf("seed: %d", now) + r := rand.New(rand.NewSource(now)) // #nosec G404 + dbToSync, err := generateTrie(t, r, 1000) + require.NoError(err) + syncRoot, err := dbToSync.GetMerkleRoot(context.Background()) + require.NoError(err) + + db, err := merkledb.New( + context.Background(), + memdb.New(), + newDefaultDBConfig(), + ) + require.NoError(err) + + ctx := context.Background() + proofClient, err := NewClient(db, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + + syncer, err := NewManager(ManagerConfig{ + ProofClient: proofClient, + RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(dbToSync)), + ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(dbToSync)), + TargetRoot: syncRoot, + SimultaneousWorkLimit: 5, + Log: logging.NoLog{}, + }, prometheus.NewRegistry()) + require.NoError(err) + require.NotNil(syncer) + + require.NoError(syncer.Start(context.Background())) + require.NoError(syncer.Wait(context.Background())) + + proof, err := dbToSync.GetRangeProof(context.Background(), maybe.Nothing[[]byte](), maybe.Nothing[[]byte](), 500) + require.NoError(err) + + // add an extra value to local db + lastKey := proof.KeyChanges[len(proof.KeyChanges)-1].Key + midpoint := midPoint(maybe.Some(lastKey), maybe.Nothing[[]byte]()) + midPointVal := midpoint.Value() + + require.NoError(db.Put(midPointVal, []byte{1})) + + // next key at prefix of newly added point + nextKey, err := proofClient.findNextKey(context.Background(), lastKey, maybe.Nothing[[]byte](), proof.EndProof) + require.NoError(err) + require.True(nextKey.HasValue()) + + require.True(isPrefix(midPointVal, nextKey.Value())) + + require.NoError(db.Delete(midPointVal)) + + require.NoError(dbToSync.Put(midPointVal, []byte{1})) + + proof, err = dbToSync.GetRangeProof(context.Background(), maybe.Nothing[[]byte](), maybe.Some(lastKey), 500) + require.NoError(err) + + // next key at prefix of newly added point + nextKey, err = proofClient.findNextKey(context.Background(), lastKey, maybe.Nothing[[]byte](), proof.EndProof) + require.NoError(err) + require.True(nextKey.HasValue()) + + // deal with odd length key + require.True(isPrefix(midPointVal, nextKey.Value())) +} + +func TestFindNextKeyEmptyEndProof(t *testing.T) { + require := require.New(t) + now := time.Now().UnixNano() + t.Logf("seed: %d", now) + r := rand.New(rand.NewSource(now)) // #nosec G404 + + db, err := merkledb.New( + context.Background(), + memdb.New(), + newDefaultDBConfig(), + ) + require.NoError(err) + + ctx := context.Background() + proofClient, err := NewClient(db, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + + syncer, err := NewManager(ManagerConfig{ + ProofClient: proofClient, + RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(db)), + ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(db)), + TargetRoot: ids.Empty, + SimultaneousWorkLimit: 5, + Log: logging.NoLog{}, + }, prometheus.NewRegistry()) + require.NoError(err) + require.NotNil(syncer) + + for i := 0; i < 100; i++ { + lastReceivedKeyLen := r.Intn(16) + lastReceivedKey := make([]byte, lastReceivedKeyLen) + _, _ = r.Read(lastReceivedKey) // #nosec G404 + + rangeEndLen := r.Intn(16) + rangeEndBytes := make([]byte, rangeEndLen) + _, _ = r.Read(rangeEndBytes) // #nosec G404 + + rangeEnd := maybe.Nothing[[]byte]() + if rangeEndLen > 0 { + rangeEnd = maybe.Some(rangeEndBytes) + } + + nextKey, err := proofClient.findNextKey( + context.Background(), + lastReceivedKey, + rangeEnd, + nil, /* endProof */ + ) + require.NoError(err) + require.Equal(maybe.Some(append(lastReceivedKey, 0)), nextKey) + } +} + +func isPrefix(data []byte, prefix []byte) bool { + if prefix[len(prefix)-1]%16 == 0 { + index := 0 + for ; index < len(prefix)-1; index++ { + if data[index] != prefix[index] { + return false + } + } + return data[index]>>4 == prefix[index]>>4 + } + return bytes.HasPrefix(data, prefix) +} + +func Test_Sync_FindNextKey_DifferentChild(t *testing.T) { + require := require.New(t) + + now := time.Now().UnixNano() + t.Logf("seed: %d", now) + r := rand.New(rand.NewSource(now)) // #nosec G404 + dbToSync, err := generateTrie(t, r, 500) + require.NoError(err) + syncRoot, err := dbToSync.GetMerkleRoot(context.Background()) + require.NoError(err) + + db, err := merkledb.New( + context.Background(), + memdb.New(), + newDefaultDBConfig(), + ) + require.NoError(err) + + ctx := context.Background() + proofClient, err := NewClient(db, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + + syncer, err := NewManager(ManagerConfig{ + ProofClient: proofClient, + RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(dbToSync)), + ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(dbToSync)), + TargetRoot: syncRoot, + SimultaneousWorkLimit: 5, + Log: logging.NoLog{}, + }, prometheus.NewRegistry()) + require.NoError(err) + require.NotNil(syncer) + + require.NoError(syncer.Start(context.Background())) + require.NoError(syncer.Wait(context.Background())) + + proof, err := dbToSync.GetRangeProof(context.Background(), maybe.Nothing[[]byte](), maybe.Nothing[[]byte](), 100) + require.NoError(err) + lastKey := proof.KeyChanges[len(proof.KeyChanges)-1].Key + + // local db has a different child than remote db + lastKey = append(lastKey, 16) + require.NoError(db.Put(lastKey, []byte{1})) + + require.NoError(dbToSync.Put(lastKey, []byte{2})) + + proof, err = dbToSync.GetRangeProof(context.Background(), maybe.Nothing[[]byte](), maybe.Some(proof.KeyChanges[len(proof.KeyChanges)-1].Key), 100) + require.NoError(err) + + nextKey, err := proofClient.findNextKey(context.Background(), proof.KeyChanges[len(proof.KeyChanges)-1].Key, maybe.Nothing[[]byte](), proof.EndProof) + require.NoError(err) + require.True(nextKey.HasValue()) + require.Equal(lastKey, nextKey.Value()) +} + +// Test findNextKey by computing the expected result in a naive, inefficient +// way and comparing it to the actual result +func TestFindNextKeyRandom(t *testing.T) { + now := time.Now().UnixNano() + t.Logf("seed: %d", now) + rand := rand.New(rand.NewSource(now)) // #nosec G404 + require := require.New(t) + + // Create a "remote" database and "local" database + remoteDB, err := merkledb.New( + context.Background(), + memdb.New(), + newDefaultDBConfig(), + ) + require.NoError(err) + + config := newDefaultDBConfig() + localDB, err := merkledb.New( + context.Background(), + memdb.New(), + config, + ) + require.NoError(err) + + var ( + numProofsToTest = 250 + numKeyValues = 250 + maxKeyLen = 256 + maxValLen = 256 + maxRangeStartLen = 8 + maxRangeEndLen = 8 + maxProofLen = 128 + ) + + // Put random keys into the databases + for _, db := range []database.Database{remoteDB, localDB} { + for i := 0; i < numKeyValues; i++ { + key := make([]byte, rand.Intn(maxKeyLen)) + _, _ = rand.Read(key) + val := make([]byte, rand.Intn(maxValLen)) + _, _ = rand.Read(val) + require.NoError(db.Put(key, val)) + } + } + + // Repeatedly generate end proofs from the remote database and compare + // the result of findNextKey to the expected result. + for proofIndex := 0; proofIndex < numProofsToTest; proofIndex++ { + // Generate a proof for a random key + var ( + rangeStart []byte + rangeEnd []byte + ) + // Generate a valid range start and end + for rangeStart == nil || bytes.Compare(rangeStart, rangeEnd) == 1 { + rangeStart = make([]byte, rand.Intn(maxRangeStartLen)+1) + _, _ = rand.Read(rangeStart) + rangeEnd = make([]byte, rand.Intn(maxRangeEndLen)+1) + _, _ = rand.Read(rangeEnd) + } + + startKey := maybe.Nothing[[]byte]() + if len(rangeStart) > 0 { + startKey = maybe.Some(rangeStart) + } + endKey := maybe.Nothing[[]byte]() + if len(rangeEnd) > 0 { + endKey = maybe.Some(rangeEnd) + } + + remoteProof, err := remoteDB.GetRangeProof( + context.Background(), + startKey, + endKey, + rand.Intn(maxProofLen)+1, + ) + require.NoError(err) + + if len(remoteProof.KeyChanges) == 0 { + continue + } + lastReceivedKey := remoteProof.KeyChanges[len(remoteProof.KeyChanges)-1].Key + + // Commit the proof to the local database as we do + // in the actual syncer. + require.NoError(localDB.CommitRangeProof( + context.Background(), + startKey, + endKey, + remoteProof, + )) + + localProof, err := localDB.GetProof( + context.Background(), + lastReceivedKey, + ) + require.NoError(err) + + type keyAndID struct { + key merkledb.Key + id ids.ID + } + + // Set of key prefix/ID pairs proven by the remote database's end proof. + remoteKeyIDs := []keyAndID{} + for _, node := range remoteProof.EndProof { + for childIdx, childID := range node.Children { + remoteKeyIDs = append(remoteKeyIDs, keyAndID{ + key: node.Key.Extend(merkledb.ToToken(childIdx, merkledb.BranchFactorToTokenSize[config.BranchFactor])), + id: childID, + }) + } + } + + // Set of key prefix/ID pairs proven by the local database's proof. + localKeyIDs := []keyAndID{} + for _, node := range localProof.Path { + for childIdx, childID := range node.Children { + localKeyIDs = append(localKeyIDs, keyAndID{ + key: node.Key.Extend(merkledb.ToToken(childIdx, merkledb.BranchFactorToTokenSize[config.BranchFactor])), + id: childID, + }) + } + } + + // Sort in ascending order by key prefix. + serializedPathCompare := func(i, j keyAndID) int { + return i.key.Compare(j.key) + } + slices.SortFunc(remoteKeyIDs, serializedPathCompare) + slices.SortFunc(localKeyIDs, serializedPathCompare) + + // Filter out keys that are before the last received key + findBounds := func(keyIDs []keyAndID) (int, int) { + var ( + firstIdxInRange = len(keyIDs) + firstIdxInRangeFound = false + firstIdxOutOfRange = len(keyIDs) + ) + for i, keyID := range keyIDs { + if !firstIdxInRangeFound && bytes.Compare(keyID.key.Bytes(), lastReceivedKey) > 0 { + firstIdxInRange = i + firstIdxInRangeFound = true + continue + } + if bytes.Compare(keyID.key.Bytes(), rangeEnd) > 0 { + firstIdxOutOfRange = i + break + } + } + return firstIdxInRange, firstIdxOutOfRange + } + + remoteFirstIdxAfterLastReceived, remoteFirstIdxAfterEnd := findBounds(remoteKeyIDs) + remoteKeyIDs = remoteKeyIDs[remoteFirstIdxAfterLastReceived:remoteFirstIdxAfterEnd] + + localFirstIdxAfterLastReceived, localFirstIdxAfterEnd := findBounds(localKeyIDs) + localKeyIDs = localKeyIDs[localFirstIdxAfterLastReceived:localFirstIdxAfterEnd] + + // Find smallest difference between the set of key/ID pairs proven by + // the remote/local proofs for key/ID pairs after the last received key. + var ( + smallestDiffKey merkledb.Key + foundDiff bool + ) + for i := 0; i < len(remoteKeyIDs) && i < len(localKeyIDs); i++ { + // See if the keys are different. + smaller, bigger := remoteKeyIDs[i], localKeyIDs[i] + if serializedPathCompare(localKeyIDs[i], remoteKeyIDs[i]) == -1 { + smaller, bigger = localKeyIDs[i], remoteKeyIDs[i] + } + + if smaller.key != bigger.key || smaller.id != bigger.id { + smallestDiffKey = smaller.key + foundDiff = true + break + } + } + if !foundDiff { + // All the keys were equal. The smallest diff is the next key + // in the longer of the lists (if they're not same length.) + if len(remoteKeyIDs) < len(localKeyIDs) { + smallestDiffKey = localKeyIDs[len(remoteKeyIDs)].key + } else if len(remoteKeyIDs) > len(localKeyIDs) { + smallestDiffKey = remoteKeyIDs[len(localKeyIDs)].key + } + } + + // Get the actual value from the syncer + ctx := context.Background() + proofClient, err := NewClient(localDB, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + + _, err = NewManager(ManagerConfig{ + ProofClient: proofClient, + RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(remoteDB)), + ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(remoteDB)), + TargetRoot: ids.GenerateTestID(), + SimultaneousWorkLimit: 5, + Log: logging.NoLog{}, + }, prometheus.NewRegistry()) + require.NoError(err) + + gotFirstDiff, err := proofClient.findNextKey( + context.Background(), + lastReceivedKey, + endKey, + remoteProof.EndProof, + ) + require.NoError(err) + + if bytes.Compare(smallestDiffKey.Bytes(), rangeEnd) >= 0 { + // The smallest key which differs is after the range end so the + // next key to get should be nil because we're done fetching the range. + require.True(gotFirstDiff.IsNothing()) + } else { + require.Equal(smallestDiffKey.Bytes(), gotFirstDiff.Value()) + } + } +} diff --git a/x/sync/sync_test.go b/x/sync/sync_test.go index 79e1fe235723..d22a47b4e641 100644 --- a/x/sync/sync_test.go +++ b/x/sync/sync_test.go @@ -14,7 +14,6 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" - "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/database/memdb" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/network/p2p" @@ -37,13 +36,17 @@ func Test_Creation(t *testing.T) { require.NoError(err) ctx := context.Background() + proofClient, err := NewClient(db, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + syncer, err := NewManager(ManagerConfig{ - DB: db, + ProofClient: proofClient, RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(db)), ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(db)), SimultaneousWorkLimit: 5, Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, }, prometheus.NewRegistry()) require.NoError(err) require.NotNil(syncer) @@ -70,14 +73,18 @@ func Test_Completion(t *testing.T) { require.NoError(err) ctx := context.Background() + proofClient, err := NewClient(db, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + syncer, err := NewManager(ManagerConfig{ - DB: db, + ProofClient: proofClient, RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(emptyDB)), ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(emptyDB)), TargetRoot: emptyRoot, SimultaneousWorkLimit: 5, Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, }, prometheus.NewRegistry()) require.NoError(err) require.NotNil(syncer) @@ -156,656 +163,6 @@ func Test_Midpoint(t *testing.T) { } } -func Test_Sync_FindNextKey_InSync(t *testing.T) { - require := require.New(t) - - now := time.Now().UnixNano() - t.Logf("seed: %d", now) - r := rand.New(rand.NewSource(now)) // #nosec G404 - dbToSync, err := generateTrie(t, r, 1000) - require.NoError(err) - syncRoot, err := dbToSync.GetMerkleRoot(context.Background()) - require.NoError(err) - - db, err := merkledb.New( - context.Background(), - memdb.New(), - newDefaultDBConfig(), - ) - require.NoError(err) - - ctx := context.Background() - syncer, err := NewManager(ManagerConfig{ - DB: db, - RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(dbToSync)), - ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(dbToSync)), - TargetRoot: syncRoot, - SimultaneousWorkLimit: 5, - Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, - }, prometheus.NewRegistry()) - require.NoError(err) - require.NotNil(syncer) - - require.NoError(syncer.Start(context.Background())) - require.NoError(syncer.Wait(context.Background())) - - proof, err := dbToSync.GetRangeProof(context.Background(), maybe.Nothing[[]byte](), maybe.Nothing[[]byte](), 500) - require.NoError(err) - - // the two dbs should be in sync, so next key should be nil - lastKey := proof.KeyChanges[len(proof.KeyChanges)-1].Key - nextKey, err := syncer.findNextKey(context.Background(), lastKey, maybe.Nothing[[]byte](), proof.EndProof) - require.NoError(err) - require.True(nextKey.IsNothing()) - - // add an extra value to sync db past the last key returned - newKey := midPoint(maybe.Some(lastKey), maybe.Nothing[[]byte]()) - newKeyVal := newKey.Value() - require.NoError(db.Put(newKeyVal, []byte{1})) - - // create a range endpoint that is before the newly added key, but after the last key - endPointBeforeNewKey := make([]byte, 0, 2) - for i := 0; i < len(newKeyVal); i++ { - endPointBeforeNewKey = append(endPointBeforeNewKey, newKeyVal[i]) - - // we need the new key to be after the last key - // don't subtract anything from the current byte if newkey and lastkey are equal - if lastKey[i] == newKeyVal[i] { - continue - } - - // if the first nibble is > 0, subtract "1" from it - if endPointBeforeNewKey[i] >= 16 { - endPointBeforeNewKey[i] -= 16 - break - } - // if the second nibble > 0, subtract 1 from it - if endPointBeforeNewKey[i] > 0 { - endPointBeforeNewKey[i] -= 1 - break - } - // both nibbles were 0, so move onto the next byte - } - - nextKey, err = syncer.findNextKey(context.Background(), lastKey, maybe.Some(endPointBeforeNewKey), proof.EndProof) - require.NoError(err) - - // next key would be after the end of the range, so it returns Nothing instead - require.True(nextKey.IsNothing()) -} - -func Test_Sync_FindNextKey_Deleted(t *testing.T) { - require := require.New(t) - - db, err := merkledb.New( - context.Background(), - memdb.New(), - newDefaultDBConfig(), - ) - require.NoError(err) - require.NoError(db.Put([]byte{0x10}, []byte{1})) - require.NoError(db.Put([]byte{0x11, 0x11}, []byte{2})) - - syncRoot, err := db.GetMerkleRoot(context.Background()) - require.NoError(err) - - ctx := context.Background() - syncer, err := NewManager(ManagerConfig{ - DB: db, - RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(db)), - ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(db)), - TargetRoot: syncRoot, - SimultaneousWorkLimit: 5, - Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, - }, prometheus.NewRegistry()) - require.NoError(err) - - // 0x12 was "deleted" and there should be no extra node in the proof since there was nothing with a common prefix - noExtraNodeProof, err := db.GetProof(context.Background(), []byte{0x12}) - require.NoError(err) - - // 0x11 was "deleted" and 0x11.0x11 should be in the exclusion proof - extraNodeProof, err := db.GetProof(context.Background(), []byte{0x11}) - require.NoError(err) - - // there is now another value in the range that needs to be sync'ed - require.NoError(db.Put([]byte{0x13}, []byte{3})) - - nextKey, err := syncer.findNextKey(context.Background(), []byte{0x12}, maybe.Some([]byte{0x20}), noExtraNodeProof.Path) - require.NoError(err) - require.Equal(maybe.Some([]byte{0x13}), nextKey) - - nextKey, err = syncer.findNextKey(context.Background(), []byte{0x11}, maybe.Some([]byte{0x20}), extraNodeProof.Path) - require.NoError(err) - require.Equal(maybe.Some([]byte{0x13}), nextKey) -} - -func Test_Sync_FindNextKey_BranchInLocal(t *testing.T) { - require := require.New(t) - - db, err := merkledb.New( - context.Background(), - memdb.New(), - newDefaultDBConfig(), - ) - require.NoError(err) - require.NoError(db.Put([]byte{0x11}, []byte{1})) - require.NoError(db.Put([]byte{0x11, 0x11}, []byte{2})) - - targetRoot, err := db.GetMerkleRoot(context.Background()) - require.NoError(err) - - proof, err := db.GetProof(context.Background(), []byte{0x11, 0x11}) - require.NoError(err) - - ctx := context.Background() - syncer, err := NewManager(ManagerConfig{ - DB: db, - RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(db)), - ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(db)), - TargetRoot: targetRoot, - SimultaneousWorkLimit: 5, - Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, - }, prometheus.NewRegistry()) - require.NoError(err) - require.NoError(db.Put([]byte{0x11, 0x15}, []byte{4})) - - nextKey, err := syncer.findNextKey(context.Background(), []byte{0x11, 0x11}, maybe.Some([]byte{0x20}), proof.Path) - require.NoError(err) - require.Equal(maybe.Some([]byte{0x11, 0x15}), nextKey) -} - -func Test_Sync_FindNextKey_BranchInReceived(t *testing.T) { - require := require.New(t) - - db, err := merkledb.New( - context.Background(), - memdb.New(), - newDefaultDBConfig(), - ) - require.NoError(err) - require.NoError(db.Put([]byte{0x11}, []byte{1})) - require.NoError(db.Put([]byte{0x12}, []byte{2})) - require.NoError(db.Put([]byte{0x12, 0xA0}, []byte{4})) - - targetRoot, err := db.GetMerkleRoot(context.Background()) - require.NoError(err) - - proof, err := db.GetProof(context.Background(), []byte{0x12}) - require.NoError(err) - - ctx := context.Background() - syncer, err := NewManager(ManagerConfig{ - DB: db, - RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(db)), - ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(db)), - TargetRoot: targetRoot, - SimultaneousWorkLimit: 5, - Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, - }, prometheus.NewRegistry()) - require.NoError(err) - require.NoError(db.Delete([]byte{0x12, 0xA0})) - - nextKey, err := syncer.findNextKey(context.Background(), []byte{0x12}, maybe.Some([]byte{0x20}), proof.Path) - require.NoError(err) - require.Equal(maybe.Some([]byte{0x12, 0xA0}), nextKey) -} - -func Test_Sync_FindNextKey_ExtraValues(t *testing.T) { - require := require.New(t) - - now := time.Now().UnixNano() - t.Logf("seed: %d", now) - r := rand.New(rand.NewSource(now)) // #nosec G404 - dbToSync, err := generateTrie(t, r, 1000) - require.NoError(err) - syncRoot, err := dbToSync.GetMerkleRoot(context.Background()) - require.NoError(err) - - db, err := merkledb.New( - context.Background(), - memdb.New(), - newDefaultDBConfig(), - ) - require.NoError(err) - - ctx := context.Background() - syncer, err := NewManager(ManagerConfig{ - DB: db, - RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(dbToSync)), - ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(dbToSync)), - TargetRoot: syncRoot, - SimultaneousWorkLimit: 5, - Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, - }, prometheus.NewRegistry()) - require.NoError(err) - require.NotNil(syncer) - - require.NoError(syncer.Start(context.Background())) - require.NoError(syncer.Wait(context.Background())) - - proof, err := dbToSync.GetRangeProof(context.Background(), maybe.Nothing[[]byte](), maybe.Nothing[[]byte](), 500) - require.NoError(err) - - // add an extra value to local db - lastKey := proof.KeyChanges[len(proof.KeyChanges)-1].Key - midpoint := midPoint(maybe.Some(lastKey), maybe.Nothing[[]byte]()) - midPointVal := midpoint.Value() - - require.NoError(db.Put(midPointVal, []byte{1})) - - // next key at prefix of newly added point - nextKey, err := syncer.findNextKey(context.Background(), lastKey, maybe.Nothing[[]byte](), proof.EndProof) - require.NoError(err) - require.True(nextKey.HasValue()) - - require.True(isPrefix(midPointVal, nextKey.Value())) - - require.NoError(db.Delete(midPointVal)) - - require.NoError(dbToSync.Put(midPointVal, []byte{1})) - - proof, err = dbToSync.GetRangeProof(context.Background(), maybe.Nothing[[]byte](), maybe.Some(lastKey), 500) - require.NoError(err) - - // next key at prefix of newly added point - nextKey, err = syncer.findNextKey(context.Background(), lastKey, maybe.Nothing[[]byte](), proof.EndProof) - require.NoError(err) - require.True(nextKey.HasValue()) - - // deal with odd length key - require.True(isPrefix(midPointVal, nextKey.Value())) -} - -func Test_Sync_FindNextKey_IdenticalKeys(t *testing.T) { - require := require.New(t) - - db, err := merkledb.New( - context.Background(), - memdb.New(), - newDefaultDBConfig(), - ) - require.NoError(err) - - testKeys := [][]byte{ - {0x10}, - {0x11, 0x11}, - {0x12, 0x34}, - {0x15}, - } - - for i, key := range testKeys { - value := []byte{byte(i + 1)} - require.NoError(db.Put(key, value)) - } - - targetRoot, err := db.GetMerkleRoot(context.Background()) - require.NoError(err) - - // Get the proof for the test key - testKey := []byte{0x11, 0x11} - proof, err := db.GetRangeProof(context.Background(), maybe.Some(testKey), maybe.Some(testKey), 1) - require.NoError(err) - - ctx := context.Background() - syncer, err := NewManager(ManagerConfig{ - DB: db, - RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(db)), - ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(db)), - TargetRoot: targetRoot, - SimultaneousWorkLimit: 5, - Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, - }, prometheus.NewRegistry()) - require.NoError(err) - - // Since both keys are identical, the next key should be nothing, since the range is complete - nextKey, err := syncer.findNextKey(context.Background(), testKey, maybe.Some([]byte{0x11, 0x11}), proof.EndProof) - require.NoError(err) - - require.Equal(maybe.Nothing[[]byte](), nextKey) -} - -func TestFindNextKeyEmptyEndProof(t *testing.T) { - require := require.New(t) - now := time.Now().UnixNano() - t.Logf("seed: %d", now) - r := rand.New(rand.NewSource(now)) // #nosec G404 - - db, err := merkledb.New( - context.Background(), - memdb.New(), - newDefaultDBConfig(), - ) - require.NoError(err) - - ctx := context.Background() - syncer, err := NewManager(ManagerConfig{ - DB: db, - RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(db)), - ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(db)), - TargetRoot: ids.Empty, - SimultaneousWorkLimit: 5, - Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, - }, prometheus.NewRegistry()) - require.NoError(err) - require.NotNil(syncer) - - for i := 0; i < 100; i++ { - lastReceivedKeyLen := r.Intn(16) - lastReceivedKey := make([]byte, lastReceivedKeyLen) - _, _ = r.Read(lastReceivedKey) // #nosec G404 - - rangeEndLen := r.Intn(16) - rangeEndBytes := make([]byte, rangeEndLen) - _, _ = r.Read(rangeEndBytes) // #nosec G404 - - rangeEnd := maybe.Nothing[[]byte]() - if rangeEndLen > 0 { - rangeEnd = maybe.Some(rangeEndBytes) - } - - nextKey, err := syncer.findNextKey( - context.Background(), - lastReceivedKey, - rangeEnd, - nil, /* endProof */ - ) - require.NoError(err) - require.Equal(maybe.Some(append(lastReceivedKey, 0)), nextKey) - } -} - -func isPrefix(data []byte, prefix []byte) bool { - if prefix[len(prefix)-1]%16 == 0 { - index := 0 - for ; index < len(prefix)-1; index++ { - if data[index] != prefix[index] { - return false - } - } - return data[index]>>4 == prefix[index]>>4 - } - return bytes.HasPrefix(data, prefix) -} - -func Test_Sync_FindNextKey_DifferentChild(t *testing.T) { - require := require.New(t) - - now := time.Now().UnixNano() - t.Logf("seed: %d", now) - r := rand.New(rand.NewSource(now)) // #nosec G404 - dbToSync, err := generateTrie(t, r, 500) - require.NoError(err) - syncRoot, err := dbToSync.GetMerkleRoot(context.Background()) - require.NoError(err) - - db, err := merkledb.New( - context.Background(), - memdb.New(), - newDefaultDBConfig(), - ) - require.NoError(err) - - ctx := context.Background() - syncer, err := NewManager(ManagerConfig{ - DB: db, - RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(dbToSync)), - ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(dbToSync)), - TargetRoot: syncRoot, - SimultaneousWorkLimit: 5, - Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, - }, prometheus.NewRegistry()) - require.NoError(err) - require.NotNil(syncer) - require.NoError(syncer.Start(context.Background())) - require.NoError(syncer.Wait(context.Background())) - - proof, err := dbToSync.GetRangeProof(context.Background(), maybe.Nothing[[]byte](), maybe.Nothing[[]byte](), 100) - require.NoError(err) - lastKey := proof.KeyChanges[len(proof.KeyChanges)-1].Key - - // local db has a different child than remote db - lastKey = append(lastKey, 16) - require.NoError(db.Put(lastKey, []byte{1})) - - require.NoError(dbToSync.Put(lastKey, []byte{2})) - - proof, err = dbToSync.GetRangeProof(context.Background(), maybe.Nothing[[]byte](), maybe.Some(proof.KeyChanges[len(proof.KeyChanges)-1].Key), 100) - require.NoError(err) - - nextKey, err := syncer.findNextKey(context.Background(), proof.KeyChanges[len(proof.KeyChanges)-1].Key, maybe.Nothing[[]byte](), proof.EndProof) - require.NoError(err) - require.True(nextKey.HasValue()) - require.Equal(lastKey, nextKey.Value()) -} - -// Test findNextKey by computing the expected result in a naive, inefficient -// way and comparing it to the actual result -func TestFindNextKeyRandom(t *testing.T) { - now := time.Now().UnixNano() - t.Logf("seed: %d", now) - rand := rand.New(rand.NewSource(now)) // #nosec G404 - require := require.New(t) - - // Create a "remote" database and "local" database - remoteDB, err := merkledb.New( - context.Background(), - memdb.New(), - newDefaultDBConfig(), - ) - require.NoError(err) - - config := newDefaultDBConfig() - localDB, err := merkledb.New( - context.Background(), - memdb.New(), - config, - ) - require.NoError(err) - - var ( - numProofsToTest = 250 - numKeyValues = 250 - maxKeyLen = 256 - maxValLen = 256 - maxRangeStartLen = 8 - maxRangeEndLen = 8 - maxProofLen = 128 - ) - - // Put random keys into the databases - for _, db := range []database.Database{remoteDB, localDB} { - for i := 0; i < numKeyValues; i++ { - key := make([]byte, rand.Intn(maxKeyLen)) - _, _ = rand.Read(key) - val := make([]byte, rand.Intn(maxValLen)) - _, _ = rand.Read(val) - require.NoError(db.Put(key, val)) - } - } - - // Repeatedly generate end proofs from the remote database and compare - // the result of findNextKey to the expected result. - for proofIndex := 0; proofIndex < numProofsToTest; proofIndex++ { - // Generate a proof for a random key - var ( - rangeStart []byte - rangeEnd []byte - ) - // Generate a valid range start and end - for rangeStart == nil || bytes.Compare(rangeStart, rangeEnd) == 1 { - rangeStart = make([]byte, rand.Intn(maxRangeStartLen)+1) - _, _ = rand.Read(rangeStart) - rangeEnd = make([]byte, rand.Intn(maxRangeEndLen)+1) - _, _ = rand.Read(rangeEnd) - } - - startKey := maybe.Nothing[[]byte]() - if len(rangeStart) > 0 { - startKey = maybe.Some(rangeStart) - } - endKey := maybe.Nothing[[]byte]() - if len(rangeEnd) > 0 { - endKey = maybe.Some(rangeEnd) - } - - remoteProof, err := remoteDB.GetRangeProof( - context.Background(), - startKey, - endKey, - rand.Intn(maxProofLen)+1, - ) - require.NoError(err) - - if len(remoteProof.KeyChanges) == 0 { - continue - } - lastReceivedKey := remoteProof.KeyChanges[len(remoteProof.KeyChanges)-1].Key - - // Commit the proof to the local database as we do - // in the actual syncer. - require.NoError(localDB.CommitRangeProof( - context.Background(), - startKey, - endKey, - remoteProof, - )) - - localProof, err := localDB.GetProof( - context.Background(), - lastReceivedKey, - ) - require.NoError(err) - - type keyAndID struct { - key merkledb.Key - id ids.ID - } - - // Set of key prefix/ID pairs proven by the remote database's end proof. - remoteKeyIDs := []keyAndID{} - for _, node := range remoteProof.EndProof { - for childIdx, childID := range node.Children { - remoteKeyIDs = append(remoteKeyIDs, keyAndID{ - key: node.Key.Extend(merkledb.ToToken(childIdx, merkledb.BranchFactorToTokenSize[config.BranchFactor])), - id: childID, - }) - } - } - - // Set of key prefix/ID pairs proven by the local database's proof. - localKeyIDs := []keyAndID{} - for _, node := range localProof.Path { - for childIdx, childID := range node.Children { - localKeyIDs = append(localKeyIDs, keyAndID{ - key: node.Key.Extend(merkledb.ToToken(childIdx, merkledb.BranchFactorToTokenSize[config.BranchFactor])), - id: childID, - }) - } - } - - // Sort in ascending order by key prefix. - serializedPathCompare := func(i, j keyAndID) int { - return i.key.Compare(j.key) - } - slices.SortFunc(remoteKeyIDs, serializedPathCompare) - slices.SortFunc(localKeyIDs, serializedPathCompare) - - // Filter out keys that are before the last received key - findBounds := func(keyIDs []keyAndID) (int, int) { - var ( - firstIdxInRange = len(keyIDs) - firstIdxInRangeFound = false - firstIdxOutOfRange = len(keyIDs) - ) - for i, keyID := range keyIDs { - if !firstIdxInRangeFound && bytes.Compare(keyID.key.Bytes(), lastReceivedKey) > 0 { - firstIdxInRange = i - firstIdxInRangeFound = true - continue - } - if bytes.Compare(keyID.key.Bytes(), rangeEnd) > 0 { - firstIdxOutOfRange = i - break - } - } - return firstIdxInRange, firstIdxOutOfRange - } - - remoteFirstIdxAfterLastReceived, remoteFirstIdxAfterEnd := findBounds(remoteKeyIDs) - remoteKeyIDs = remoteKeyIDs[remoteFirstIdxAfterLastReceived:remoteFirstIdxAfterEnd] - - localFirstIdxAfterLastReceived, localFirstIdxAfterEnd := findBounds(localKeyIDs) - localKeyIDs = localKeyIDs[localFirstIdxAfterLastReceived:localFirstIdxAfterEnd] - - // Find smallest difference between the set of key/ID pairs proven by - // the remote/local proofs for key/ID pairs after the last received key. - var ( - smallestDiffKey merkledb.Key - foundDiff bool - ) - for i := 0; i < len(remoteKeyIDs) && i < len(localKeyIDs); i++ { - // See if the keys are different. - smaller, bigger := remoteKeyIDs[i], localKeyIDs[i] - if serializedPathCompare(localKeyIDs[i], remoteKeyIDs[i]) == -1 { - smaller, bigger = localKeyIDs[i], remoteKeyIDs[i] - } - - if smaller.key != bigger.key || smaller.id != bigger.id { - smallestDiffKey = smaller.key - foundDiff = true - break - } - } - if !foundDiff { - // All the keys were equal. The smallest diff is the next key - // in the longer of the lists (if they're not same length.) - if len(remoteKeyIDs) < len(localKeyIDs) { - smallestDiffKey = localKeyIDs[len(remoteKeyIDs)].key - } else if len(remoteKeyIDs) > len(localKeyIDs) { - smallestDiffKey = remoteKeyIDs[len(localKeyIDs)].key - } - } - - // Get the actual value from the syncer - ctx := context.Background() - syncer, err := NewManager(ManagerConfig{ - DB: localDB, - RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(remoteDB)), - ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(remoteDB)), - TargetRoot: ids.GenerateTestID(), - SimultaneousWorkLimit: 5, - Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, - }, prometheus.NewRegistry()) - require.NoError(err) - - gotFirstDiff, err := syncer.findNextKey( - context.Background(), - lastReceivedKey, - endKey, - remoteProof.EndProof, - ) - require.NoError(err) - - if bytes.Compare(smallestDiffKey.Bytes(), rangeEnd) >= 0 { - // The smallest key which differs is after the range end so the - // next key to get should be nil because we're done fetching the range. - require.True(gotFirstDiff.IsNothing()) - } else { - require.Equal(smallestDiffKey.Bytes(), gotFirstDiff.Value()) - } - } -} - // Tests that we are able to sync to the correct root while the server is // updating func Test_Sync_Result_Correct_Root(t *testing.T) { @@ -1006,14 +363,18 @@ func Test_Sync_Result_Correct_Root(t *testing.T) { changeProofClient = tt.changeProofClient(dbToSync) } + proofClient, err := NewClient(db, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + syncer, err := NewManager(ManagerConfig{ - DB: db, + ProofClient: proofClient, RangeProofClient: rangeProofClient, ChangeProofClient: changeProofClient, TargetRoot: syncRoot, SimultaneousWorkLimit: 5, Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, }, prometheus.NewRegistry()) require.NoError(err) @@ -1078,14 +439,18 @@ func Test_Sync_Result_Correct_Root_With_Sync_Restart(t *testing.T) { require.NoError(err) ctx := context.Background() + proofClient, err := NewClient(db, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + syncer, err := NewManager(ManagerConfig{ - DB: db, + ProofClient: proofClient, RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(dbToSync)), ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(dbToSync)), TargetRoot: syncRoot, SimultaneousWorkLimit: 5, Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, }, prometheus.NewRegistry()) require.NoError(err) require.NotNil(syncer) @@ -1105,14 +470,18 @@ func Test_Sync_Result_Correct_Root_With_Sync_Restart(t *testing.T) { ) syncer.Close() + newProofClient, err := NewClient(db, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + newSyncer, err := NewManager(ManagerConfig{ - DB: db, + ProofClient: newProofClient, RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(dbToSync)), ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(dbToSync)), TargetRoot: syncRoot, SimultaneousWorkLimit: 5, Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, }, prometheus.NewRegistry()) require.NoError(err) require.NotNil(newSyncer) @@ -1189,14 +558,18 @@ func Test_Sync_Result_Correct_Root_Update_Root_During(t *testing.T) { updatedRootChan: updatedRootChan, }) + proofClient, err := NewClient(db, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + syncer, err := NewManager(ManagerConfig{ - DB: db, + ProofClient: proofClient, RangeProofClient: rangeProofClient, ChangeProofClient: changeProofClient, TargetRoot: firstSyncRoot, SimultaneousWorkLimit: 5, Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, }, prometheus.NewRegistry()) require.NoError(err) require.NotNil(syncer) @@ -1236,14 +609,19 @@ func Test_Sync_UpdateSyncTarget(t *testing.T) { ) require.NoError(err) ctx := context.Background() + + proofClient, err := NewClient(db, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + m, err := NewManager(ManagerConfig{ - DB: db, + ProofClient: proofClient, RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(db)), ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(db)), TargetRoot: ids.Empty, SimultaneousWorkLimit: 5, Log: logging.NoLog{}, - BranchFactor: merkledb.BranchFactor16, }, prometheus.NewRegistry()) require.NoError(err) From 389b0bf6f2fc6387df6c1ecd441b2587b56ff790 Mon Sep 17 00:00:00 2001 From: Austin Larson Date: Thu, 17 Jul 2025 10:33:41 -0400 Subject: [PATCH 3/5] refactor: remove callbacks --- x/sync/db.go | 7 ++--- x/sync/manager.go | 29 ++++++++++++-------- x/sync/merkledb_client.go | 58 +++++++++++++++++++++------------------ 3 files changed, 53 insertions(+), 41 deletions(-) diff --git a/x/sync/db.go b/x/sync/db.go index 1d30c5944d5b..f9799694c82e 100644 --- a/x/sync/db.go +++ b/x/sync/db.go @@ -23,12 +23,11 @@ type DB interface { type ProofClient interface { merkledb.Clearer merkledb.MerkleRootGetter - HandleRangeProofResponse(ctx context.Context, request *pb.SyncGetRangeProofRequest, responseBytes []byte, onFinish func(maybe.Maybe[[]byte])) error + HandleRangeProofResponse(ctx context.Context, request *pb.SyncGetRangeProofRequest, responseBytes []byte) (maybe.Maybe[[]byte], error) HandleChangeProofResponse( ctx context.Context, request *pb.SyncGetChangeProofRequest, responseBytes []byte, - onFinish func(maybe.Maybe[[]byte]), - ) error - RegisterErrorHandler(handler func(error)) + ) (maybe.Maybe[[]byte], error) + Error() error } diff --git a/x/sync/manager.go b/x/sync/manager.go index 6ddee2578bfa..7e757924dfdd 100644 --- a/x/sync/manager.go +++ b/x/sync/manager.go @@ -174,9 +174,6 @@ func NewManager(config ManagerConfig, registerer prometheus.Registerer) (*Manage metrics: metrics, } m.unprocessedWorkCond.L = &m.workLock - m.config.ProofClient.RegisterErrorHandler(func(err error) { - m.setError(err) - }) return m, nil } @@ -219,6 +216,10 @@ func (m *Manager) sync(ctx context.Context) { switch { case ctx.Err() != nil: return // [m.workLock] released by defer. + case m.config.ProofClient.Error() != nil: + // If the proof client has an error, we can't continue. + m.setError(m.config.ProofClient.Error()) + return case m.processingWorkItems >= m.config.SimultaneousWorkLimit: // We're already processing the maximum number of work items. // Wait until one of them finishes. @@ -498,14 +499,17 @@ func (m *Manager) handleRangeProofResponse( return err } - return m.config.ProofClient.HandleRangeProofResponse( + nextKey, err := m.config.ProofClient.HandleRangeProofResponse( ctx, request, responseBytes, - func(largestHandledKey maybe.Maybe[[]byte]) { - m.completeWorkItem(work, largestHandledKey, targetRootID) - }, ) + if err != nil { + return err + } + + m.completeWorkItem(work, nextKey, targetRootID) + return nil } func (m *Manager) handleChangeProofResponse( @@ -520,14 +524,17 @@ func (m *Manager) handleChangeProofResponse( return err } - return m.config.ProofClient.HandleChangeProofResponse( + nextKey, err := m.config.ProofClient.HandleChangeProofResponse( ctx, request, responseBytes, - func(largestHandledKey maybe.Maybe[[]byte]) { - m.completeWorkItem(work, largestHandledKey, targetRootID) - }, ) + if err != nil { + return err + } + + m.completeWorkItem(work, nextKey, targetRootID) + return nil } func (m *Manager) Error() error { diff --git a/x/sync/merkledb_client.go b/x/sync/merkledb_client.go index 76bddd72d1a2..b3d23aa81fa1 100644 --- a/x/sync/merkledb_client.go +++ b/x/sync/merkledb_client.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "slices" + "sync" "golang.org/x/exp/maps" "google.golang.org/protobuf/proto" @@ -31,7 +32,8 @@ type xSyncClient struct { db DB config *ClientConfig tokenSize int - setError func(error) + err error + errorOnce sync.Once } func NewClient(db DB, config *ClientConfig) (*xSyncClient, error) { @@ -50,8 +52,14 @@ func NewClient(db DB, config *ClientConfig) (*xSyncClient, error) { }, nil } -func (c *xSyncClient) RegisterErrorHandler(handler func(error)) { - c.setError = handler +func (c *xSyncClient) Error() error { + return c.err +} + +func (c *xSyncClient) setError(err error) { + c.errorOnce.Do(func() { + c.err = err + }) } func (c *xSyncClient) Clear() error { @@ -62,15 +70,15 @@ func (c *xSyncClient) GetMerkleRoot(ctx context.Context) (ids.ID, error) { return c.db.GetMerkleRoot(ctx) } -func (c *xSyncClient) HandleRangeProofResponse(ctx context.Context, request *pb.SyncGetRangeProofRequest, responseBytes []byte, onFinish func(maybe.Maybe[[]byte])) error { +func (c *xSyncClient) HandleRangeProofResponse(ctx context.Context, request *pb.SyncGetRangeProofRequest, responseBytes []byte) (maybe.Maybe[[]byte], error) { var rangeProofProto pb.RangeProof if err := proto.Unmarshal(responseBytes, &rangeProofProto); err != nil { - return err + return maybe.Nothing[[]byte](), err } var rangeProof merkledb.RangeProof if err := rangeProof.UnmarshalProto(&rangeProofProto); err != nil { - return err + return maybe.Nothing[[]byte](), err } start := maybeBytesToMaybe(request.StartKey) @@ -86,7 +94,7 @@ func (c *xSyncClient) HandleRangeProofResponse(ctx context.Context, request *pb. c.tokenSize, c.config.Hasher, ); err != nil { - return err + return maybe.Nothing[[]byte](), err } largestHandledKey := end @@ -94,7 +102,7 @@ func (c *xSyncClient) HandleRangeProofResponse(ctx context.Context, request *pb. // Replace all the key-value pairs in the DB from start to end with values from the response. if err := c.db.CommitRangeProof(ctx, start, end, &rangeProof); err != nil { c.setError(err) - return nil + return maybe.Nothing[[]byte](), err } if len(rangeProof.KeyChanges) > 0 { @@ -107,24 +115,22 @@ func (c *xSyncClient) HandleRangeProofResponse(ctx context.Context, request *pb. nextKey, err := c.findNextKey(ctx, largestHandledKey.Value(), end, rangeProof.EndProof) if err != nil { c.setError(err) - return nil + return maybe.Nothing[[]byte](), err } largestHandledKey = nextKey } - onFinish(largestHandledKey) - return nil + return largestHandledKey, nil } func (c *xSyncClient) HandleChangeProofResponse( ctx context.Context, request *pb.SyncGetChangeProofRequest, responseBytes []byte, - onFinish func(maybe.Maybe[[]byte]), -) error { +) (maybe.Maybe[[]byte], error) { var changeProofResp pb.SyncGetChangeProofResponse if err := proto.Unmarshal(responseBytes, &changeProofResp); err != nil { - return err + return maybe.Nothing[[]byte](), err } startKey := maybeBytesToMaybe(request.StartKey) @@ -139,13 +145,13 @@ func (c *xSyncClient) HandleChangeProofResponse( // The server had enough history to send us a change proof var changeProof merkledb.ChangeProof if err := changeProof.UnmarshalProto(changeProofResp.ChangeProof); err != nil { - return err + return maybe.Nothing[[]byte](), err } // Ensure the response does not contain more than the requested number of leaves // and the start and end roots match the requested roots. if len(changeProof.KeyChanges) > int(request.KeyLimit) { - return fmt.Errorf( + return maybe.Nothing[[]byte](), fmt.Errorf( "%w: (%d) > %d)", errTooManyKeys, len(changeProof.KeyChanges), request.KeyLimit, ) @@ -153,7 +159,7 @@ func (c *xSyncClient) HandleChangeProofResponse( endRoot, err := ids.ToID(request.EndRootHash) if err != nil { - return err + return maybe.Nothing[[]byte](), err } if err := c.db.VerifyChangeProof( @@ -163,7 +169,7 @@ func (c *xSyncClient) HandleChangeProofResponse( endKey, endRoot, ); err != nil { - return fmt.Errorf("%w due to %w", errInvalidChangeProof, err) + return maybe.Nothing[[]byte](), fmt.Errorf("%w due to %w", errInvalidChangeProof, err) } largestHandledKey = endKey @@ -171,7 +177,7 @@ func (c *xSyncClient) HandleChangeProofResponse( if len(changeProof.KeyChanges) > 0 { if err := c.db.CommitChangeProof(ctx, &changeProof); err != nil { c.setError(err) - return nil + return maybe.Nothing[[]byte](), err } largestHandledKey = maybe.Some(changeProof.KeyChanges[len(changeProof.KeyChanges)-1].Key) } @@ -180,7 +186,7 @@ func (c *xSyncClient) HandleChangeProofResponse( case *pb.SyncGetChangeProofResponse_RangeProof: var rangeProof merkledb.RangeProof if err := rangeProof.UnmarshalProto(changeProofResp.RangeProof); err != nil { - return err + return maybe.Nothing[[]byte](), err } // The server did not have enough history to send us a change proof @@ -195,7 +201,7 @@ func (c *xSyncClient) HandleChangeProofResponse( c.tokenSize, c.config.Hasher, ); err != nil { - return err + return maybe.Nothing[[]byte](), err } largestHandledKey = endKey @@ -203,14 +209,14 @@ func (c *xSyncClient) HandleChangeProofResponse( // Add all the key-value pairs we got to the database. if err := c.db.CommitRangeProof(ctx, startKey, endKey, &rangeProof); err != nil { c.setError(err) - return nil + return maybe.Nothing[[]byte](), err } largestHandledKey = maybe.Some(rangeProof.KeyChanges[len(rangeProof.KeyChanges)-1].Key) } endProof = rangeProof.EndProof default: - return fmt.Errorf( + return maybe.Nothing[[]byte](), fmt.Errorf( "%w: %T", errUnexpectedChangeProofResponse, changeProofResp, ) @@ -221,12 +227,12 @@ func (c *xSyncClient) HandleChangeProofResponse( nextKey, err := c.findNextKey(ctx, largestHandledKey.Value(), endKey, endProof) if err != nil { c.setError(err) - return nil + return maybe.Nothing[[]byte](), err } largestHandledKey = nextKey } - onFinish(largestHandledKey) - return nil + + return largestHandledKey, nil } // findNextKey returns the start of the key range that should be fetched next From 2fb7fbd635fe97b5bde3be87333909982582a614 Mon Sep 17 00:00:00 2001 From: Austin Larson Date: Thu, 17 Jul 2025 11:16:48 -0400 Subject: [PATCH 4/5] refactor: naming --- x/sync/db.go | 16 ++++------------ x/sync/manager.go | 4 ++-- x/sync/merkledb_client.go | 30 +++++++++++++++++++----------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/x/sync/db.go b/x/sync/db.go index f9799694c82e..7a7c61d2109f 100644 --- a/x/sync/db.go +++ b/x/sync/db.go @@ -6,23 +6,14 @@ package sync import ( "context" + "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/maybe" - "github.com/ava-labs/avalanchego/x/merkledb" pb "github.com/ava-labs/avalanchego/proto/pb/sync" ) -type DB interface { - merkledb.Clearer - merkledb.MerkleRootGetter - merkledb.ProofGetter - merkledb.ChangeProofer - merkledb.RangeProofer -} - -type ProofClient interface { - merkledb.Clearer - merkledb.MerkleRootGetter +type DBSyncClient interface { + GetRootHash(ctx context.Context) (ids.ID, error) HandleRangeProofResponse(ctx context.Context, request *pb.SyncGetRangeProofRequest, responseBytes []byte) (maybe.Maybe[[]byte], error) HandleChangeProofResponse( ctx context.Context, @@ -30,4 +21,5 @@ type ProofClient interface { responseBytes []byte, ) (maybe.Maybe[[]byte], error) Error() error + Clear() error } diff --git a/x/sync/manager.go b/x/sync/manager.go index 7e757924dfdd..09fcf82e03e1 100644 --- a/x/sync/manager.go +++ b/x/sync/manager.go @@ -138,7 +138,7 @@ type Manager struct { // TODO remove non-config values out of this struct type ManagerConfig struct { - ProofClient ProofClient + ProofClient DBSyncClient RangeProofClient *p2p.Client ChangeProofClient *p2p.Client SimultaneousWorkLimit int @@ -561,7 +561,7 @@ func (m *Manager) Wait(ctx context.Context) error { return err } - root, err := m.config.ProofClient.GetMerkleRoot(ctx) + root, err := m.config.ProofClient.GetRootHash(ctx) if err != nil { return err } diff --git a/x/sync/merkledb_client.go b/x/sync/merkledb_client.go index b3d23aa81fa1..5784f89cfe12 100644 --- a/x/sync/merkledb_client.go +++ b/x/sync/merkledb_client.go @@ -21,14 +21,22 @@ import ( pb "github.com/ava-labs/avalanchego/proto/pb/sync" ) -var _ ProofClient = (*xSyncClient)(nil) +var _ DBSyncClient = (*merkleDBSyncClient)(nil) + +type DB interface { + merkledb.Clearer + merkledb.MerkleRootGetter + merkledb.ProofGetter + merkledb.ChangeProofer + merkledb.RangeProofer +} type ClientConfig struct { Hasher merkledb.Hasher BranchFactor merkledb.BranchFactor } -type xSyncClient struct { +type merkleDBSyncClient struct { db DB config *ClientConfig tokenSize int @@ -36,7 +44,7 @@ type xSyncClient struct { errorOnce sync.Once } -func NewClient(db DB, config *ClientConfig) (*xSyncClient, error) { +func NewClient(db DB, config *ClientConfig) (*merkleDBSyncClient, error) { if err := config.BranchFactor.Valid(); err != nil { return nil, err } @@ -45,32 +53,32 @@ func NewClient(db DB, config *ClientConfig) (*xSyncClient, error) { config.Hasher = merkledb.DefaultHasher } - return &xSyncClient{ + return &merkleDBSyncClient{ db: db, config: config, tokenSize: merkledb.BranchFactorToTokenSize[config.BranchFactor], }, nil } -func (c *xSyncClient) Error() error { +func (c *merkleDBSyncClient) Error() error { return c.err } -func (c *xSyncClient) setError(err error) { +func (c *merkleDBSyncClient) setError(err error) { c.errorOnce.Do(func() { c.err = err }) } -func (c *xSyncClient) Clear() error { +func (c *merkleDBSyncClient) Clear() error { return c.db.Clear() } -func (c *xSyncClient) GetMerkleRoot(ctx context.Context) (ids.ID, error) { +func (c *merkleDBSyncClient) GetRootHash(ctx context.Context) (ids.ID, error) { return c.db.GetMerkleRoot(ctx) } -func (c *xSyncClient) HandleRangeProofResponse(ctx context.Context, request *pb.SyncGetRangeProofRequest, responseBytes []byte) (maybe.Maybe[[]byte], error) { +func (c *merkleDBSyncClient) HandleRangeProofResponse(ctx context.Context, request *pb.SyncGetRangeProofRequest, responseBytes []byte) (maybe.Maybe[[]byte], error) { var rangeProofProto pb.RangeProof if err := proto.Unmarshal(responseBytes, &rangeProofProto); err != nil { return maybe.Nothing[[]byte](), err @@ -123,7 +131,7 @@ func (c *xSyncClient) HandleRangeProofResponse(ctx context.Context, request *pb. return largestHandledKey, nil } -func (c *xSyncClient) HandleChangeProofResponse( +func (c *merkleDBSyncClient) HandleChangeProofResponse( ctx context.Context, request *pb.SyncGetChangeProofRequest, responseBytes []byte, @@ -247,7 +255,7 @@ func (c *xSyncClient) HandleChangeProofResponse( // // Invariant: [lastReceivedKey] < [rangeEnd]. // If [rangeEnd] is Nothing it's considered > [lastReceivedKey]. -func (c *xSyncClient) findNextKey( +func (c *merkleDBSyncClient) findNextKey( ctx context.Context, lastReceivedKey []byte, rangeEnd maybe.Maybe[[]byte], From 412c512e711824cd260edfe177c9b786ffcd7b79 Mon Sep 17 00:00:00 2001 From: Austin Larson Date: Thu, 17 Jul 2025 12:21:53 -0400 Subject: [PATCH 5/5] test: ensure db errors work --- x/sync/merkledb_client.go | 8 +++-- x/sync/sync_test.go | 65 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/x/sync/merkledb_client.go b/x/sync/merkledb_client.go index 5784f89cfe12..8ee5fc7ee4aa 100644 --- a/x/sync/merkledb_client.go +++ b/x/sync/merkledb_client.go @@ -14,6 +14,7 @@ import ( "google.golang.org/protobuf/proto" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils" "github.com/ava-labs/avalanchego/utils/maybe" "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/x/merkledb" @@ -40,7 +41,7 @@ type merkleDBSyncClient struct { db DB config *ClientConfig tokenSize int - err error + err *utils.Atomic[error] errorOnce sync.Once } @@ -57,16 +58,17 @@ func NewClient(db DB, config *ClientConfig) (*merkleDBSyncClient, error) { db: db, config: config, tokenSize: merkledb.BranchFactorToTokenSize[config.BranchFactor], + err: utils.NewAtomic[error](nil), }, nil } func (c *merkleDBSyncClient) Error() error { - return c.err + return c.err.Get() } func (c *merkleDBSyncClient) setError(err error) { c.errorOnce.Do(func() { - c.err = err + c.err.Set(err) }) } diff --git a/x/sync/sync_test.go b/x/sync/sync_test.go index d22a47b4e641..039a32929444 100644 --- a/x/sync/sync_test.go +++ b/x/sync/sync_test.go @@ -6,6 +6,7 @@ package sync import ( "bytes" "context" + "errors" "math/rand" "slices" "testing" @@ -23,7 +24,12 @@ import ( "github.com/ava-labs/avalanchego/x/merkledb" ) -var _ p2p.Handler = (*waitingHandler)(nil) +var ( + _ p2p.Handler = (*waitingHandler)(nil) + + // Used to simulate a corrupted DB + errorFoo = errors.New("mock error") +) func Test_Creation(t *testing.T) { require := require.New(t) @@ -657,6 +663,63 @@ func Test_Sync_UpdateSyncTarget(t *testing.T) { require.Equal(1, m.unprocessedWork.Len()) } +func Test_Sync_DBError(t *testing.T) { + require := require.New(t) + + now := time.Now().UnixNano() + r := rand.New(rand.NewSource(now)) // #nosec G404 + dbToSync, err := generateTrie(t, r, maxKeyValuesLimit) + require.NoError(err) + syncRoot, err := dbToSync.GetMerkleRoot(context.Background()) + require.NoError(err) + + db, err := merkledb.New( + context.Background(), + memdb.New(), + newDefaultDBConfig(), + ) + require.NoError(err) + + // This client DB will return an error when it tries to commit a range proof. + badDB := &badMerkleDB{MerkleDB: db} + + proofClient, err := NewClient(badDB, &ClientConfig{ + BranchFactor: merkledb.BranchFactor16, + }) + require.NoError(err) + + ctx := context.Background() + syncer, err := NewManager(ManagerConfig{ + ProofClient: proofClient, + RangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetRangeProofHandler(dbToSync)), + ChangeProofClient: p2ptest.NewSelfClient(t, ctx, ids.EmptyNodeID, NewGetChangeProofHandler(dbToSync)), + TargetRoot: syncRoot, + SimultaneousWorkLimit: 5, + Log: logging.NoLog{}, + }, prometheus.NewRegistry()) + require.NoError(err) + require.NotNil(syncer) + require.NoError(syncer.Start(context.Background())) + err = syncer.Wait(context.Background()) + require.ErrorIs(err, errorFoo) +} + +var _ DB = (*badMerkleDB)(nil) + +type badMerkleDB struct { + merkledb.MerkleDB +} + +func (*badMerkleDB) CommitRangeProof( + _ context.Context, + _ maybe.Maybe[[]byte], + _ maybe.Maybe[[]byte], + _ *merkledb.RangeProof, +) error { + // Simulate a bad response by returning an error. + return errorFoo +} + func generateTrie(t *testing.T, r *rand.Rand, count int) (merkledb.MerkleDB, error) { return generateTrieWithMinKeyLen(t, r, count, 0) }