From e9f59787f226e590d0b474171b3eea1871dbff02 Mon Sep 17 00:00:00 2001 From: Thomas Tendyck Date: Tue, 1 Oct 2024 13:44:03 +0200 Subject: [PATCH 1/2] add monotonic counter interface for rollback protection --- db.go | 5 +- edg.go | 60 ++++++++++ internal/edg/edg_test.go | 233 ++++++++++++++++++++++++++++++++++++++- open.go | 12 ++ options.go | 11 ++ transaction.go | 35 ++++++ 6 files changed, 349 insertions(+), 7 deletions(-) create mode 100644 edg.go diff --git a/db.go b/db.go index 5a84bc11..8e812d30 100644 --- a/db.go +++ b/db.go @@ -501,8 +501,9 @@ type DB struct { // compaction concurrency openedAt time.Time - keyManager *edg.KeyManager - txLock sync.Mutex + keyManager *edg.KeyManager + txLock sync.Mutex + monotonicCounter uint64 } var _ Reader = (*DB)(nil) diff --git a/edg.go b/edg.go new file mode 100644 index 00000000..e2afddf6 --- /dev/null +++ b/edg.go @@ -0,0 +1,60 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package estore + +import ( + "encoding/binary" + + "github.com/cockroachdb/errors" +) + +var edgMonotonicCounterKey = []byte("!EDGELESS_MONOTONIC_COUNTER") + +func (d *DB) edgGetMonotonicCounterFromStore() (uint64, error) { + value, closer, err := d.Get(edgMonotonicCounterKey) + if errors.Is(err, ErrNotFound) { + return 0, nil + } + if err != nil { + return 0, err + } + defer closer.Close() + return binary.LittleEndian.Uint64(value), nil +} + +func (d *DB) edgSetMonotonicCounterOnStore(value uint64) error { + return d.Set(edgMonotonicCounterKey, binary.LittleEndian.AppendUint64(nil, value), nil) +} + +func (d *DB) edgVerifyFreshness() error { + if d.opts.SetMonotonicCounter == nil { + return nil + } + + // get counter from trusted source + sourceCount, err := d.opts.SetMonotonicCounter(0) + if err != nil { + return errors.Wrap(err, "getting monotonic counter from trusted source") + } + + // get counter from store + storeCount, err := d.edgGetMonotonicCounterFromStore() + if err != nil { + return errors.Wrap(err, "getting monotonic counter from store") + } + + if storeCount < sourceCount { + return errors.Newf("rollback detected: store counter: %v, trusted source counter: %v", storeCount, sourceCount) + } + if storeCount > sourceCount { + d.opts.Logger.Infof("WARNING: open: monotonic counter source lags behind: store counter: %v, source counter: %v", storeCount, sourceCount) + // will be synced on next tx commit + } + + d.monotonicCounter = storeCount + return nil +} diff --git a/internal/edg/edg_test.go b/internal/edg/edg_test.go index 7d9a014c..1485b8e6 100644 --- a/internal/edg/edg_test.go +++ b/internal/edg/edg_test.go @@ -12,6 +12,7 @@ import ( "math" "os" "strings" + "sync" "testing" "github.com/edgelesssys/estore" @@ -55,9 +56,10 @@ func TestConfidentiality(t *testing.T) { fs := vfs.NewMem() db, err := estore.Open("", &estore.Options{ - EncryptionKey: testKey(), - FS: fs, - Levels: []estore.LevelOptions{{Compression: estore.NoCompression}}, + EncryptionKey: testKey(), + SetMonotonicCounter: (&fakeCounter{}).set, + FS: fs, + Levels: []estore.LevelOptions{{Compression: estore.NoCompression}}, }) require.NoError(err) @@ -239,8 +241,9 @@ func TestOldDB(t *testing.T) { require := require.New(t) opts := &estore.Options{ - EncryptionKey: testKey(), - FS: vfs.NewMem(), + EncryptionKey: testKey(), + SetMonotonicCounter: (&fakeCounter{}).set, + FS: vfs.NewMem(), } ok, err := vfs.Clone(vfs.Default, opts.FS, "testdata/db-v1.0.0", "") @@ -273,6 +276,206 @@ func TestOldDB(t *testing.T) { require.NoError(db.Close()) } +func TestRollbackProtection(t *testing.T) { + require := require.New(t) + + const dbdir = "db" + const olddir = "old" + fs := vfs.NewMem() + var counter fakeCounter + + opts := &estore.Options{ + EncryptionKey: testKey(), + SetMonotonicCounter: counter.set, + FS: fs, + } + + // create db + db, err := estore.Open(dbdir, opts) + require.NoError(err) + tx := db.NewTransaction(true) + require.NoError(tx.Set([]byte("key"), []byte("val1"), nil)) + require.NoError(tx.Commit()) + require.NoError(db.Close()) + + // copy the db + ok, err := vfs.Clone(fs, fs, dbdir, olddir) + require.NoError(err) + require.True(ok) + + // advance db + db, err = estore.Open(dbdir, opts) + require.NoError(err) + tx = db.NewTransaction(true) + require.NoError(tx.Set([]byte("key"), []byte("val2"), nil)) + require.NoError(tx.Commit()) + require.NoError(db.Close()) + + // try to roll back the db + _, err = estore.Open(olddir, opts) + require.ErrorContains(err, "rollback detected") +} + +func TestRollbackProtection_Open_CounterSourceFails(t *testing.T) { + require := require.New(t) + + counter := fakeCounter{preErr: assert.AnError} + + opts := &estore.Options{ + EncryptionKey: testKey(), + SetMonotonicCounter: counter.set, + FS: vfs.NewMem(), + } + + _, err := estore.Open("", opts) + require.ErrorIs(err, counter.preErr) +} + +func TestRollbackProtection_Open_NewDBWithExistingCounter(t *testing.T) { + require := require.New(t) + + var counter fakeCounter + + opts := &estore.Options{ + EncryptionKey: testKey(), + SetMonotonicCounter: counter.set, + FS: vfs.NewMem(), + } + + counter.value = 2 + _, err := estore.Open("", opts) + require.ErrorContains(err, "rollback detected") +} + +func TestRollbackProtection_Open_CounterSourceRollbackCanBeHandled(t *testing.T) { + require := require.New(t) + + var counter fakeCounter + + opts := &estore.Options{ + EncryptionKey: testKey(), + SetMonotonicCounter: counter.set, + FS: vfs.NewMem(), + } + + // create db + db, err := estore.Open("", opts) + require.NoError(err) + tx := db.NewTransaction(true) + require.NoError(tx.Set([]byte("key"), []byte("val1"), nil)) + require.NoError(tx.Commit()) + tx = db.NewTransaction(true) + require.NoError(tx.Set([]byte("key"), []byte("val2"), nil)) + require.NoError(tx.Commit()) + require.NoError(db.Close()) + + // roll back counter source + counter.value-- + + // advance db + db, err = estore.Open("", opts) + require.NoError(err) + tx = db.NewTransaction(true) + require.NoError(tx.Set([]byte("key"), []byte("val3"), nil)) + require.NoError(tx.Commit()) + require.NoError(db.Close()) + + // counter is synced + require.EqualValues(3, counter.value) +} + +func TestRollbackProtection_CounterSourceReturnsError(t *testing.T) { + testCases := map[string]struct { + preErr error + postErr error + }{ + "error before increment": { + preErr: assert.AnError, + }, + "error after increment": { + postErr: assert.AnError, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + require := require.New(t) + + var counter fakeCounter + + opts := &estore.Options{ + EncryptionKey: testKey(), + SetMonotonicCounter: counter.set, + FS: vfs.NewMem(), + } + + db, err := estore.Open("", opts) + require.NoError(err) + + // tx fails + counter.preErr = tc.preErr + counter.postErr = tc.postErr + tx := db.NewTransaction(true) + require.NoError(tx.Set([]byte("key"), []byte("val1"), nil)) + require.Error(tx.Commit()) + + // value was not written + _, _, err = db.Get([]byte("key")) + require.ErrorIs(err, estore.ErrNotFound) + + // retry succeeds + counter.preErr = nil + counter.postErr = nil + tx = db.NewTransaction(true) + require.NoError(tx.Set([]byte("key"), []byte("val1"), nil)) + require.NoError(tx.Commit()) + + require.NoError(db.Close()) + + // counter is synced + require.EqualValues(2, counter.value) + }) + } +} + +func TestRollbackProtection_AdvancingCounterSourceCausesFailure(t *testing.T) { + require := require.New(t) + + var counter fakeCounter + + opts := &estore.Options{ + EncryptionKey: testKey(), + SetMonotonicCounter: counter.set, + FS: vfs.NewMem(), + Logger: &base.InMemLogger{}, // calls runtime.Goexit on Fatalf + } + + // create db + db, err := estore.Open("", opts) + require.NoError(err) + tx := db.NewTransaction(true) + require.NoError(tx.Set([]byte("key"), []byte("val1"), nil)) + require.NoError(tx.Commit()) + + // advance counter source + counter.value++ + + tx = db.NewTransaction(true) + require.NoError(tx.Set([]byte("key"), []byte("val2"), nil)) + + // expect fatal exit on commit + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + tx.Commit() + panic("unreachable") + }() + wg.Wait() + + require.NoError(db.Close()) +} + func testKey() []byte { return bytes.Repeat([]byte{2}, 16) } @@ -312,3 +515,23 @@ func isCryptoError(err error) bool { return false } + +type fakeCounter struct { + value uint64 + preErr error + postErr error +} + +func (c *fakeCounter) set(value uint64) (uint64, error) { + if c.preErr != nil { + return 0, c.preErr + } + prev := c.value + if value > c.value { + c.value = value + } + if c.postErr != nil { + return 0, c.postErr + } + return prev, nil +} diff --git a/open.go b/open.go index 9adb0289..e95fa806 100644 --- a/open.go +++ b/open.go @@ -72,6 +72,18 @@ func TableCacheSize(maxOpenFiles int) int { // Open opens a DB whose files live in the given directory. func Open(dirname string, opts *Options) (db *DB, _ error) { + db, err := open(dirname, opts) + if err != nil { + return db, err + } + if err := db.edgVerifyFreshness(); err != nil { + db.Close() + return nil, err + } + return db, nil +} + +func open(dirname string, opts *Options) (db *DB, _ error) { // Make a copy of the options so that we don't mutate the passed in options. opts = opts.Clone() opts = opts.EnsureDefaults() diff --git a/options.go b/options.go index 0f0793ea..0f4e40d0 100644 --- a/options.go +++ b/options.go @@ -470,6 +470,17 @@ type Options struct { // EncryptionKey is the master key for encryption at rest. Must be 16, 24, or 32 bytes. EncryptionKey []byte + // SetMonotonicCounter is a callback that EStore invokes to provide rollback protection by using a trusted monotonic counter. + // + // The behavior of the counter must be the following: + // If the passed value is greater than the counter's value, it is set as the new value and the old value is returned. + // Otherwise, the value is not changed and the current value is returned. + // + // If you use this feature, you should perform all write operations inside transactions because only these will be protected. + // + // If not set, rollback protection is disabled. + SetMonotonicCounter func(uint64) (uint64, error) + // Sync sstables periodically in order to smooth out writes to disk. This // option does not provide any persistency guarantee, but is used to avoid // latency spikes if the OS automatically decides to write out a large chunk diff --git a/transaction.go b/transaction.go index 46e4ff03..7ae87ca5 100644 --- a/transaction.go +++ b/transaction.go @@ -9,6 +9,8 @@ package estore import ( "context" "io" + + "github.com/cockroachdb/errors" ) // NewTransaction starts a new transaction. @@ -35,6 +37,39 @@ type Transaction struct { // Commit commits and closes the transaction. func (t *Transaction) Commit() error { defer t.Close() + if t.Batch == nil { + return nil + } + db := t.db + + if db.opts.SetMonotonicCounter != nil { + // We must increment the store counter and the source counter. The order is important so that errors don't make + // the store inaccessible: It's tolerable if the store counter is incremented but the source counter is not. + // We only commit the transaction if both counters are incremented. + + if err := db.edgSetMonotonicCounterOnStore(db.monotonicCounter + 1); err != nil { + return errors.Wrap(err, "setting monotonic counter on store") + } + prevStoreCount := db.monotonicCounter + db.monotonicCounter++ + + prevSourceCount, err := db.opts.SetMonotonicCounter(db.monotonicCounter) + if err != nil { + // We don't know if the source counter was incremented or not. + // Keep the store counter incremented. It will be synced on next successful commit. + db.opts.Logger.Infof("ERROR: incrementing the trusted source counter: %v", err) + return errors.Wrap(err, "incrementing the trusted source counter") + } + + if prevSourceCount > prevStoreCount { + // Unrecovarable error. Should only be possible if someone else incremented the monotonic counter. + db.opts.Logger.Fatalf("Previous value of the trusted source counter (%v) is greater than expected (%v)", prevSourceCount, prevStoreCount) + } + if prevSourceCount < prevStoreCount { + db.opts.Logger.Infof("WARNING: tx commit: monotonic counter source lagged behind (store counter: %v, source counter: %v) and should have been synced now", prevStoreCount, prevSourceCount) + } + } + return t.Batch.Commit(nil) } From 08beb1e1776bb3f8d764c07d96afb20e53fe2478 Mon Sep 17 00:00:00 2001 From: Thomas Tendyck Date: Mon, 7 Oct 2024 16:23:58 +0200 Subject: [PATCH 2/2] fix linter warnings of new govet version --- internal/metamorphic/crossversion/crossversion_test.go | 2 +- iterator_test.go | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/metamorphic/crossversion/crossversion_test.go b/internal/metamorphic/crossversion/crossversion_test.go index 5ac365f8..66f7c6bf 100644 --- a/internal/metamorphic/crossversion/crossversion_test.go +++ b/internal/metamorphic/crossversion/crossversion_test.go @@ -365,7 +365,7 @@ func (f *pebbleVersions) String() string { if i > 0 { fmt.Fprint(&buf, " ") } - fmt.Fprintf(&buf, v.SHA) + fmt.Fprint(&buf, v.SHA) } return buf.String() } diff --git a/iterator_test.go b/iterator_test.go index 01f2508e..d4c7cdb9 100644 --- a/iterator_test.go +++ b/iterator_test.go @@ -2087,7 +2087,7 @@ func BenchmarkIteratorSeqSeekGEWithBounds(b *testing.B) { valid = iter.Next() } if iter.Error() != nil { - b.Fatalf(iter.Error().Error()) + b.Fatal(iter.Error().Error()) } } iter.Close() @@ -2398,17 +2398,17 @@ func TestRangeKeyMaskingRandomized(t *testing.T) { t.Fatalf("iteration didn't produce identical results") } if hasP1 && !bytes.Equal(iter1.Key(), iter2.Key()) { - t.Fatalf(fmt.Sprintf("iteration didn't produce identical point keys: %s, %s", iter1.Key(), iter2.Key())) + t.Fatalf("iteration didn't produce identical point keys: %s, %s", iter1.Key(), iter2.Key()) } if hasR1 { // Confirm that the range key is the same. b1, e1 := iter1.RangeBounds() b2, e2 := iter2.RangeBounds() if !bytes.Equal(b1, b2) || !bytes.Equal(e1, e2) { - t.Fatalf(fmt.Sprintf( + t.Fatalf( "iteration didn't produce identical range keys: [%s, %s], [%s, %s]", b1, e1, b2, e2, - )) + ) } } @@ -2416,7 +2416,7 @@ func TestRangeKeyMaskingRandomized(t *testing.T) { // Confirm that the returned point key wasn't hidden. for j, pkey := range keys { if bytes.Equal(iter1.Key(), pkey) && pointKeyHidden[j] { - t.Fatalf(fmt.Sprintf("hidden point key was exposed %s %d", pkey, keyTimeStamps[j])) + t.Fatalf("hidden point key was exposed %s %d", pkey, keyTimeStamps[j]) } } }