Skip to content

Custom proto revision #1553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions clickhouse_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ type Options struct {

scheme string
ReadTimeout time.Duration

// ClientTCPProtocolVersion specifies the custom protocol revision, as defined in lib/proto/const.go
// if not specified, the latest supported protocol revision, proto.DBMS_TCP_PROTOCOL_VERSION , is used.
ClientTCPProtocolVersion uint64
}

func (o *Options) fromDSN(in string) error {
Expand Down Expand Up @@ -399,5 +403,9 @@ func (o Options) setDefaults() *Options {
o.Addr = []string{"localhost:8123"}
}
}
if o.ClientTCPProtocolVersion == 0 {
o.ClientTCPProtocolVersion = ClientTCPProtocolVersion
}

return &o
}
6 changes: 5 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
compressor = compress.NewWriter(compress.LevelZero, compress.None)
}

if opt.ClientTCPProtocolVersion < proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO || opt.ClientTCPProtocolVersion > proto.DBMS_TCP_PROTOCOL_VERSION {
return nil, fmt.Errorf("unsupported protocol revision")
}

var (
connect = &connect{
id: num,
Expand All @@ -99,7 +103,7 @@ func dial(ctx context.Context, addr string, num int, opt *Options) (*connect, er
debugf: debugf,
buffer: new(chproto.Buffer),
reader: chproto.NewReader(conn),
revision: ClientTCPProtocolVersion,
revision: opt.ClientTCPProtocolVersion,
structMap: &structMap{},
compression: compression,
connectedAt: time.Now(),
Expand Down
4 changes: 2 additions & 2 deletions conn_handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (c *connect) handshake(auth Auth) error {
{
c.buffer.PutByte(proto.ClientHello)
handshake := &proto.ClientHandshake{
ProtocolVersion: ClientTCPProtocolVersion,
ProtocolVersion: c.revision,
ClientName: c.opt.ClientInfo.String(),
ClientVersion: proto.Version{ClientVersionMajor, ClientVersionMinor, ClientVersionPatch}, //nolint:govet
}
Expand All @@ -60,7 +60,7 @@ func (c *connect) handshake(auth Auth) error {
case proto.ServerException:
return c.exception()
case proto.ServerHello:
if err := c.server.Decode(c.reader); err != nil {
if err := c.server.Decode(c.reader, c.revision); err != nil {
return err
}
case proto.ServerEndOfStream:
Expand Down
2 changes: 1 addition & 1 deletion conn_send_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (c *connect) sendQuery(body string, o *QueryOptions) error {
c.debugf("[send query] compression=%q %s", c.compression, body)
c.buffer.PutByte(proto.ClientQuery)
q := proto.Query{
ClientTCPProtocolVersion: ClientTCPProtocolVersion,
ClientTCPProtocolVersion: c.revision,
ClientName: c.opt.ClientInfo.String(),
ClientVersion: proto.Version{ClientVersionMajor, ClientVersionMinor, ClientVersionPatch}, //nolint:govet
ID: o.queryID,
Expand Down
9 changes: 5 additions & 4 deletions lib/proto/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func CheckMinVersion(constraint Version, version Version) bool {
return true
}

func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) {
func (srv *ServerHandshake) Decode(reader *chproto.Reader, clientRevision uint64) (err error) {
if srv.Name, err = reader.Str(); err != nil {
return fmt.Errorf("could not read server name: %v", err)
}
Expand All @@ -98,7 +98,8 @@ func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) {
if srv.Revision, err = reader.UVarInt(); err != nil {
return fmt.Errorf("could not read server revision: %v", err)
}
if srv.Revision >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE {
rev := min(clientRevision, srv.Revision)
if rev >= DBMS_MIN_REVISION_WITH_SERVER_TIMEZONE {
name, err := reader.Str()
if err != nil {
return fmt.Errorf("could not read server timezone: %v", err)
Expand All @@ -107,12 +108,12 @@ func (srv *ServerHandshake) Decode(reader *chproto.Reader) (err error) {
return fmt.Errorf("could not load time location: %v", err)
}
}
if srv.Revision >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME {
if rev >= DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME {
if srv.DisplayName, err = reader.Str(); err != nil {
return fmt.Errorf("could not read server display name: %v", err)
}
}
if srv.Revision >= DBMS_MIN_REVISION_WITH_VERSION_PATCH {
if rev >= DBMS_MIN_REVISION_WITH_VERSION_PATCH {
if srv.Version.Patch, err = reader.UVarInt(); err != nil {
return fmt.Errorf("could not read server patch: %v", err)
}
Expand Down
76 changes: 76 additions & 0 deletions tests/std/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import (
"time"

"github.com/ClickHouse/clickhouse-go/v2"
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
"github.com/ClickHouse/clickhouse-go/v2/lib/proto"
clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -231,6 +233,80 @@ func TestStdConnector(t *testing.T) {
require.NoError(t, err)
}

func TestCustomProtocolRevision(t *testing.T) {
env, err := GetStdTestEnvironment()
require.NoError(t, err)
useSSL, err := strconv.ParseBool(clickhouse_tests.GetEnv("CLICKHOUSE_USE_SSL", "false"))
require.NoError(t, err)
port := env.Port
var tlsConfig *tls.Config
if useSSL {
port = env.SslPort
tlsConfig = &tls.Config{}
}
baseOpts := clickhouse.Options{
Addr: []string{fmt.Sprintf("%s:%d", env.Host, port)},
Auth: clickhouse.Auth{
Database: "default",
Username: env.Username,
Password: env.Password,
},
Compression: &clickhouse.Compression{
Method: clickhouse.CompressionLZ4,
},
TLS: tlsConfig,
}
t.Run("unsupported proto versions", func(t *testing.T) {
badOpts := baseOpts
badOpts.ClientTCPProtocolVersion = proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO - 1
conn, _ := clickhouse.Open(&badOpts)
require.NotNil(t, conn)
err = conn.Ping(t.Context())
require.Error(t, err)
badOpts.ClientTCPProtocolVersion = proto.DBMS_TCP_PROTOCOL_VERSION + 1
conn, _ = clickhouse.Open(&badOpts)
require.NotNil(t, conn)
err = conn.Ping(t.Context())
require.Error(t, err)
})

t.Run("minimal proto version", func(t *testing.T) {
opts := baseOpts
opts.ClientTCPProtocolVersion = proto.DBMS_MIN_REVISION_WITH_CLIENT_INFO
conn, err := clickhouse.Open(&opts)
require.NoError(t, err)
require.NotNil(t, conn)
err = conn.Ping(t.Context())
require.NoError(t, err)

defer func() {
_ = conn.Exec(t.Context(), "DROP TABLE insert_example")
}()
err = conn.Exec(t.Context(), "DROP TABLE IF EXISTS insert_example")

err = conn.Exec(t.Context(), `
CREATE TABLE insert_example (
Col1 UInt64
) Engine = MergeTree() ORDER BY tuple()
`)
require.NoError(t, err)
var batch driver.Batch
batch, err = conn.PrepareBatch(t.Context(), "INSERT INTO insert_example (Col1)")
require.NoError(t, err)
require.NoError(t, batch.Append(10))
require.NoError(t, batch.Send())

rows, err := conn.Query(t.Context(), "SELECT Col1 FROM insert_example")
require.NoError(t, err)
count := 0
for rows.Next() {
count++
}
assert.Equal(t, 1, count)
})

}

func TestBlockBufferSize(t *testing.T) {
env, err := GetStdTestEnvironment()
require.NoError(t, err)
Expand Down