From 6224a0b111cadf512e5fbcdf047d1dd2e8be66e6 Mon Sep 17 00:00:00 2001 From: Bohdan Siryk Date: Fri, 12 Jul 2024 13:41:22 +0300 Subject: [PATCH 1/2] Support of keyspace field for BATCH message --- cassandra_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ conn.go | 4 ++++ frame.go | 14 ++++++++++++++ 3 files changed, 61 insertions(+) diff --git a/cassandra_test.go b/cassandra_test.go index 797a7cf7f..2838dbc24 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -3288,3 +3288,46 @@ func TestQuery_NamedValues(t *testing.T) { t.Fatal(err) } } + +func TestBatchKeyspaceField(t *testing.T) { + session := createSession(t) + defer session.Close() + + if session.cfg.ProtoVersion < protoVersion5 { + t.Skip("keyspace for BATCH message is not supported in protocol < 5") + } + + err := createTable(session, "CREATE TABLE batch_keyspace(id int, value text, PRIMARY KEY (id))") + if err != nil { + t.Fatal(err) + } + + ids := []int{1, 2} + texts := []string{"val1", "val2"} + + b := session.NewBatch(LoggedBatch) + b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[0], texts[0]) + b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[1], texts[1]) + err = session.ExecuteBatch(b) + if err != nil { + t.Fatal(err) + } + + var ( + id int + text string + ) + + iter := session.Query("SELECT * FROM batch_keyspace").Iter() + defer iter.Close() + + for i := 0; iter.Scan(&id, &text); i++ { + if id != ids[i] { + t.Fatalf("expected id %v, got %v", ids[i], id) + } + + if text != texts[i] { + t.Fatalf("expected text %v, got %v", texts[i], text) + } + } +} diff --git a/conn.go b/conn.go index 3daca6250..da834cc20 100644 --- a/conn.go +++ b/conn.go @@ -1554,6 +1554,10 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { customPayload: batch.CustomPayload, } + if c.version > protoVersion4 { + req.keyspace = c.currentKeyspace + } + stmts := make(map[string]string, len(batch.Entries)) for i := 0; i < n; i++ { diff --git a/frame.go b/frame.go index d374ae574..72e71b493 100644 --- a/frame.go +++ b/frame.go @@ -1659,6 +1659,9 @@ type writeBatchFrame struct { //v4+ customPayload map[string][]byte + + //v5+ + keyspace string } func (w *writeBatchFrame) buildFrame(framer *framer, streamID int) error { @@ -1718,6 +1721,13 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload flags |= flagDefaultTimestamp } + if w.keyspace != "" { + if f.proto < protoVersion5 { + panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher")) + } + flags |= flagWithKeyspace + } + if f.proto > protoVersion4 { f.writeUint(uint32(flags)) } else { @@ -1737,6 +1747,10 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload } f.writeLong(ts) } + + if w.keyspace != "" { + f.writeString(w.keyspace) + } } return f.finish() From 4374c11ca7bfdb33292c5cac3d54294a34cc58fb Mon Sep 17 00:00:00 2001 From: Bohdan Siryk Date: Tue, 10 Sep 2024 15:19:09 +0300 Subject: [PATCH 2/2] added SetKeyspace method to Batch that allows to specify the keyspace that the query should be executed on --- cassandra_test.go | 29 +++++++++++++++++++---------- conn.go | 12 ++++++++---- frame.go | 8 +++++++- session.go | 6 ++++++ 4 files changed, 40 insertions(+), 15 deletions(-) diff --git a/cassandra_test.go b/cassandra_test.go index 2838dbc24..fd3e317b7 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -32,6 +32,7 @@ import ( "context" "errors" "fmt" + "github.com/stretchr/testify/require" "io" "math" "math/big" @@ -3297,7 +3298,20 @@ func TestBatchKeyspaceField(t *testing.T) { t.Skip("keyspace for BATCH message is not supported in protocol < 5") } - err := createTable(session, "CREATE TABLE batch_keyspace(id int, value text, PRIMARY KEY (id))") + const keyspaceStmt = ` + CREATE KEYSPACE IF NOT EXISTS gocql_keyspace_override_test + WITH replication = { + 'class': 'SimpleStrategy', + 'replication_factor': '1' + }; +` + + err := session.Query(keyspaceStmt).Exec() + if err != nil { + t.Fatal(err) + } + + err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_keyspace_override_test.batch_keyspace(id int, value text, PRIMARY KEY (id))") if err != nil { t.Fatal(err) } @@ -3305,7 +3319,7 @@ func TestBatchKeyspaceField(t *testing.T) { ids := []int{1, 2} texts := []string{"val1", "val2"} - b := session.NewBatch(LoggedBatch) + b := session.NewBatch(LoggedBatch).SetKeyspace("gocql_keyspace_override_test") b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[0], texts[0]) b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[1], texts[1]) err = session.ExecuteBatch(b) @@ -3318,16 +3332,11 @@ func TestBatchKeyspaceField(t *testing.T) { text string ) - iter := session.Query("SELECT * FROM batch_keyspace").Iter() + iter := session.Query("SELECT * FROM gocql_keyspace_override_test.batch_keyspace").Iter() defer iter.Close() for i := 0; iter.Scan(&id, &text); i++ { - if id != ids[i] { - t.Fatalf("expected id %v, got %v", ids[i], id) - } - - if text != texts[i] { - t.Fatalf("expected text %v, got %v", texts[i], text) - } + require.Equal(t, id, ids[i]) + require.Equal(t, text, texts[i]) } } diff --git a/conn.go b/conn.go index da834cc20..4c2006171 100644 --- a/conn.go +++ b/conn.go @@ -1235,7 +1235,7 @@ type inflightPrepare struct { preparedStatment *preparedStatment } -func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) { +func (c *Conn) prepareStatementWithKeyspace(ctx context.Context, stmt string, tracer Tracer, keyspace string) (*preparedStatment, error) { stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt) flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare { flight := &inflightPrepare{ @@ -1253,7 +1253,7 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) statement: stmt, } if c.version > protoVersion4 { - prep.keyspace = c.currentKeyspace + prep.keyspace = keyspace } // we won the race to do the load, if our context is canceled we shouldnt @@ -1310,6 +1310,10 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) } } +func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) { + return c.prepareStatementWithKeyspace(ctx, stmt, tracer, c.currentKeyspace) +} + func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error { if named, ok := value.(*namedValue); ok { dst.name = named.name @@ -1555,7 +1559,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { } if c.version > protoVersion4 { - req.keyspace = c.currentKeyspace + req.keyspace = batch.keyspace } stmts := make(map[string]string, len(batch.Entries)) @@ -1565,7 +1569,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { b := &req.statements[i] if len(entry.Args) > 0 || entry.binding != nil { - info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace) + info, err := c.prepareStatementWithKeyspace(batch.Context(), entry.Stmt, batch.trace, batch.keyspace) if err != nil { return &Iter{err: err} } diff --git a/frame.go b/frame.go index 72e71b493..dc38c90c8 100644 --- a/frame.go +++ b/frame.go @@ -1173,9 +1173,15 @@ func (f *framer) parseResultPrepared() frame { frame := &resultPreparedFrame{ frameHeader: *f.header, preparedID: f.readShortBytes(), - reqMeta: f.parsePreparedMetadata(), } + if f.proto > protoVersion4 { + // TODO handle result_metadata_id for native proto 5 + _ = f.readShortBytes() + } + + frame.reqMeta = f.parsePreparedMetadata() + if f.proto < protoVersion2 { return frame } diff --git a/session.go b/session.go index a600b95f3..e8b053458 100644 --- a/session.go +++ b/session.go @@ -2042,6 +2042,12 @@ func (b *Batch) releaseAfterExecution() { // that would race with speculative executions. } +// SetKeyspace allows to specify the keyspace that the query should be executed in. +func (b *Batch) SetKeyspace(keyspace string) *Batch { + b.keyspace = keyspace + return b +} + type BatchType byte const (