Skip to content

Commit

Permalink
UPSTREAM: <carry>: Use ConnectionConfig fields when establishing sess…
Browse files Browse the repository at this point in the history
…ion to initial contact point

When host information was missing, driver used resolved IP address as TLS.ServerName
Instead it should connect to Server specified in ConnectionConfig and use NodeDomain
as SNI.
  • Loading branch information
zimnx committed Nov 7, 2022
1 parent e3e27db commit 2819177
Show file tree
Hide file tree
Showing 4 changed files with 388 additions and 6 deletions.
7 changes: 7 additions & 0 deletions host_source_scylla.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package gocql

func (h *HostInfo) SetDatacenter(dc string) {
h.mu.Lock()
defer h.mu.Unlock()
h.dataCenter = dc
}
15 changes: 13 additions & 2 deletions scyllacloud/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,19 @@ func (cc *ConnectionConfig) getDataOrReadFile(data []byte, path string) ([]byte,
}

func (cc *ConnectionConfig) GetClientCertificate() (*tls.Certificate, error) {
confContext := cc.Contexts[cc.CurrentContext]
authInfo := cc.AuthInfos[confContext.AuthInfoName]
if len(cc.CurrentContext) == 0 {
return nil, fmt.Errorf("current context is empty")
}

confContext, ok := cc.Contexts[cc.CurrentContext]
if !ok {
return nil, fmt.Errorf("context %q does not exists", cc.CurrentContext)
}

authInfo, ok := cc.AuthInfos[confContext.AuthInfoName]
if !ok {
return nil, fmt.Errorf("autoInfo %q does not exists", confContext.AuthInfoName)
}

clientCert, err := cc.getDataOrReadFile(authInfo.ClientCertificateData, authInfo.ClientCertificatePath)
if err != nil {
Expand Down
22 changes: 18 additions & 4 deletions scyllacloud/hostdialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func NewSniHostDialer(connConfig *ConnectionConfig, dialer gocql.Dialer) *SniHos
func (s *SniHostDialer) DialHost(ctx context.Context, host *gocql.HostInfo) (*gocql.DialedHost, error) {
hostID := host.HostID()
if len(hostID) == 0 {
return s.dialInitialContactPoint(ctx, host)
return s.dialInitialContactPoint(ctx)
}

dcName := host.DataCenter()
Expand Down Expand Up @@ -77,7 +77,7 @@ func (s *SniHostDialer) DialHost(ctx context.Context, host *gocql.HostInfo) (*go
})
}

func (s *SniHostDialer) dialInitialContactPoint(ctx context.Context, host *gocql.HostInfo) (*gocql.DialedHost, error) {
func (s *SniHostDialer) dialInitialContactPoint(ctx context.Context) (*gocql.DialedHost, error) {
insecureSkipVerify := false
for _, dc := range s.connConfig.Datacenters {
if dc.InsecureSkipTLSVerify {
Expand All @@ -96,8 +96,22 @@ func (s *SniHostDialer) dialInitialContactPoint(ctx context.Context, host *gocql
return nil, fmt.Errorf("can't get root CA from configuration: %w", err)
}

return s.connect(ctx, s.dialer, host.HostnameAndPort(), &tls.Config{
ServerName: host.Hostname(),
if len(s.connConfig.CurrentContext) == 0 {
return nil, fmt.Errorf("current context is empty")
}

contextConf, ok := s.connConfig.Contexts[s.connConfig.CurrentContext]
if !ok {
return nil, fmt.Errorf("context %q does not exists", s.connConfig.CurrentContext)
}

dcConf, ok := s.connConfig.Datacenters[contextConf.DatacenterName]
if !ok {
return nil, fmt.Errorf("datacenter %q does not exists", contextConf.DatacenterName)
}

return s.connect(ctx, s.dialer, dcConf.Server, &tls.Config{
ServerName: dcConf.NodeDomain,
RootCAs: ca,
InsecureSkipVerify: insecureSkipVerify,
Certificates: []tls.Certificate{*clientCertificate},
Expand Down
Loading

0 comments on commit 2819177

Please sign in to comment.