From 281917781768a26cfc31446e09a8d1665656cc42 Mon Sep 17 00:00:00 2001 From: Maciej Zimnoch Date: Wed, 2 Nov 2022 17:39:47 +0100 Subject: [PATCH] UPSTREAM: : Use ConnectionConfig fields when establishing session to initial contact point When host information was missing, driver used resolved IP address as TLS.ServerName Instead it should connect to Server specified in ConnectionConfig and use NodeDomain as SNI. --- host_source_scylla.go | 7 + scyllacloud/config.go | 15 +- scyllacloud/hostdialer.go | 22 ++- scyllacloud/hostdialer_test.go | 350 +++++++++++++++++++++++++++++++++ 4 files changed, 388 insertions(+), 6 deletions(-) create mode 100644 host_source_scylla.go create mode 100644 scyllacloud/hostdialer_test.go diff --git a/host_source_scylla.go b/host_source_scylla.go new file mode 100644 index 000000000..be49315eb --- /dev/null +++ b/host_source_scylla.go @@ -0,0 +1,7 @@ +package gocql + +func (h *HostInfo) SetDatacenter(dc string) { + h.mu.Lock() + defer h.mu.Unlock() + h.dataCenter = dc +} diff --git a/scyllacloud/config.go b/scyllacloud/config.go index 52d4c56f1..722ed22fe 100644 --- a/scyllacloud/config.go +++ b/scyllacloud/config.go @@ -171,8 +171,19 @@ func (cc *ConnectionConfig) getDataOrReadFile(data []byte, path string) ([]byte, } func (cc *ConnectionConfig) GetClientCertificate() (*tls.Certificate, error) { - confContext := cc.Contexts[cc.CurrentContext] - authInfo := cc.AuthInfos[confContext.AuthInfoName] + if len(cc.CurrentContext) == 0 { + return nil, fmt.Errorf("current context is empty") + } + + confContext, ok := cc.Contexts[cc.CurrentContext] + if !ok { + return nil, fmt.Errorf("context %q does not exists", cc.CurrentContext) + } + + authInfo, ok := cc.AuthInfos[confContext.AuthInfoName] + if !ok { + return nil, fmt.Errorf("autoInfo %q does not exists", confContext.AuthInfoName) + } clientCert, err := cc.getDataOrReadFile(authInfo.ClientCertificateData, authInfo.ClientCertificatePath) if err != nil { diff --git a/scyllacloud/hostdialer.go b/scyllacloud/hostdialer.go index bba54a55e..4e2acc268 100644 --- a/scyllacloud/hostdialer.go +++ b/scyllacloud/hostdialer.go @@ -31,7 +31,7 @@ func NewSniHostDialer(connConfig *ConnectionConfig, dialer gocql.Dialer) *SniHos func (s *SniHostDialer) DialHost(ctx context.Context, host *gocql.HostInfo) (*gocql.DialedHost, error) { hostID := host.HostID() if len(hostID) == 0 { - return s.dialInitialContactPoint(ctx, host) + return s.dialInitialContactPoint(ctx) } dcName := host.DataCenter() @@ -77,7 +77,7 @@ func (s *SniHostDialer) DialHost(ctx context.Context, host *gocql.HostInfo) (*go }) } -func (s *SniHostDialer) dialInitialContactPoint(ctx context.Context, host *gocql.HostInfo) (*gocql.DialedHost, error) { +func (s *SniHostDialer) dialInitialContactPoint(ctx context.Context) (*gocql.DialedHost, error) { insecureSkipVerify := false for _, dc := range s.connConfig.Datacenters { if dc.InsecureSkipTLSVerify { @@ -96,8 +96,22 @@ func (s *SniHostDialer) dialInitialContactPoint(ctx context.Context, host *gocql return nil, fmt.Errorf("can't get root CA from configuration: %w", err) } - return s.connect(ctx, s.dialer, host.HostnameAndPort(), &tls.Config{ - ServerName: host.Hostname(), + if len(s.connConfig.CurrentContext) == 0 { + return nil, fmt.Errorf("current context is empty") + } + + contextConf, ok := s.connConfig.Contexts[s.connConfig.CurrentContext] + if !ok { + return nil, fmt.Errorf("context %q does not exists", s.connConfig.CurrentContext) + } + + dcConf, ok := s.connConfig.Datacenters[contextConf.DatacenterName] + if !ok { + return nil, fmt.Errorf("datacenter %q does not exists", contextConf.DatacenterName) + } + + return s.connect(ctx, s.dialer, dcConf.Server, &tls.Config{ + ServerName: dcConf.NodeDomain, RootCAs: ca, InsecureSkipVerify: insecureSkipVerify, Certificates: []tls.Certificate{*clientCertificate}, diff --git a/scyllacloud/hostdialer_test.go b/scyllacloud/hostdialer_test.go new file mode 100644 index 000000000..af5de8aef --- /dev/null +++ b/scyllacloud/hostdialer_test.go @@ -0,0 +1,350 @@ +package scyllacloud + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "net/http/httptest" + "reflect" + "testing" + "time" + + "github.com/gocql/gocql" +) + +const ( + testTimeout = time.Second +) + +func TestHostSNIDialer_InvalidConnectionConfig(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, serverCertPem, clientCertPem, clientKeyPem, err := setupTLSServer(nil) + if err != nil { + t.Fatal(err) + } + + dialer := &gocql.ScyllaShardAwareDialer{Dialer: net.Dialer{}} + + tt := []struct { + name string + connConfig *ConnectionConfig + hostInfo *gocql.HostInfo + expectedError error + }{ + { + name: "empty current context", + connConfig: func() *ConnectionConfig { + cc := newBasicConnectionConf("127.0.0.1:9142", serverCertPem, clientCertPem, clientKeyPem) + cc.CurrentContext = "" + return cc + }(), + hostInfo: &gocql.HostInfo{}, + expectedError: fmt.Errorf("can't get client certificate from configuration: %w", fmt.Errorf("current context is empty")), + }, + { + name: "current context is unknown", + connConfig: func() *ConnectionConfig { + cc := newBasicConnectionConf("127.0.0.1:9142", serverCertPem, clientCertPem, clientKeyPem) + cc.CurrentContext = "unknown-context" + return cc + }(), + hostInfo: &gocql.HostInfo{}, + expectedError: fmt.Errorf("can't get client certificate from configuration: %w", fmt.Errorf(`context "unknown-context" does not exists`)), + }, + { + name: "unknown default datacenter", + connConfig: func() *ConnectionConfig { + cc := newBasicConnectionConf("127.0.0.1:9142", serverCertPem, clientCertPem, clientKeyPem) + cc.Contexts[cc.CurrentContext].DatacenterName = "unknown-datacenter" + return cc + }(), + hostInfo: &gocql.HostInfo{}, + expectedError: fmt.Errorf(`datacenter "unknown-datacenter" does not exists`), + }, + { + name: "unknown host datacenter", + connConfig: newBasicConnectionConf("127.0.0.1:9142", serverCertPem, clientCertPem, clientKeyPem), + hostInfo: func() *gocql.HostInfo { + hi := &gocql.HostInfo{} + hi.SetDatacenter("unknown-datacenter") + hi.SetHostID("host-id") + return hi + }(), + expectedError: fmt.Errorf(`datacenter "unknown-datacenter" configuration not found in connection bundle`), + }, + } + for i := range tt { + tc := tt[i] + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + hostDialer := NewSniHostDialer(tc.connConfig, dialer) + _, err := hostDialer.DialHost(ctx, tc.hostInfo) + if !reflect.DeepEqual(err, tc.expectedError) { + t.Errorf("expected error to be %q, got %q", tc.expectedError, err) + } + }) + } +} + +func TestHostSNIDialer_ServerNameIdentifiers(t *testing.T) { + t.Parallel() + + tt := []struct { + name string + hostInfo func() *gocql.HostInfo + expectedSNI string + }{ + { + name: "node domain as SNI when host info is unknown", + hostInfo: func() *gocql.HostInfo { + return &gocql.HostInfo{} + }, + expectedSNI: "node.scylladb.com", + }, + { + name: "host SNI when host is known", + hostInfo: func() *gocql.HostInfo { + hi := &gocql.HostInfo{} + hi.SetHostID("host-1-uuid") + hi.SetDatacenter("us-east-1") + return hi + }, + expectedSNI: "host-1-uuid.node.scylladb.com", + }, + } + + for i := range tt { + tc := tt[i] + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + server, serverCertPem, clientCertPem, clientKeyPem, err := setupTLSServer([]string{"host-1-uuid.node.scylladb.com", "node.scylladb.com"}) + if err != nil { + t.Fatal(err) + } + + connectionStateCh := make(chan tls.ConnectionState, 1) + server.TLS.VerifyConnection = func(state tls.ConnectionState) error { + connectionStateCh <- state + return nil + } + + server.StartTLS() + defer server.Close() + + dialer := &gocql.ScyllaShardAwareDialer{Dialer: net.Dialer{}} + connConfig := newBasicConnectionConf(server.Listener.Addr().String(), serverCertPem, clientCertPem, clientKeyPem) + hostDialer := NewSniHostDialer(connConfig, dialer) + + _, err = hostDialer.DialHost(ctx, tc.hostInfo()) + if err != nil { + t.Fatal(err) + } + + select { + case receivedState := <-connectionStateCh: + if receivedState.ServerName != tc.expectedSNI { + t.Errorf("expected %s SNI, got %s", tc.expectedSNI, receivedState.ServerName) + } + case <-ctx.Done(): + t.Fatal("expected to receive connection, but timed out") + } + }) + } +} + +func setupTLSServer(dnsDomains []string) (*httptest.Server, []byte, []byte, []byte, error) { + clientCert, clientKey, err := generateClientCert() + if err != nil { + return nil, nil, nil, nil, err + } + + clientCertPem, err := encodeCertificates(clientCert) + if err != nil { + return nil, nil, nil, nil, err + } + + clientKeyPem, err := encodePrivateKey(clientKey) + if err != nil { + return nil, nil, nil, nil, err + } + + clientCAPool := x509.NewCertPool() + clientCAPool.AppendCertsFromPEM(clientCertPem) + + serverCert, serverKey, err := generateServingCert(dnsDomains) + if err != nil { + return nil, nil, nil, nil, err + } + + serverCertPem, err := encodeCertificates(serverCert) + if err != nil { + return nil, nil, nil, nil, err + } + + serverKeyPem, err := encodePrivateKey(serverKey) + if err != nil { + return nil, nil, nil, nil, err + } + + servingCert, err := tls.X509KeyPair(serverCertPem, serverKeyPem) + if err != nil { + return nil, nil, nil, nil, err + } + + server := httptest.NewUnstartedServer(nil) + + server.TLS = &tls.Config{ + Certificates: []tls.Certificate{servingCert}, + ClientCAs: clientCAPool, + ClientAuth: tls.RequestClientCert, + } + + return server, serverCertPem, clientCertPem, clientKeyPem, nil +} + +func newBasicConnectionConf(server string, serverCertPem, clientCertPem, clientKeyPem []byte) *ConnectionConfig { + return &ConnectionConfig{ + Datacenters: map[string]*Datacenter{ + "us-east-1": { + CertificateAuthorityData: serverCertPem, + Server: server, + NodeDomain: "node.scylladb.com", + }, + }, + AuthInfos: map[string]*AuthInfo{ + "admin": { + ClientCertificateData: clientCertPem, + ClientKeyData: clientKeyPem, + }, + }, + Contexts: map[string]*Context{ + "default": { + DatacenterName: "us-east-1", + AuthInfoName: "admin", + }, + }, + CurrentContext: "default", + } +} + +func generateServingCert(dnsNames []string) (*x509.Certificate, *rsa.PrivateKey, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 1028) + if err != nil { + return nil, nil, fmt.Errorf("can't generate private key: %w", err) + } + + commonName := "serving-cert" + cert, err := generateSelfSignedX509Certificate(commonName, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, dnsNames, &privateKey.PublicKey, privateKey) + if err != nil { + return nil, nil, err + } + + return cert, privateKey, nil +} + +func generateClientCert() (*x509.Certificate, *rsa.PrivateKey, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 1028) + if err != nil { + return nil, nil, fmt.Errorf("can't generate private key: %w", err) + } + + commonName := "client" + cert, err := generateSelfSignedX509Certificate(commonName, []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, nil, &privateKey.PublicKey, privateKey) + if err != nil { + return nil, nil, err + } + + return cert, privateKey, nil +} + +func generateSerialNumber() (*big.Int, error) { + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, err + } + + return serialNumber, nil +} + +func generateSelfSignedX509Certificate(cn string, extKeyUsage []x509.ExtKeyUsage, dnsNames []string, pub, priv interface{}) (*x509.Certificate, error) { + now := time.Now() + + serialNumber, err := generateSerialNumber() + if err != nil { + return nil, err + } + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: cn, + }, + IsCA: false, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: extKeyUsage, + NotBefore: now.Add(-1 * time.Second), + NotAfter: now.Add(time.Hour), + SignatureAlgorithm: x509.SHA512WithRSA, + BasicConstraintsValid: true, + SerialNumber: serialNumber, + DNSNames: dnsNames, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, pub, priv) + if err != nil { + return nil, fmt.Errorf("can't create certificate: %w", err) + } + + certs, err := x509.ParseCertificates(derBytes) + if err != nil { + return nil, fmt.Errorf("can't parse der encoded certificate: %w", err) + } + if len(certs) != 1 { + return nil, fmt.Errorf("expected to parse 1 certificate from der bytes but %d were present", len(certs)) + } + + return certs[0], nil +} + +func encodeCertificates(certificates ...*x509.Certificate) ([]byte, error) { + buffer := bytes.Buffer{} + for _, certificate := range certificates { + err := pem.Encode(&buffer, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certificate.Raw, + }) + if err != nil { + return nil, fmt.Errorf("can't pem encode certificate: %w", err) + } + } + return buffer.Bytes(), nil +} + +func encodePrivateKey(key *rsa.PrivateKey) ([]byte, error) { + buffer := bytes.Buffer{} + err := pem.Encode(&buffer, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + if err != nil { + return nil, fmt.Errorf("can't pem encode rsa private key: %w", err) + } + + return buffer.Bytes(), nil +}