Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support of keyspace field for BATCH message #1778

Open
wants to merge 2 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"context"
"errors"
"fmt"
"github.com/stretchr/testify/require"
"io"
"math"
"math/big"
Expand Down Expand Up @@ -3288,3 +3289,54 @@ func TestQuery_NamedValues(t *testing.T) {
t.Fatal(err)
}
}

func TestBatchKeyspaceField(t *testing.T) {
session := createSession(t)
defer session.Close()

if session.cfg.ProtoVersion < protoVersion5 {
t.Skip("keyspace for BATCH message is not supported in protocol < 5")
}

const keyspaceStmt = `
CREATE KEYSPACE IF NOT EXISTS gocql_keyspace_override_test
WITH replication = {
'class': 'SimpleStrategy',
'replication_factor': '1'
};
`

err := session.Query(keyspaceStmt).Exec()
if err != nil {
t.Fatal(err)
}

err = createTable(session, "CREATE TABLE IF NOT EXISTS gocql_keyspace_override_test.batch_keyspace(id int, value text, PRIMARY KEY (id))")
if err != nil {
t.Fatal(err)
}

ids := []int{1, 2}
texts := []string{"val1", "val2"}

b := session.NewBatch(LoggedBatch).SetKeyspace("gocql_keyspace_override_test")
b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[0], texts[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also do the test when one of queries also overrides the keyspace.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I see from the proto 5 spec we can not override the keyspace for the specific query in the Batch, only for the Batch itself.

4.1.7. BATCH

  Allows executing a list of queries (prepared or not) as a batch (note that
  only DML statements are accepted in a batch). The body of the message must
  be:
    <type><n><query_1>...<query_n><consistency><flags>[<serial_consistency>][<timestamp>][<keyspace>][<now_in_seconds>]

Copy link
Member

@lukasz-antoniak lukasz-antoniak Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that my comment might not be understandable. You can simulate hardcoding keyspace in CQL string itself. But, yeah it is not a very useful test indeed.

b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[1], texts[1])
err = session.ExecuteBatch(b)
if err != nil {
t.Fatal(err)
}

var (
id int
text string
)

iter := session.Query("SELECT * FROM gocql_keyspace_override_test.batch_keyspace").Iter()
defer iter.Close()

for i := 0; iter.Scan(&id, &text); i++ {
require.Equal(t, id, ids[i])
require.Equal(t, text, texts[i])
}
}
14 changes: 11 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,7 @@ type inflightPrepare struct {
preparedStatment *preparedStatment
}

func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) {
func (c *Conn) prepareStatementWithKeyspace(ctx context.Context, stmt string, tracer Tracer, keyspace string) (*preparedStatment, error) {
stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, stmt)
flight, ok := c.session.stmtsLRU.execIfMissing(stmtCacheKey, func(lru *lru.Cache) *inflightPrepare {
flight := &inflightPrepare{
Expand All @@ -1253,7 +1253,7 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer)
statement: stmt,
}
if c.version > protoVersion4 {
prep.keyspace = c.currentKeyspace
prep.keyspace = keyspace
}

// we won the race to do the load, if our context is canceled we shouldnt
Expand Down Expand Up @@ -1310,6 +1310,10 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer)
}
}

func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) (*preparedStatment, error) {
return c.prepareStatementWithKeyspace(ctx, stmt, tracer, c.currentKeyspace)
}

func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error {
if named, ok := value.(*namedValue); ok {
dst.name = named.name
Expand Down Expand Up @@ -1554,14 +1558,18 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
customPayload: batch.CustomPayload,
}

if c.version > protoVersion4 {
req.keyspace = batch.keyspace
}

stmts := make(map[string]string, len(batch.Entries))

for i := 0; i < n; i++ {
entry := &batch.Entries[i]
b := &req.statements[i]

if len(entry.Args) > 0 || entry.binding != nil {
info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace)
info, err := c.prepareStatementWithKeyspace(batch.Context(), entry.Stmt, batch.trace, batch.keyspace)
if err != nil {
return &Iter{err: err}
}
Expand Down
22 changes: 21 additions & 1 deletion frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -1173,9 +1173,15 @@ func (f *framer) parseResultPrepared() frame {
frame := &resultPreparedFrame{
frameHeader: *f.header,
preparedID: f.readShortBytes(),
reqMeta: f.parsePreparedMetadata(),
}

if f.proto > protoVersion4 {
// TODO handle result_metadata_id for native proto 5
_ = f.readShortBytes()
}

frame.reqMeta = f.parsePreparedMetadata()

if f.proto < protoVersion2 {
return frame
}
Expand Down Expand Up @@ -1659,6 +1665,9 @@ type writeBatchFrame struct {

//v4+
customPayload map[string][]byte

//v5+
keyspace string
}

func (w *writeBatchFrame) buildFrame(framer *framer, streamID int) error {
Expand Down Expand Up @@ -1718,6 +1727,13 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload
flags |= flagDefaultTimestamp
}

if w.keyspace != "" {
if f.proto < protoVersion5 {
panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher"))
}
flags |= flagWithKeyspace
}

if f.proto > protoVersion4 {
f.writeUint(uint32(flags))
} else {
Expand All @@ -1737,6 +1753,10 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload
}
f.writeLong(ts)
}

if w.keyspace != "" {
f.writeString(w.keyspace)
}
}

return f.finish()
Expand Down
6 changes: 6 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2042,6 +2042,12 @@ func (b *Batch) releaseAfterExecution() {
// that would race with speculative executions.
}

// SetKeyspace allows to specify the keyspace that the query should be executed in.
func (b *Batch) SetKeyspace(keyspace string) *Batch {
b.keyspace = keyspace
return b
}

type BatchType byte

const (
Expand Down