From 0e9c637d89162347f9f3757a45ec5076468146e3 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 27 Aug 2024 16:12:09 -0400 Subject: [PATCH] Expand cluster validation and move SSL initialization to the validation step --- cluster.go | 144 ++++++++++++++++++++++++++++++++++++++++++++-- connectionpool.go | 72 +---------------------- session.go | 3 +- 3 files changed, 144 insertions(+), 75 deletions(-) diff --git a/cluster.go b/cluster.go index 1b8470466..269052e1f 100644 --- a/cluster.go +++ b/cluster.go @@ -6,8 +6,13 @@ package gocql import ( "context" + "crypto/tls" + "crypto/x509" "errors" + "fmt" + "io/ioutil" "net" + "sync/atomic" "time" ) @@ -144,7 +149,8 @@ type ClusterConfig struct { // SslOpts configures TLS use when HostDialer is not set. // SslOpts is ignored if HostDialer is set. - SslOpts *SslOptions + SslOpts *SslOptions + actualSslOpts atomic.Value // Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server. // Default: true, only enabled for protocol 3 and above. @@ -349,6 +355,27 @@ func (cfg *ClusterConfig) filterHost(host *HostInfo) bool { return !(cfg.HostFilter == nil || cfg.HostFilter.Accept(host)) } +func (cfg *ClusterConfig) ValidateAndInitSSL() error { + if cfg.SslOpts == nil { + return nil + } + actualTLSConfig, err := setupTLSConfig(cfg.SslOpts) + if err != nil { + return fmt.Errorf("failed to initialize ssl configuration: %s", err.Error()) + } + + cfg.actualSslOpts.Store(actualTLSConfig) + return nil +} + +func (cfg *ClusterConfig) getActualTLSConfig() *tls.Config { + val, ok := cfg.actualSslOpts.Load().(*tls.Config) + if !ok { + return nil + } + return val.Clone() +} + func (cfg *ClusterConfig) Validate() error { if len(cfg.Hosts) == 0 { return ErrNoHosts @@ -363,7 +390,7 @@ func (cfg *ClusterConfig) Validate() error { } if cfg.InitialReconnectionPolicy.GetMaxRetries() <= 0 { - return errors.New("InitialReconnectionPolicy.GetMaxRetries returns non-positive number") + return errors.New("InitialReconnectionPolicy.GetMaxRetries returns negative number") } if cfg.ReconnectionPolicy == nil { @@ -371,10 +398,70 @@ func (cfg *ClusterConfig) Validate() error { } if cfg.InitialReconnectionPolicy.GetMaxRetries() <= 0 { - return errors.New("ReconnectionPolicy.GetMaxRetries returns non-positive number") + return errors.New("ReconnectionPolicy.GetMaxRetries returns negative number") } - return nil + if cfg.PageSize < 0 { + return errors.New("PageSize should be positive number or zero") + } + + if cfg.MaxRoutingKeyInfo < 0 { + return errors.New("MaxRoutingKeyInfo should be positive number or zero") + } + + if cfg.MaxPreparedStmts < 0 { + return errors.New("MaxPreparedStmts should be positive number or zero") + } + + if cfg.SocketKeepalive < 0 { + return errors.New("SocketKeepalive should be positive time.Duration or zero") + } + + if cfg.MaxRequestsPerConn < 0 { + return errors.New("MaxRequestsPerConn should be positive number or zero") + } + + if cfg.NumConns < 0 { + return errors.New("NumConns should be positive non-zero number or zero") + } + + if cfg.Port <= 0 || cfg.Port > 65535 { + return errors.New("Port should be a valid port number: a number between 1 and 65535") + } + + if cfg.WriteTimeout < 0 { + return errors.New("WriteTimeout should be positive time.Duration or zero") + } + + if cfg.Timeout < 0 { + return errors.New("Timeout should be positive time.Duration or zero") + } + + if cfg.ConnectTimeout < 0 { + return errors.New("ConnectTimeout should be positive time.Duration or zero") + } + + if cfg.MetadataSchemaRequestTimeout < 0 { + return errors.New("MetadataSchemaRequestTimeout should be positive time.Duration or zero") + } + + if cfg.WriteCoalesceWaitTime < 0 { + return errors.New("WriteCoalesceWaitTime should be positive time.Duration or zero") + } + + if cfg.ReconnectInterval < 0 { + return errors.New("ReconnectInterval should be positive time.Duration or zero") + } + + if cfg.MaxWaitSchemaAgreement < 0 { + return errors.New("MaxWaitSchemaAgreement should be positive time.Duration or zero") + } + + if cfg.ProtoVersion < 0 { + return errors.New("ProtoVersion should be positive number or zero") + } + + return cfg.ValidateAndInitSSL() } var ( @@ -382,3 +469,52 @@ var ( ErrNoConnectionsStarted = errors.New("no connections were made when creating the session") ErrHostQueryFailed = errors.New("unable to populate Hosts") ) + +func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) { + // Config.InsecureSkipVerify | EnableHostVerification | Result + // Config is nil | true | verify host + // Config is nil | false | do not verify host + // false | false | verify host + // true | false | do not verify host + // false | true | verify host + // true | true | verify host + var tlsConfig *tls.Config + if sslOpts.Config == nil { + tlsConfig = &tls.Config{ + InsecureSkipVerify: !sslOpts.EnableHostVerification, + } + } else { + // use clone to avoid race. + tlsConfig = sslOpts.Config.Clone() + } + + if tlsConfig.InsecureSkipVerify && sslOpts.EnableHostVerification { + tlsConfig.InsecureSkipVerify = false + } + + // ca cert is optional + if sslOpts.CaPath != "" { + if tlsConfig.RootCAs == nil { + tlsConfig.RootCAs = x509.NewCertPool() + } + + pem, err := ioutil.ReadFile(sslOpts.CaPath) + if err != nil { + return nil, fmt.Errorf("unable to open CA certs: %v", err) + } + + if !tlsConfig.RootCAs.AppendCertsFromPEM(pem) { + return nil, errors.New("failed parsing or CA certs") + } + } + + if sslOpts.CertPath != "" || sslOpts.KeyPath != "" { + mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath) + if err != nil { + return nil, fmt.Errorf("unable to load X509 key pair: %v", err) + } + tlsConfig.Certificates = append(tlsConfig.Certificates, mycert) + } + + return tlsConfig, nil +} diff --git a/connectionpool.go b/connectionpool.go index 6b9111f36..2cdec6ea4 100644 --- a/connectionpool.go +++ b/connectionpool.go @@ -5,11 +5,7 @@ package gocql import ( - "crypto/tls" - "crypto/x509" - "errors" "fmt" - "io/ioutil" "math/rand" "net" "sync" @@ -32,55 +28,6 @@ type SetTablets interface { SetTablets(tablets []*TabletInfo) } -func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) { - // Config.InsecureSkipVerify | EnableHostVerification | Result - // Config is nil | true | verify host - // Config is nil | false | do not verify host - // false | false | verify host - // true | false | do not verify host - // false | true | verify host - // true | true | verify host - var tlsConfig *tls.Config - if sslOpts.Config == nil { - tlsConfig = &tls.Config{ - InsecureSkipVerify: !sslOpts.EnableHostVerification, - } - } else { - // use clone to avoid race. - tlsConfig = sslOpts.Config.Clone() - } - - if tlsConfig.InsecureSkipVerify && sslOpts.EnableHostVerification { - tlsConfig.InsecureSkipVerify = false - } - - // ca cert is optional - if sslOpts.CaPath != "" { - if tlsConfig.RootCAs == nil { - tlsConfig.RootCAs = x509.NewCertPool() - } - - pem, err := ioutil.ReadFile(sslOpts.CaPath) - if err != nil { - return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err) - } - - if !tlsConfig.RootCAs.AppendCertsFromPEM(pem) { - return nil, errors.New("connectionpool: failed parsing or CA certs") - } - } - - if sslOpts.CertPath != "" || sslOpts.KeyPath != "" { - mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath) - if err != nil { - return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err) - } - tlsConfig.Certificates = append(tlsConfig.Certificates, mycert) - } - - return tlsConfig, nil -} - type policyConnPool struct { session *Session @@ -93,22 +40,9 @@ type policyConnPool struct { } func connConfig(cfg *ClusterConfig) (*ConnConfig, error) { - var ( - err error - hostDialer HostDialer - ) + hostDialer := cfg.HostDialer - hostDialer = cfg.HostDialer - var tlsConfig *tls.Config if hostDialer == nil { - // TODO(zariel): move tls config setup into session init. - if cfg.SslOpts != nil { - tlsConfig, err = setupTLSConfig(cfg.SslOpts) - if err != nil { - return nil, err - } - } - dialer := cfg.Dialer if dialer == nil { d := net.Dialer{ @@ -123,7 +57,7 @@ func connConfig(cfg *ClusterConfig) (*ConnConfig, error) { hostDialer = &scyllaDialer{ dialer: dialer, logger: cfg.logger(), - tlsConfig: tlsConfig, + tlsConfig: cfg.getActualTLSConfig(), cfg: cfg, } } @@ -141,7 +75,7 @@ func connConfig(cfg *ClusterConfig) (*ConnConfig, error) { AuthProvider: cfg.AuthProvider, Keepalive: cfg.SocketKeepalive, Logger: cfg.logger(), - tlsConfig: tlsConfig, + tlsConfig: cfg.getActualTLSConfig(), }, nil } diff --git a/session.go b/session.go index 40fe71960..3754c9115 100644 --- a/session.go +++ b/session.go @@ -120,9 +120,8 @@ func addrsToHosts(addrs []string, defaultPort int, logger StdLogger) ([]*HostInf // NewSession wraps an existing Node. func NewSession(cfg ClusterConfig) (*Session, error) { if err := cfg.Validate(); err != nil { - return nil, err + return nil, fmt.Errorf("gocql: unable to create session: cluster config validation failed: %v", err) } - // TODO: we should take a context in here at some point ctx, cancel := context.WithCancel(context.TODO())