Skip to content

Commit

Permalink
Fix etcd compat with result limits / pagination
Browse files Browse the repository at this point in the history
* Sort list results by name, not revision. List continuation (start key)
  functionality requires that keys be returned in ascending order.
* Only count keys remaining after the start key, not the total number of
  keys in the prefix.
* Return current revision in header along with error when unable to
  range on key.
* Don't ignore start key when listing with revision=0

Signed-off-by: Brad Davidson <brad.davidson@rancher.com>
  • Loading branch information
brandond committed May 8, 2024
1 parent 7484a03 commit 4070872
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 68 deletions.
36 changes: 13 additions & 23 deletions pkg/drivers/generic/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ const (
var _ server.Dialect = (*Generic)(nil)

var (
columns = "kv.id AS theid, kv.name, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value"
columns = "kv.id AS theid, kv.name AS thename, kv.created, kv.deleted, kv.create_revision, kv.prev_revision, kv.lease, kv.value, kv.old_value"
revSQL = `
SELECT MAX(rkv.id) AS id
FROM kine AS rkv`
Expand All @@ -39,16 +39,6 @@ var (
FROM kine AS crkv
WHERE crkv.name = 'compact_rev_key'`

idOfKey = `
AND
mkv.id <= ? AND
mkv.id > (
SELECT MAX(ikv.id) AS id
FROM kine AS ikv
WHERE
ikv.name = ? AND
ikv.id <= ?)`

listSQL = fmt.Sprintf(`
SELECT *
FROM (
Expand All @@ -66,7 +56,7 @@ var (
kv.deleted = 0 OR
?
) AS lkv
ORDER BY lkv.theid ASC
ORDER BY lkv.thename ASC
`, revSQL, compactRevSQL, columns)
)

Expand Down Expand Up @@ -216,21 +206,21 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig
FROM kine AS kv
WHERE kv.id = ?`, columns), paramCharacter, numbered),

GetCurrentSQL: q(fmt.Sprintf(listSQL, ""), paramCharacter, numbered),
GetCurrentSQL: q(fmt.Sprintf(listSQL, "AND mkv.name > ?"), paramCharacter, numbered),
ListRevisionStartSQL: q(fmt.Sprintf(listSQL, "AND mkv.id <= ?"), paramCharacter, numbered),
GetRevisionAfterSQL: q(fmt.Sprintf(listSQL, idOfKey), paramCharacter, numbered),
GetRevisionAfterSQL: q(fmt.Sprintf(listSQL, "AND mkv.name > ? AND mkv.id <= ?"), paramCharacter, numbered),

CountCurrentSQL: q(fmt.Sprintf(`
SELECT (%s), COUNT(c.theid)
FROM (
%s
) c`, revSQL, fmt.Sprintf(listSQL, "")), paramCharacter, numbered),
) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.name > ?")), paramCharacter, numbered),

CountRevisionSQL: q(fmt.Sprintf(`
SELECT (%s), COUNT(c.theid)
FROM (
%s
) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.id <= ?")), paramCharacter, numbered),
) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.name > ? AND mkv.id <= ?")), paramCharacter, numbered),

AfterSQL: q(fmt.Sprintf(`
SELECT (%s), (%s), %s
Expand Down Expand Up @@ -343,12 +333,12 @@ func (d *Generic) DeleteRevision(ctx context.Context, revision int64) error {
return err
}

func (d *Generic) ListCurrent(ctx context.Context, prefix string, limit int64, includeDeleted bool) (*sql.Rows, error) {
func (d *Generic) ListCurrent(ctx context.Context, prefix, startKey string, limit int64, includeDeleted bool) (*sql.Rows, error) {
sql := d.GetCurrentSQL
if limit > 0 {
sql = fmt.Sprintf("%s LIMIT %d", sql, limit)
}
return d.query(ctx, sql, prefix, includeDeleted)
return d.query(ctx, sql, prefix, startKey, includeDeleted)
}

func (d *Generic) List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (*sql.Rows, error) {
Expand All @@ -364,27 +354,27 @@ func (d *Generic) List(ctx context.Context, prefix, startKey string, limit, revi
if limit > 0 {
sql = fmt.Sprintf("%s LIMIT %d", sql, limit)
}
return d.query(ctx, sql, prefix, revision, startKey, revision, includeDeleted)
return d.query(ctx, sql, prefix, startKey, revision, includeDeleted)
}

func (d *Generic) CountCurrent(ctx context.Context, prefix string) (int64, int64, error) {
func (d *Generic) CountCurrent(ctx context.Context, prefix, startKey string) (int64, int64, error) {
var (
rev sql.NullInt64
id int64
)

row := d.queryRow(ctx, d.CountCurrentSQL, prefix, false)
row := d.queryRow(ctx, d.CountCurrentSQL, prefix, startKey, false)
err := row.Scan(&rev, &id)
return rev.Int64, id, err
}

func (d *Generic) Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) {
func (d *Generic) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) {
var (
rev sql.NullInt64
id int64
)

row := d.queryRow(ctx, d.CountRevisionSQL, prefix, revision, false)
row := d.queryRow(ctx, d.CountRevisionSQL, prefix, startKey, revision, false)
err := row.Scan(&rev, &id)
return rev.Int64, id, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/drivers/nats/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ func (b *Backend) CurrentRevision(ctx context.Context) (int64, error) {
}

// Count returns an exact count of the number of matching keys and the current revision of the database.
func (b *Backend) Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) {
count, err := b.kv.Count(ctx, prefix, revision)
func (b *Backend) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) {
count, err := b.kv.Count(ctx, prefix, startKey, revision)
if err != nil {
return 0, 0, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/drivers/nats/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ func TestBackend_Create(t *testing.T) {

time.Sleep(2 * time.Millisecond)

srev, count, err := b.Count(ctx, "/", 0)
srev, count, err := b.Count(ctx, "/", "", 0)
noErr(t, err)
expEqual(t, 4, srev)
expEqual(t, 4, count)

time.Sleep(time.Second)

srev, count, err = b.Count(ctx, "/", 0)
srev, count, err = b.Count(ctx, "/", "", 0)
noErr(t, err)
expEqual(t, 4, srev)
expEqual(t, 3, count)
Expand All @@ -149,7 +149,7 @@ func TestBackend_Create(t *testing.T) {

time.Sleep(2 * time.Millisecond)

srev, count, err = b.Count(ctx, "/", 0)
srev, count, err = b.Count(ctx, "/", "", 0)
noErr(t, err)
expEqual(t, 6, srev)
expEqual(t, 4, count)
Expand Down
2 changes: 1 addition & 1 deletion pkg/drivers/nats/kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ type keySeq struct {
seq uint64
}

func (e *KeyValue) Count(ctx context.Context, prefix string, revision int64) (int64, error) {
func (e *KeyValue) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, error) {
it := e.bt.Iter()

if prefix != "" {
Expand Down
8 changes: 4 additions & 4 deletions pkg/drivers/nats/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ func (b *BackendLogger) List(ctx context.Context, prefix, startKey string, limit
}

// Count returns an exact count of the number of matching keys and the current revision of the database
func (b *BackendLogger) Count(ctx context.Context, prefix string, revision int64) (revRet int64, count int64, err error) {
func (b *BackendLogger) Count(ctx context.Context, prefix, startKey string, revision int64) (revRet int64, count int64, err error) {
start := time.Now()
defer func() {
dur := time.Since(start)
fStr := "COUNT %s, rev=%d => rev=%d, count=%d, err=%v, duration=%s"
b.logMethod(dur, fStr, prefix, revision, revRet, count, err, dur)
fStr := "COUNT %s, start=%s, rev=%d => rev=%d, count=%d, err=%v, duration=%s"
b.logMethod(dur, fStr, prefix, startKey, revision, revRet, count, err, dur)
}()

return b.backend.Count(ctx, prefix, revision)
return b.backend.Count(ctx, prefix, startKey, revision)
}

func (b *BackendLogger) Update(ctx context.Context, key string, value []byte, revision, lease int64) (revRet int64, kvRet *server.KeyValue, updateRet bool, errRet error) {
Expand Down
10 changes: 5 additions & 5 deletions pkg/logstructured/logstructured.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ type Log interface {
CompactRevision(ctx context.Context) (int64, error)
CurrentRevision(ctx context.Context) (int64, error)
List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeletes bool) (int64, []*server.Event, error)
Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error)
After(ctx context.Context, prefix string, revision, limit int64) (int64, []*server.Event, error)
Watch(ctx context.Context, prefix string) <-chan []*server.Event
Count(ctx context.Context, prefix string, revision int64) (int64, int64, error)
Append(ctx context.Context, event *server.Event) (int64, error)
DbSize(ctx context.Context) (int64, error)
}
Expand Down Expand Up @@ -176,7 +176,7 @@ func (l *LogStructured) List(ctx context.Context, prefix, startKey string, limit

rev, events, err := l.log.List(ctx, prefix, startKey, limit, revision, false)
if err != nil {
return 0, nil, err
return rev, nil, err
}
if revision == 0 && len(events) == 0 {
// if no revision is requested and no events are returned, then
Expand All @@ -185,7 +185,7 @@ func (l *LogStructured) List(ctx context.Context, prefix, startKey string, limit
// been created.
currentRev, err := l.log.CurrentRevision(ctx)
if err != nil {
return 0, nil, err
return currentRev, nil, err
}
return l.List(ctx, prefix, startKey, limit, currentRev)
} else if revision != 0 {
Expand All @@ -199,11 +199,11 @@ func (l *LogStructured) List(ctx context.Context, prefix, startKey string, limit
return rev, kvs, nil
}

func (l *LogStructured) Count(ctx context.Context, prefix string, revision int64) (revRet int64, count int64, err error) {
func (l *LogStructured) Count(ctx context.Context, prefix, startKey string, revision int64) (revRet int64, count int64, err error) {
defer func() {
logrus.Tracef("COUNT %s, rev=%d => rev=%d, count=%d, err=%v", prefix, revision, revRet, count, err)
}()
rev, count, err := l.log.Count(ctx, prefix, revision)
rev, count, err := l.log.Count(ctx, prefix, startKey, revision)
if err != nil {
return 0, 0, err
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/logstructured/sqllog/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ func (s *SQLLog) List(ctx context.Context, prefix, startKey string, limit, revis
}

if revision == 0 {
rows, err = s.d.ListCurrent(ctx, prefix, limit, includeDeleted)
rows, err = s.d.ListCurrent(ctx, prefix, startKey, limit, includeDeleted)
} else {
rows, err = s.d.List(ctx, prefix, startKey, limit, revision, includeDeleted)
}
Expand Down Expand Up @@ -526,15 +526,15 @@ func canSkipRevision(rev, skip int64, skipTime time.Time) bool {
return rev == skip && time.Since(skipTime) > time.Second
}

func (s *SQLLog) Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) {
func (s *SQLLog) Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error) {
if strings.HasSuffix(prefix, "/") {
prefix += "%"
}

if revision == 0 {
return s.d.CountCurrent(ctx, prefix)
return s.d.CountCurrent(ctx, prefix, startKey)
}
return s.d.Count(ctx, prefix, revision)
return s.d.Count(ctx, prefix, startKey, revision)
}

func (s *SQLLog) Append(ctx context.Context, event *server.Event) (int64, error) {
Expand Down
6 changes: 1 addition & 5 deletions pkg/server/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,12 @@ func (l *LimitedServer) get(ctx context.Context, r *etcdserverpb.RangeRequest) (
}

rev, kv, err := l.backend.Get(ctx, string(r.Key), string(r.RangeEnd), r.Limit, r.Revision)
if err != nil {
return nil, err
}

resp := &RangeResponse{
Header: txnHeader(rev),
}
if kv != nil {
resp.Kvs = []*KeyValue{kv}
resp.Count = 1
}
return resp, nil
return resp, err
}
25 changes: 8 additions & 17 deletions pkg/server/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,13 @@ func (l *LimitedServer) list(ctx context.Context, r *etcdserverpb.RangeRequest)
revision := r.Revision

if r.CountOnly {
rev, count, err := l.backend.Count(ctx, prefix, revision)
if err != nil {
return nil, err
}
logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, revision, rev, count)
return &RangeResponse{
rev, count, err := l.backend.Count(ctx, prefix, start, revision)
resp := &RangeResponse{
Header: txnHeader(rev),
Count: count,
}, nil
}
logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, revision, rev, count)
return resp, err
}

limit := r.Limit
Expand All @@ -40,18 +38,14 @@ func (l *LimitedServer) list(ctx context.Context, r *etcdserverpb.RangeRequest)
}

rev, kvs, err := l.backend.List(ctx, prefix, start, limit, revision)
if err != nil {
return nil, err
}

logrus.Tracef("LIST key=%s, end=%s, revision=%d, currentRev=%d count=%d, limit=%d", r.Key, r.RangeEnd, revision, rev, len(kvs), r.Limit)
resp := &RangeResponse{
Header: txnHeader(rev),
Count: int64(len(kvs)),
Kvs: kvs,
}

// count the actual number of results if there are more items in the db.
// if the number of items returned exceeds the limit, count the keys remaining that follow the start key
if limit > 0 && resp.Count > r.Limit {
resp.More = true
resp.Kvs = kvs[0 : limit-1]
Expand All @@ -60,13 +54,10 @@ func (l *LimitedServer) list(ctx context.Context, r *etcdserverpb.RangeRequest)
revision = rev
}

rev, resp.Count, err = l.backend.Count(ctx, prefix, revision)
if err != nil {
return nil, err
}
rev, resp.Count, err = l.backend.Count(ctx, prefix, start, revision)
logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, revision, rev, resp.Count)
resp.Header = txnHeader(rev)
}

return resp, nil
return resp, err
}
8 changes: 4 additions & 4 deletions pkg/server/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@ type Backend interface {
Create(ctx context.Context, key string, value []byte, lease int64) (int64, error)
Delete(ctx context.Context, key string, revision int64) (int64, *KeyValue, bool, error)
List(ctx context.Context, prefix, startKey string, limit, revision int64) (int64, []*KeyValue, error)
Count(ctx context.Context, prefix string, revision int64) (int64, int64, error)
Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error)
Update(ctx context.Context, key string, value []byte, revision, lease int64) (int64, *KeyValue, bool, error)
Watch(ctx context.Context, key string, revision int64) WatchResult
DbSize(ctx context.Context) (int64, error)
CurrentRevision(ctx context.Context) (int64, error)
}

type Dialect interface {
ListCurrent(ctx context.Context, prefix string, limit int64, includeDeleted bool) (*sql.Rows, error)
ListCurrent(ctx context.Context, prefix, startKey string, limit int64, includeDeleted bool) (*sql.Rows, error)
List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (*sql.Rows, error)
CountCurrent(ctx context.Context, prefix string) (int64, int64, error)
Count(ctx context.Context, prefix string, revision int64) (int64, int64, error)
CountCurrent(ctx context.Context, prefix, startKey string) (int64, int64, error)
Count(ctx context.Context, prefix, startKey string, revision int64) (int64, int64, error)
CurrentRevision(ctx context.Context) (int64, error)
After(ctx context.Context, prefix string, rev, limit int64) (*sql.Rows, error)
Insert(ctx context.Context, key string, create, delete bool, createRevision, previousRevision int64, ttl int64, value, prevValue []byte) (int64, error)
Expand Down

0 comments on commit 4070872

Please sign in to comment.