Skip to content

Commit

Permalink
Merge pull request #1 from launchdarkly/ashanbrown/add-timestamp
Browse files Browse the repository at this point in the history
Support initial timestamp in shard iterator
  • Loading branch information
ashanbrown authored Jul 23, 2019
2 parents 7018c0c + 1f35253 commit 3e82db2
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
25 changes: 19 additions & 6 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"fmt"
"io/ioutil"
"log"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
Expand Down Expand Up @@ -60,6 +62,7 @@ func New(streamName string, opts ...Option) (*Consumer, error) {
type Consumer struct {
streamName string
initialShardIteratorType string
initialTimestamp *time.Time
client kinesisiface.KinesisAPI
logger Logger
group Group
Expand Down Expand Up @@ -97,22 +100,29 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
close(shardc)
}()

wg := new(sync.WaitGroup)
// process each of the shards
for shard := range shardc {
wg.Add(1)
go func(shardID string) {
defer wg.Done()
if err := c.ScanShard(ctx, shardID, fn); err != nil {
select {
case errc <- fmt.Errorf("shard %s error: %v", shardID, err):
// first error to occur
cancel()
default:
// error has already occured
// error has already occurred
}
}
}(aws.StringValue(shard.ShardId))
}

close(errc)
go func() {
wg.Wait()
close(errc)
}()

return <-errc
}

Expand All @@ -126,7 +136,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e
}

// get shard iterator
shardIterator, err := c.getShardIterator(c.streamName, shardID, lastSeqNum)
shardIterator, err := c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %v", err)
}
Expand All @@ -147,7 +157,7 @@ func (c *Consumer) ScanShard(ctx context.Context, shardID string, fn ScanFunc) e

// attempt to recover from GetRecords error by getting new shard iterator
if err != nil {
shardIterator, err = c.getShardIterator(c.streamName, shardID, lastSeqNum)
shardIterator, err = c.getShardIterator(ctx, c.streamName, shardID, lastSeqNum)
if err != nil {
return fmt.Errorf("get shard iterator error: %v", err)
}
Expand Down Expand Up @@ -190,7 +200,7 @@ func isShardClosed(nextShardIterator, currentShardIterator *string) bool {
return nextShardIterator == nil || currentShardIterator == nextShardIterator
}

func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string, error) {
func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, seqNum string) (*string, error) {
params := &kinesis.GetShardIteratorInput{
ShardId: aws.String(shardID),
StreamName: aws.String(streamName),
Expand All @@ -199,10 +209,13 @@ func (c *Consumer) getShardIterator(streamName, shardID, seqNum string) (*string
if seqNum != "" {
params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAfterSequenceNumber)
params.StartingSequenceNumber = aws.String(seqNum)
} else if c.initialTimestamp != nil {
params.ShardIteratorType = aws.String(kinesis.ShardIteratorTypeAtTimestamp)
params.Timestamp = c.initialTimestamp
} else {
params.ShardIteratorType = aws.String(c.initialShardIteratorType)
}

res, err := c.client.GetShardIterator(params)
res, err := c.client.GetShardIteratorWithContext(aws.Context(ctx), params)
return res.ShardIterator, err
}
13 changes: 12 additions & 1 deletion options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package consumer

import "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
import (
"time"

"github.com/aws/aws-sdk-go/service/kinesis/kinesisiface"
)

// Option is used to override defaults when creating a new Consumer
type Option func(*Consumer)
Expand Down Expand Up @@ -39,3 +43,10 @@ func WithShardIteratorType(t string) Option {
c.initialShardIteratorType = t
}
}

// Timestamp overrides the starting point for the consumer
func WithTimestamp(t time.Time) Option {
return func(c *Consumer) {
c.initialTimestamp = &t
}
}

0 comments on commit 3e82db2

Please sign in to comment.