forked from apache/cassandra-gocql-driver
-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use ConnectionConfig fields when establishing session to initial cont…
…act point When host information was missing, driver used resolved IP address as TLS.ServerName Instead it should connect to Server specified in ConnectionConfig and use NodeDomain as SNI.
- Loading branch information
Showing
3 changed files
with
208 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
package scyllacloud | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"crypto/rand" | ||
"crypto/rsa" | ||
"crypto/tls" | ||
"crypto/x509" | ||
"crypto/x509/pkix" | ||
"encoding/pem" | ||
"fmt" | ||
"math/big" | ||
"net" | ||
"net/http/httptest" | ||
"testing" | ||
"time" | ||
|
||
"github.com/gocql/gocql" | ||
) | ||
|
||
func TestHostSNIDialer_ServerNameIdentifiers(t *testing.T) { | ||
connectionStateCh := make(chan tls.ConnectionState, 1) | ||
|
||
server := httptest.NewUnstartedServer(nil) | ||
server.TLS = &tls.Config{ | ||
VerifyConnection: func(state tls.ConnectionState) error { | ||
connectionStateCh <- state | ||
return nil | ||
}, | ||
} | ||
server.StartTLS() | ||
defer server.Close() | ||
|
||
clientCert, clientKey, err := generateClientCert(server) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
certPem, err := encodeCertificates(clientCert) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
privateKeyPem, err := encodePrivateKey(clientKey) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
caCert, err := encodeCertificates(server.Certificate()) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
dialer := &gocql.ScyllaShardAwareDialer{Dialer: net.Dialer{}} | ||
connConfig := &ConnectionConfig{ | ||
Datacenters: map[string]*Datacenter{ | ||
"us-east-1": { | ||
CertificateAuthorityData: caCert, | ||
Server: server.Listener.Addr().String(), | ||
NodeDomain: "node.domain.com", | ||
// Test TLS server serving cert is signed for example.com domain, | ||
// we skip verification here since we don't care about it in the test. | ||
InsecureSkipTLSVerify: true, | ||
}, | ||
}, | ||
AuthInfos: map[string]*AuthInfo{ | ||
"admin": { | ||
ClientCertificateData: certPem, | ||
ClientKeyData: privateKeyPem, | ||
}, | ||
}, | ||
Contexts: map[string]*Context{ | ||
"default": { | ||
DatacenterName: "us-east-1", | ||
AuthInfoName: "admin", | ||
}, | ||
}, | ||
CurrentContext: "default", | ||
} | ||
|
||
tt := []struct { | ||
name string | ||
hostInfo func() *gocql.HostInfo | ||
expectedSNI string | ||
}{ | ||
{ | ||
name: "node domain as SNI when host info is unknown", | ||
hostInfo: func() *gocql.HostInfo { | ||
return &gocql.HostInfo{} | ||
}, | ||
expectedSNI: "node.domain.com", | ||
}, | ||
{ | ||
name: "host SNI when host is known", | ||
hostInfo: func() *gocql.HostInfo { | ||
hi := &gocql.HostInfo{} | ||
hi.SetHostID("host-1-uuid") | ||
hi.SetDatacenter("us-east-1") | ||
return hi | ||
}, | ||
expectedSNI: "host-1-uuid.node.domain.com", | ||
}, | ||
} | ||
|
||
for i := range tt { | ||
tc := tt[i] | ||
t.Run(tc.name, func(t *testing.T) { | ||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
|
||
hostDialer := NewSniHostDialer(connConfig, dialer) | ||
|
||
_, err = hostDialer.DialHost(ctx, tc.hostInfo()) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
receivedState := <-connectionStateCh | ||
if receivedState.ServerName != tc.expectedSNI { | ||
t.Errorf("expected %s SNI, got %s", tc.expectedSNI, receivedState.ServerName) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func generateClientCert(server *httptest.Server) (*x509.Certificate, *rsa.PrivateKey, error) { | ||
now := time.Now() | ||
|
||
privateKey, err := rsa.GenerateKey(rand.Reader, 4096) | ||
if err != nil { | ||
return nil, nil, fmt.Errorf("can't generate private key: %w", err) | ||
} | ||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) | ||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
template := &x509.Certificate{ | ||
Subject: pkix.Name{ | ||
CommonName: "client", | ||
}, | ||
IsCA: false, | ||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, | ||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, | ||
NotBefore: now.Add(-1 * time.Second), | ||
NotAfter: now.Add(time.Hour), | ||
SignatureAlgorithm: x509.SHA512WithRSA, | ||
BasicConstraintsValid: true, | ||
SerialNumber: serialNumber, | ||
} | ||
|
||
derBytes, err := x509.CreateCertificate(rand.Reader, template, server.Certificate(), &privateKey.PublicKey, server.TLS.Certificates[0].PrivateKey) | ||
if err != nil { | ||
return nil, nil, fmt.Errorf("can't create certificate: %w", err) | ||
} | ||
|
||
certs, err := x509.ParseCertificates(derBytes) | ||
if err != nil { | ||
return nil, nil, fmt.Errorf("can't parse der encoded certificate: %w", err) | ||
} | ||
if len(certs) != 1 { | ||
return nil, nil, fmt.Errorf("expected to parse 1 certificate from der bytes but %d were present", len(certs)) | ||
} | ||
|
||
return certs[0], privateKey, nil | ||
} | ||
|
||
func encodeCertificates(certificates ...*x509.Certificate) ([]byte, error) { | ||
buffer := bytes.Buffer{} | ||
for _, certificate := range certificates { | ||
err := pem.Encode(&buffer, &pem.Block{ | ||
Type: "CERTIFICATE", | ||
Bytes: certificate.Raw, | ||
}) | ||
if err != nil { | ||
return nil, fmt.Errorf("can't pem encode certificate: %w", err) | ||
} | ||
} | ||
return buffer.Bytes(), nil | ||
} | ||
|
||
func encodePrivateKey(key *rsa.PrivateKey) ([]byte, error) { | ||
buffer := bytes.Buffer{} | ||
err := pem.Encode(&buffer, &pem.Block{ | ||
Type: "RSA PRIVATE KEY", | ||
Bytes: x509.MarshalPKCS1PrivateKey(key), | ||
}) | ||
if err != nil { | ||
return nil, fmt.Errorf("can't pem encode rsa private key: %w", err) | ||
} | ||
|
||
return buffer.Bytes(), nil | ||
} |