Skip to content

Commit

Permalink
Use ConnectionConfig fields when establishing session to initial cont…
Browse files Browse the repository at this point in the history
…act point

When host information was missing, driver used resolved IP address as TLS.ServerName
Instead it should connect to Server specified in ConnectionConfig and use NodeDomain
as SNI.
  • Loading branch information
zimnx committed Nov 3, 2022
1 parent e3e27db commit 9337308
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 4 deletions.
6 changes: 6 additions & 0 deletions host_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ func (h *HostInfo) DataCenter() string {
return dc
}

func (h *HostInfo) SetDatacenter(dc string) {
h.mu.Lock()
defer h.mu.Unlock()
h.dataCenter = dc
}

func (h *HostInfo) Rack() string {
h.mu.RLock()
rack := h.rack
Expand Down
12 changes: 8 additions & 4 deletions scyllacloud/hostdialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func NewSniHostDialer(connConfig *ConnectionConfig, dialer gocql.Dialer) *SniHos
func (s *SniHostDialer) DialHost(ctx context.Context, host *gocql.HostInfo) (*gocql.DialedHost, error) {
hostID := host.HostID()
if len(hostID) == 0 {
return s.dialInitialContactPoint(ctx, host)
return s.dialInitialContactPoint(ctx)
}

dcName := host.DataCenter()
Expand Down Expand Up @@ -77,7 +77,7 @@ func (s *SniHostDialer) DialHost(ctx context.Context, host *gocql.HostInfo) (*go
})
}

func (s *SniHostDialer) dialInitialContactPoint(ctx context.Context, host *gocql.HostInfo) (*gocql.DialedHost, error) {
func (s *SniHostDialer) dialInitialContactPoint(ctx context.Context) (*gocql.DialedHost, error) {
insecureSkipVerify := false
for _, dc := range s.connConfig.Datacenters {
if dc.InsecureSkipTLSVerify {
Expand All @@ -96,8 +96,12 @@ func (s *SniHostDialer) dialInitialContactPoint(ctx context.Context, host *gocql
return nil, fmt.Errorf("can't get root CA from configuration: %w", err)
}

return s.connect(ctx, s.dialer, host.HostnameAndPort(), &tls.Config{
ServerName: host.Hostname(),
contextConf := s.connConfig.Contexts[s.connConfig.CurrentContext]
dcConf := s.connConfig.Datacenters[contextConf.DatacenterName]
sni := dcConf.NodeDomain

return s.connect(ctx, s.dialer, dcConf.Server, &tls.Config{
ServerName: sni,
RootCAs: ca,
InsecureSkipVerify: insecureSkipVerify,
Certificates: []tls.Certificate{*clientCertificate},
Expand Down
194 changes: 194 additions & 0 deletions scyllacloud/hostdialer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package scyllacloud

import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"net/http/httptest"
"testing"
"time"

"github.com/gocql/gocql"
)

func TestHostSNIDialer_ServerNameIdentifiers(t *testing.T) {
connectionStateCh := make(chan tls.ConnectionState, 1)

server := httptest.NewUnstartedServer(nil)
server.TLS = &tls.Config{
VerifyConnection: func(state tls.ConnectionState) error {
connectionStateCh <- state
return nil
},
}
server.StartTLS()
defer server.Close()

clientCert, clientKey, err := generateClientCert(server)
if err != nil {
t.Fatal(err)
}

certPem, err := encodeCertificates(clientCert)
if err != nil {
t.Fatal(err)
}

privateKeyPem, err := encodePrivateKey(clientKey)
if err != nil {
t.Fatal(err)
}

caCert, err := encodeCertificates(server.Certificate())
if err != nil {
t.Fatal(err)
}

dialer := &gocql.ScyllaShardAwareDialer{Dialer: net.Dialer{}}
connConfig := &ConnectionConfig{
Datacenters: map[string]*Datacenter{
"us-east-1": {
CertificateAuthorityData: caCert,
Server: server.Listener.Addr().String(),
NodeDomain: "node.domain.com",
// Test TLS server serving cert is signed for example.com domain,
// we skip verification here since we don't care about it in the test.
InsecureSkipTLSVerify: true,
},
},
AuthInfos: map[string]*AuthInfo{
"admin": {
ClientCertificateData: certPem,
ClientKeyData: privateKeyPem,
},
},
Contexts: map[string]*Context{
"default": {
DatacenterName: "us-east-1",
AuthInfoName: "admin",
},
},
CurrentContext: "default",
}

tt := []struct {
name string
hostInfo func() *gocql.HostInfo
expectedSNI string
}{
{
name: "node domain as SNI when host info is unknown",
hostInfo: func() *gocql.HostInfo {
return &gocql.HostInfo{}
},
expectedSNI: "node.domain.com",
},
{
name: "host SNI when host is known",
hostInfo: func() *gocql.HostInfo {
hi := &gocql.HostInfo{}
hi.SetHostID("host-1-uuid")
hi.SetDatacenter("us-east-1")
return hi
},
expectedSNI: "host-1-uuid.node.domain.com",
},
}

for i := range tt {
tc := tt[i]
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

hostDialer := NewSniHostDialer(connConfig, dialer)

_, err = hostDialer.DialHost(ctx, tc.hostInfo())
if err != nil {
t.Fatal(err)
}

receivedState := <-connectionStateCh
if receivedState.ServerName != tc.expectedSNI {
t.Errorf("expected %s SNI, got %s", tc.expectedSNI, receivedState.ServerName)
}
})
}
}

func generateClientCert(server *httptest.Server) (*x509.Certificate, *rsa.PrivateKey, error) {
now := time.Now()

privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, fmt.Errorf("can't generate private key: %w", err)
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, nil, err
}
template := &x509.Certificate{
Subject: pkix.Name{
CommonName: "client",
},
IsCA: false,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
NotBefore: now.Add(-1 * time.Second),
NotAfter: now.Add(time.Hour),
SignatureAlgorithm: x509.SHA512WithRSA,
BasicConstraintsValid: true,
SerialNumber: serialNumber,
}

derBytes, err := x509.CreateCertificate(rand.Reader, template, server.Certificate(), &privateKey.PublicKey, server.TLS.Certificates[0].PrivateKey)
if err != nil {
return nil, nil, fmt.Errorf("can't create certificate: %w", err)
}

certs, err := x509.ParseCertificates(derBytes)
if err != nil {
return nil, nil, fmt.Errorf("can't parse der encoded certificate: %w", err)
}
if len(certs) != 1 {
return nil, nil, fmt.Errorf("expected to parse 1 certificate from der bytes but %d were present", len(certs))
}

return certs[0], privateKey, nil
}

func encodeCertificates(certificates ...*x509.Certificate) ([]byte, error) {
buffer := bytes.Buffer{}
for _, certificate := range certificates {
err := pem.Encode(&buffer, &pem.Block{
Type: "CERTIFICATE",
Bytes: certificate.Raw,
})
if err != nil {
return nil, fmt.Errorf("can't pem encode certificate: %w", err)
}
}
return buffer.Bytes(), nil
}

func encodePrivateKey(key *rsa.PrivateKey) ([]byte, error) {
buffer := bytes.Buffer{}
err := pem.Encode(&buffer, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
if err != nil {
return nil, fmt.Errorf("can't pem encode rsa private key: %w", err)
}

return buffer.Bytes(), nil
}

0 comments on commit 9337308

Please sign in to comment.