diff --git a/interface.go b/interface.go index aa6e195..f5382bf 100644 --- a/interface.go +++ b/interface.go @@ -27,6 +27,7 @@ type WriteAPI interface { PutItem(ctx context.Context, pk, sk Attribute, item interface{}, opt ...PutOption) error DeleteItem(ctx context.Context, pk, sk string) error BatchDeleteItems(ctx context.Context, input []AttributeRecord) []AttributeRecord + UpdateItem(ctx context.Context, pk, sk Attribute, fields map[string]Attribute, opt ...UpdateOption) error } type TransactionAPI interface { diff --git a/tests/client_test.go b/tests/client_test.go index f263105..fc3a443 100644 --- a/tests/client_test.go +++ b/tests/client_test.go @@ -6,6 +6,7 @@ import ( "log" "os" "testing" + "time" "github.com/oolio-group/dynago" "github.com/ory/dockertest/v3" @@ -101,8 +102,18 @@ func TestNewClientLocalEndpoint(t *testing.T) { t.Fatalf("expected configuration to succeed, got %s", err) } - err = createTestTable(table) - if err != nil { - t.Fatalf("expected create table on local table to succeed, got %s", err) + maxRetries := 5 + for i := 0; i < maxRetries; i++ { + err = createTestTable(table) + if err == nil { + break + } + + if i == maxRetries-1 { + t.Fatalf("failed to create table after %d attempts: %s", maxRetries, err) + } + + t.Logf("Table creation attempt %d failed: %s. Retrying after 1 second...", i+1, err) + time.Sleep(1 * time.Second) } } diff --git a/tests/transact_items_test.go b/tests/transact_items_test.go index 8be681a..d14dd3e 100644 --- a/tests/transact_items_test.go +++ b/tests/transact_items_test.go @@ -79,6 +79,46 @@ func TestTransactItems(t *testing.T) { }, }, }, + { + title: "update multiple items with WithUpdateItem", + condition: "pk = :pk AND begins_with(sk, :sk)", + keys: map[string]types.AttributeValue{ + ":pk": &types.AttributeValueMemberS{Value: "terminal"}, + ":sk": &types.AttributeValueMemberS{Value: "merchant"}, + }, + newItems: []Terminal{ + { + Id: "2", + Pk: "terminal", + Sk: "merchant1", + }, + { + Id: "3", + Pk: "terminal", + Sk: "merchant2", + }, + }, + operations: []types.TransactWriteItem{ + table.WithUpdateItem("terminal", "merchant1", map[string]dynago.Attribute{ + "Id": dynago.StringValue("2-updated"), + }), + table.WithUpdateItem("terminal", "merchant2", map[string]dynago.Attribute{ + "Id": dynago.StringValue("3-updated"), + }), + }, + expected: []Terminal{ + { + Id: "2-updated", + Pk: "terminal", + Sk: "merchant1", + }, + { + Id: "3-updated", + Pk: "terminal", + Sk: "merchant2", + }, + }, + }, } for _, tc := range testCases { t.Run(tc.title, func(t *testing.T) { diff --git a/tests/updateitem_test.go b/tests/updateitem_test.go new file mode 100644 index 0000000..4cb0e09 --- /dev/null +++ b/tests/updateitem_test.go @@ -0,0 +1,222 @@ +package tests + +import ( + "context" + "fmt" + "reflect" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/oolio-group/dynago" +) + +type Account struct { + ID string + Balance int + Version uint + Status string + + Pk string + Sk string +} + +func TestUpdateItem(t *testing.T) { + table := prepareTable(t, dynamoEndpoint, "update_test") + testCases := []struct { + title string + item Account + updates map[string]dynago.Attribute + options []dynago.UpdateOption + expected Account + expectedErr error + }{ + { + title: "update fields success", + item: Account{ + ID: "1", + Balance: 100, + Status: "active", + Pk: "account_1", + Sk: "account_1", + }, + updates: map[string]dynago.Attribute{ + "Balance": dynago.NumberValue(200), + "Status": dynago.StringValue("inactive"), + }, + options: []dynago.UpdateOption{}, + expected: Account{ + ID: "1", + Balance: 200, + Status: "inactive", + Pk: "account_1", + Sk: "account_1", + }, + }, + { + title: "optimistic lock success", + item: Account{ + ID: "2", + Balance: 100, + Version: 1, + Pk: "account_2", + Sk: "account_2", + }, + updates: map[string]dynago.Attribute{ + "Balance": dynago.NumberValue(300), + }, + options: []dynago.UpdateOption{ + dynago.WithOptimisticLockForUpdate("Version", 1), + }, + expected: Account{ + ID: "2", + Balance: 300, + Version: 2, + Pk: "account_2", + Sk: "account_2", + }, + }, + { + title: "conditional update success", + item: Account{ + ID: "3", + Balance: 100, + Status: "active", + Pk: "account_3", + Sk: "account_3", + }, + updates: map[string]dynago.Attribute{ + "Status": dynago.StringValue("inactive"), + }, + options: []dynago.UpdateOption{ + dynago.WithConditionalUpdate("attribute_exists(Balance)"), + }, + expected: Account{ + ID: "3", + Balance: 100, + Status: "inactive", + Pk: "account_3", + Sk: "account_3", + }, + }, + { + title: "conditional update failure", + item: Account{ + ID: "4", + Balance: 100, + Pk: "account_4", + Sk: "account_4", + }, + updates: map[string]dynago.Attribute{ + "Status": dynago.StringValue("inactive"), + }, + options: []dynago.UpdateOption{ + dynago.WithConditionalUpdate("attribute_exists(NonExistentField)"), + }, + expectedErr: fmt.Errorf("ConditionalCheckFailedException"), + }, + } + + for _, tc := range testCases { + t.Run(tc.title, func(t *testing.T) { + t.Helper() + ctx := context.TODO() + + pk := dynago.StringValue(tc.item.Pk) + sk := dynago.StringValue(tc.item.Sk) + err := table.PutItem(ctx, pk, sk, &tc.item) + if err != nil { + t.Fatalf("unexpected error on initial put: %s", err) + } + + err = table.UpdateItem(ctx, pk, sk, tc.updates, tc.options...) + if err != nil { + if tc.expectedErr == nil { + t.Fatalf("unexpected error: %s", err) + } + if !strings.Contains(err.Error(), tc.expectedErr.Error()) { + t.Fatalf("expected op to fail with %s; got %s", tc.expectedErr, err) + } + return + } + + var out Account + err, found := table.GetItem(ctx, pk, sk, &out) + if err != nil { + t.Fatalf("unexpected error on get: %s", err) + } + if !found { + t.Errorf("expected to find item with pk %s and sk %s", tc.item.Pk, tc.item.Sk) + } + if !reflect.DeepEqual(tc.expected, out) { + t.Errorf("expected query to return %v; got %v", tc.expected, out) + } + }) + } +} + +func TestUpdateItemOptimisticLockConcurrency(t *testing.T) { + table := prepareTable(t, dynamoEndpoint, "update_optimistic_test") + account := Account{ID: "123", Balance: 0, Version: 0, Pk: "123", Sk: "123"} + ctx := context.Background() + pk := dynago.StringValue("123") + err := table.PutItem(ctx, pk, pk, account) + if err != nil { + t.Fatalf("unexpected error %s", err) + return + } + + update := func() error { + var acc Account + err, _ := table.GetItem(ctx, pk, pk, &acc) + if err != nil { + return err + } + + updates := map[string]dynago.Attribute{ + "Balance": dynago.NumberValue(int64(acc.Balance + 100)), + } + + return table.UpdateItem(ctx, pk, pk, updates, dynago.WithOptimisticLockForUpdate("Version", acc.Version)) + } + + var wg sync.WaitGroup + successCount := int32(0) + for range 10 { + wg.Add(1) + go func() { + defer wg.Done() + maxRetries := 10 + for i := 0; i < maxRetries; i++ { + err := update() + if err == nil { + atomic.AddInt32(&successCount, 1) + return + } + if strings.Contains(err.Error(), "ConditionalCheckFailedException") { + time.Sleep(100 * time.Millisecond) // Longer delay before retry + continue + } + t.Errorf("Unexpected error: %v", err) + return + } + t.Logf("Max retries reached, continuing") + }() + } + wg.Wait() + + var acc Account + err, _ = table.GetItem(ctx, pk, pk, &acc) + if err != nil { + t.Fatalf("unexpected error %s", err) + return + } + if acc.Balance != 1000 { + t.Errorf("expected account balance to be 1000 after 10 increments of 100; got %d", acc.Balance) + } + if acc.Version != 10 { + t.Errorf("expected account version to be 10 after 10 updates; got %d", acc.Version) + } +} diff --git a/transaction_items.go b/transaction_items.go index 4476282..0d4b2d5 100644 --- a/transaction_items.go +++ b/transaction_items.go @@ -2,7 +2,9 @@ package dynago import ( "context" + "fmt" "log" + "strings" "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" "github.com/aws/aws-sdk-go-v2/service/dynamodb" @@ -52,3 +54,49 @@ func (t *Client) TransactItems(ctx context.Context, input ...types.TransactWrite }) return err } + +func (t *Client) WithUpdateItem(pk string, sk string, updates map[string]Attribute, opts ...UpdateOption) types.TransactWriteItem { + var setExpressions []string + expressionAttributeNames := make(map[string]string) + expressionAttributeValues := make(map[string]Attribute) + + for key, value := range updates { + attrName := fmt.Sprintf("#%s", key) + attrValue := fmt.Sprintf(":%s", key) + + setExpressions = append(setExpressions, fmt.Sprintf("%s = %s", attrName, attrValue)) + expressionAttributeNames[attrName] = key + expressionAttributeValues[attrValue] = value + } + + updateExpression := fmt.Sprintf("SET %s", strings.Join(setExpressions, ", ")) + + input := &dynamodb.UpdateItemInput{ + TableName: &t.TableName, + Key: map[string]types.AttributeValue{ + "pk": &types.AttributeValueMemberS{Value: pk}, + "sk": &types.AttributeValueMemberS{Value: sk}, + }, + UpdateExpression: &updateExpression, + ExpressionAttributeNames: expressionAttributeNames, + ExpressionAttributeValues: expressionAttributeValues, + } + + for _, opt := range opts { + err := opt(input) + if err != nil { + panic(fmt.Sprintf("Failed to apply update option: %v", err)) + } + } + + return types.TransactWriteItem{ + Update: &types.Update{ + TableName: input.TableName, + Key: input.Key, + UpdateExpression: input.UpdateExpression, + ConditionExpression: input.ConditionExpression, + ExpressionAttributeNames: input.ExpressionAttributeNames, + ExpressionAttributeValues: input.ExpressionAttributeValues, + }, + } +} diff --git a/update_item.go b/update_item.go new file mode 100644 index 0000000..9f28274 --- /dev/null +++ b/update_item.go @@ -0,0 +1,101 @@ +package dynago + +import ( + "context" + "fmt" + "log" + "strings" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +type UpdateOption func(*dynamodb.UpdateItemInput) error + +func WithOptimisticLockForUpdate(key string, currentVersion uint) UpdateOption { + return func(input *dynamodb.UpdateItemInput) error { + condition := "#version = :oldVersion" + input.ConditionExpression = &condition + if input.ExpressionAttributeNames == nil { + input.ExpressionAttributeNames = map[string]string{} + } + if input.ExpressionAttributeValues == nil { + input.ExpressionAttributeValues = map[string]Attribute{} + } + input.ExpressionAttributeNames["#version"] = key + input.ExpressionAttributeValues[":oldVersion"] = NumberValue(int64(currentVersion)) + + if input.UpdateExpression != nil { + versionUpdate := fmt.Sprintf("%s, %s = :newVersion", *input.UpdateExpression, key) + input.UpdateExpression = &versionUpdate + } else { + versionUpdate := fmt.Sprintf("SET %s = :newVersion", key) + input.UpdateExpression = &versionUpdate + } + input.ExpressionAttributeValues[":newVersion"] = NumberValue(int64(currentVersion + 1)) + return nil + } +} + +func WithConditionalUpdate(conditionExpr string) UpdateOption { + return func(input *dynamodb.UpdateItemInput) error { + input.ConditionExpression = &conditionExpr + return nil + } +} + +func WithReturnValues(returnValue types.ReturnValue) UpdateOption { + return func(input *dynamodb.UpdateItemInput) error { + input.ReturnValues = returnValue + return nil + } +} + +func (t *Client) UpdateItem(ctx context.Context, pk, sk Attribute, updates map[string]Attribute, opts ...UpdateOption) error { + var setExpressions []string + expressionAttributeNames := make(map[string]string) + expressionAttributeValues := make(map[string]Attribute) + + for key, value := range updates { + attrName := fmt.Sprintf("#%s", key) + attrValue := fmt.Sprintf(":%s", key) + + setExpressions = append(setExpressions, fmt.Sprintf("%s = %s", attrName, attrValue)) + expressionAttributeNames[attrName] = key + expressionAttributeValues[attrValue] = value + } + + updateExpression := fmt.Sprintf("SET %s", strings.Join(setExpressions, ", ")) + + input := &dynamodb.UpdateItemInput{ + TableName: &t.TableName, + Key: t.NewKeys(pk, sk), + UpdateExpression: &updateExpression, + ExpressionAttributeNames: expressionAttributeNames, + ExpressionAttributeValues: expressionAttributeValues, + } + + if len(opts) > 0 { + for _, opt := range opts { + err := opt(input) + if err != nil { + return err + } + } + } + + _, err := t.client.UpdateItem(ctx, input) + if err != nil { + log.Println("Failed to update item: " + err.Error()) + return err + } + + return nil +} + +type TransactUpdateItemsInput struct { + PartitionKeyValue Attribute + SortKeyValue Attribute + Updates map[string]Attribute + Options []UpdateOption +}