Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand cluster validation and move SSL initialization to the validation step #245

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 140 additions & 4 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@ package gocql

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
"sync/atomic"
"time"
)

Expand Down Expand Up @@ -144,7 +149,8 @@ type ClusterConfig struct {

// SslOpts configures TLS use when HostDialer is not set.
// SslOpts is ignored if HostDialer is set.
SslOpts *SslOptions
SslOpts *SslOptions
actualSslOpts atomic.Value

// Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server.
// Default: true, only enabled for protocol 3 and above.
Expand Down Expand Up @@ -349,6 +355,27 @@ func (cfg *ClusterConfig) filterHost(host *HostInfo) bool {
return !(cfg.HostFilter == nil || cfg.HostFilter.Accept(host))
}

func (cfg *ClusterConfig) ValidateAndInitSSL() error {
if cfg.SslOpts == nil {
return nil
}
actualTLSConfig, err := setupTLSConfig(cfg.SslOpts)
if err != nil {
return fmt.Errorf("failed to initialize ssl configuration: %s", err.Error())
}

cfg.actualSslOpts.Store(actualTLSConfig)
return nil
}

func (cfg *ClusterConfig) getActualTLSConfig() *tls.Config {
val, ok := cfg.actualSslOpts.Load().(*tls.Config)
if !ok {
return nil
}
return val.Clone()
}

func (cfg *ClusterConfig) Validate() error {
if len(cfg.Hosts) == 0 {
return ErrNoHosts
Expand All @@ -363,22 +390,131 @@ func (cfg *ClusterConfig) Validate() error {
}

if cfg.InitialReconnectionPolicy.GetMaxRetries() <= 0 {
return errors.New("InitialReconnectionPolicy.GetMaxRetries returns non-positive number")
return errors.New("InitialReconnectionPolicy.GetMaxRetries returns negative number")
}

if cfg.ReconnectionPolicy == nil {
return errors.New("ReconnectionPolicy is nil")
}

if cfg.InitialReconnectionPolicy.GetMaxRetries() <= 0 {
return errors.New("ReconnectionPolicy.GetMaxRetries returns non-positive number")
return errors.New("ReconnectionPolicy.GetMaxRetries returns negative number")
}

return nil
if cfg.PageSize < 0 {
sylwiaszunejko marked this conversation as resolved.
Show resolved Hide resolved
return errors.New("PageSize should be positive number or zero")
}

if cfg.MaxRoutingKeyInfo < 0 {
return errors.New("MaxRoutingKeyInfo should be positive number or zero")
}

if cfg.MaxPreparedStmts < 0 {
return errors.New("MaxPreparedStmts should be positive number or zero")
}

if cfg.SocketKeepalive < 0 {
return errors.New("SocketKeepalive should be positive time.Duration or zero")
}

if cfg.MaxRequestsPerConn < 0 {
return errors.New("MaxRequestsPerConn should be positive number or zero")
}

if cfg.NumConns < 0 {
return errors.New("NumConns should be positive non-zero number or zero")
}

if cfg.Port <= 0 || cfg.Port > 65535 {
return errors.New("Port should be a valid port number: a number between 1 and 65535")
}

if cfg.WriteTimeout < 0 {
return errors.New("WriteTimeout should be positive time.Duration or zero")
}

if cfg.Timeout < 0 {
return errors.New("Timeout should be positive time.Duration or zero")
}

if cfg.ConnectTimeout < 0 {
return errors.New("ConnectTimeout should be positive time.Duration or zero")
}

if cfg.MetadataSchemaRequestTimeout < 0 {
return errors.New("MetadataSchemaRequestTimeout should be positive time.Duration or zero")
}

if cfg.WriteCoalesceWaitTime < 0 {
return errors.New("WriteCoalesceWaitTime should be positive time.Duration or zero")
}

if cfg.ReconnectInterval < 0 {
return errors.New("ReconnectInterval should be positive time.Duration or zero")
}

if cfg.MaxWaitSchemaAgreement < 0 {
return errors.New("MaxWaitSchemaAgreement should be positive time.Duration or zero")
}

if cfg.ProtoVersion < 0 {
return errors.New("ProtoVersion should be positive number or zero")
}

return cfg.ValidateAndInitSSL()
}

var (
ErrNoHosts = errors.New("no hosts provided")
ErrNoConnectionsStarted = errors.New("no connections were made when creating the session")
ErrHostQueryFailed = errors.New("unable to populate Hosts")
)

func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
// Config.InsecureSkipVerify | EnableHostVerification | Result
// Config is nil | true | verify host
// Config is nil | false | do not verify host
// false | false | verify host
// true | false | do not verify host
// false | true | verify host
// true | true | verify host
var tlsConfig *tls.Config
if sslOpts.Config == nil {
tlsConfig = &tls.Config{
InsecureSkipVerify: !sslOpts.EnableHostVerification,
}
} else {
// use clone to avoid race.
tlsConfig = sslOpts.Config.Clone()
}

if tlsConfig.InsecureSkipVerify && sslOpts.EnableHostVerification {
tlsConfig.InsecureSkipVerify = false
}

// ca cert is optional
if sslOpts.CaPath != "" {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}

pem, err := ioutil.ReadFile(sslOpts.CaPath)
if err != nil {
return nil, fmt.Errorf("unable to open CA certs: %v", err)
}

if !tlsConfig.RootCAs.AppendCertsFromPEM(pem) {
return nil, errors.New("failed parsing or CA certs")
}
}

if sslOpts.CertPath != "" || sslOpts.KeyPath != "" {
mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath)
if err != nil {
return nil, fmt.Errorf("unable to load X509 key pair: %v", err)
}
tlsConfig.Certificates = append(tlsConfig.Certificates, mycert)
}

return tlsConfig, nil
}
72 changes: 3 additions & 69 deletions connectionpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
package gocql

import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net"
"sync"
Expand All @@ -32,55 +28,6 @@ type SetTablets interface {
SetTablets(tablets []*TabletInfo)
}

func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
// Config.InsecureSkipVerify | EnableHostVerification | Result
// Config is nil | true | verify host
// Config is nil | false | do not verify host
// false | false | verify host
// true | false | do not verify host
// false | true | verify host
// true | true | verify host
var tlsConfig *tls.Config
if sslOpts.Config == nil {
tlsConfig = &tls.Config{
InsecureSkipVerify: !sslOpts.EnableHostVerification,
}
} else {
// use clone to avoid race.
tlsConfig = sslOpts.Config.Clone()
}

if tlsConfig.InsecureSkipVerify && sslOpts.EnableHostVerification {
tlsConfig.InsecureSkipVerify = false
}

// ca cert is optional
if sslOpts.CaPath != "" {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}

pem, err := ioutil.ReadFile(sslOpts.CaPath)
if err != nil {
return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err)
}

if !tlsConfig.RootCAs.AppendCertsFromPEM(pem) {
return nil, errors.New("connectionpool: failed parsing or CA certs")
}
}

if sslOpts.CertPath != "" || sslOpts.KeyPath != "" {
mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath)
if err != nil {
return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err)
}
tlsConfig.Certificates = append(tlsConfig.Certificates, mycert)
}

return tlsConfig, nil
}

type policyConnPool struct {
session *Session

Expand All @@ -93,22 +40,9 @@ type policyConnPool struct {
}

func connConfig(cfg *ClusterConfig) (*ConnConfig, error) {
var (
err error
hostDialer HostDialer
)
hostDialer := cfg.HostDialer

hostDialer = cfg.HostDialer
var tlsConfig *tls.Config
if hostDialer == nil {
// TODO(zariel): move tls config setup into session init.
if cfg.SslOpts != nil {
tlsConfig, err = setupTLSConfig(cfg.SslOpts)
if err != nil {
return nil, err
}
}

dialer := cfg.Dialer
if dialer == nil {
d := net.Dialer{
Expand All @@ -123,7 +57,7 @@ func connConfig(cfg *ClusterConfig) (*ConnConfig, error) {
hostDialer = &scyllaDialer{
dialer: dialer,
logger: cfg.logger(),
tlsConfig: tlsConfig,
tlsConfig: cfg.getActualTLSConfig(),
cfg: cfg,
}
}
Expand All @@ -141,7 +75,7 @@ func connConfig(cfg *ClusterConfig) (*ConnConfig, error) {
AuthProvider: cfg.AuthProvider,
Keepalive: cfg.SocketKeepalive,
Logger: cfg.logger(),
tlsConfig: tlsConfig,
tlsConfig: cfg.getActualTLSConfig(),
}, nil
}

Expand Down
3 changes: 1 addition & 2 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,8 @@ func addrsToHosts(addrs []string, defaultPort int, logger StdLogger) ([]*HostInf
// NewSession wraps an existing Node.
func NewSession(cfg ClusterConfig) (*Session, error) {
if err := cfg.Validate(); err != nil {
return nil, err
return nil, fmt.Errorf("gocql: unable to create session: cluster config validation failed: %v", err)
}

// TODO: we should take a context in here at some point
ctx, cancel := context.WithCancel(context.TODO())

Expand Down
Loading