From 77f8b9b2e1a01b212257452b5e4f1544c1dadc54 Mon Sep 17 00:00:00 2001 From: Jehiah Czebotar Date: Wed, 25 Jun 2014 11:42:51 -0400 Subject: [PATCH] per-connection TLS config and set ServerName --- config.go | 64 +++++++++++++++++++++++++++++++++++++++++++++++++- config_test.go | 9 ++++++- conn.go | 19 +++++++++++---- consumer.go | 2 +- 4 files changed, 87 insertions(+), 7 deletions(-) diff --git a/config.go b/config.go index c1f0d0e2..82c815ef 100644 --- a/config.go +++ b/config.go @@ -2,8 +2,10 @@ package nsq import ( "crypto/tls" + "crypto/x509" "errors" "fmt" + "io/ioutil" "log" "os" "reflect" @@ -13,6 +15,11 @@ import ( "unsafe" ) +type configHandler interface { + Handles(option string) bool + Set(c *Config, option string, value interface{}) error +} + // Config is a struct of NSQ options // // The only valid way to create a Config is via NewConfig, using a struct literal will panic. @@ -21,7 +28,8 @@ import ( // // Use Set(key string, value interface{}) as an alternate way to set parameters type Config struct { - initialized bool + initialized bool + configHandlers []configHandler // Deadlines for network reads and writes ReadTimeout time.Duration `opt:"read_timeout" min:"100ms" max:"5m" default:"60s"` @@ -57,6 +65,7 @@ type Config struct { SampleRate int32 `opt:"sample_rate" min:"0" max:"99"` // TLS Settings + // use tls-root-ca-file and tls-insecure-skip-verify to set tls config options TlsV1 bool `opt:"tls_v1"` TlsConfig *tls.Config `opt:"tls_config"` @@ -89,6 +98,7 @@ type Config struct { // This must be used to initialize Config structs. Values can be set directly, or through Config.Set() func NewConfig() *Config { c := &Config{} + c.configHandlers = append(c.configHandlers, &tlsHandler{}) c.initialized = true if err := c.setDefaults(); err != nil { panic(err.Error()) @@ -119,6 +129,12 @@ func (c *Config) Set(option string, value interface{}) error { c.assertInitialized() + for _, h := range c.configHandlers { + if h.Handles(option) { + return h.Set(c, option, value) + } + } + val := reflect.ValueOf(c).Elem() typ := val.Type() for i := 0; i < typ.NumField(); i++ { @@ -234,6 +250,52 @@ func (c *Config) setDefaults() error { return nil } +type tlsHandler struct { +} + +func (t *tlsHandler) Handles(option string) bool { + switch option { + case "tls-root-ca-file", "tls-insecure-skip-verify": + return true + } + return false +} +func (t *tlsHandler) Set(c *Config, option string, value interface{}) error { + if c.TlsConfig == nil { + c.TlsConfig = &tls.Config{} + } + val := reflect.ValueOf(c.TlsConfig).Elem() + + switch option { + case "tls-root-ca-file": + filename, ok := value.(string) + if !ok { + return fmt.Errorf("ERROR: %v is not a string", value) + } + tlsCertPool := x509.NewCertPool() + ca_cert_file, err := ioutil.ReadFile(filename) + if err != nil { + return fmt.Errorf("ERROR: failed to read custom Certificate Authority file %s", err) + } + if !tlsCertPool.AppendCertsFromPEM(ca_cert_file) { + return fmt.Errorf("ERROR: failed to append certificates from Certificate Authority file") + } + c.TlsConfig.ClientCAs = tlsCertPool + return nil + case "tls-insecure-skip-verify": + fieldVal := val.FieldByName("InsecureSkipVerify") + dest := unsafeValueOf(fieldVal) + coercedVal, err := coerce(value, fieldVal.Type()) + if err != nil { + return fmt.Errorf("failed to coerce option %s (%v) - %s", + option, value, err) + } + dest.Set(coercedVal) + return nil + } + return fmt.Errorf("unknown option %s", option) +} + // because Config contains private structs we can't use reflect.Value // directly, instead we need to "unsafely" address the variable func unsafeValueOf(val reflect.Value) reflect.Value { diff --git a/config_test.go b/config_test.go index 95445dc7..78a9887b 100644 --- a/config_test.go +++ b/config_test.go @@ -11,7 +11,14 @@ func TestConfigSet(t *testing.T) { t.Error("No error when setting `tls_v1` to an invalid value") } if err := c.Set("tls_v1", true); err != nil { - t.Errorf("Error setting `tls_v1` config: %v", err) + t.Errorf("Error setting `tls_v1` config. %v", err) + } + + if err := c.Set("tls-insecure-skip-verify", true); err != nil { + t.Errorf("Error setting `tls-insecure-skip-verify` config. %v", err) + } + if c.TlsConfig.InsecureSkipVerify != true { + t.Errorf("Error setting `tls-insecure-skip-verify` config: %v", c.TlsConfig) } } diff --git a/conn.go b/conn.go index d756382f..03c09439 100644 --- a/conn.go +++ b/conn.go @@ -96,7 +96,7 @@ func NewConn(addr string, config *Config, delegate ConnDelegate) *Conn { return &Conn{ addr: addr, - config: config, + config: config, delegate: delegate, maxRdyCount: 2500, @@ -348,9 +348,20 @@ func (c *Conn) identify() (*IdentifyResponse, error) { return resp, nil } -func (c *Conn) upgradeTLS(conf *tls.Config) error { - c.tlsConn = tls.Client(c.conn, conf) - err := c.tlsConn.Handshake() +func (c *Conn) upgradeTLS(tlsConf *tls.Config) error { + // create a local copy of the config to set ServerName for this connection + var conf tls.Config + if tlsConf != nil { + conf = *tlsConf + } + host, _, err := net.SplitHostPort(c.addr) + if err != nil { + return err + } + conf.ServerName = host + + c.tlsConn = tls.Client(c.conn, &conf) + err = c.tlsConn.Handshake() if err != nil { return err } diff --git a/consumer.go b/consumer.go index 47ac2dcc..a5b4ee79 100644 --- a/consumer.go +++ b/consumer.go @@ -205,7 +205,7 @@ func (r *Consumer) getMaxInFlight() int32 { // will allow in-flight, and updates all existing connections as appropriate. // // For example, ChangeMaxInFlight(0) would pause message flow -// +// // If already connected, it updates the reader RDY state for each connection. func (r *Consumer) ChangeMaxInFlight(maxInFlight int) { if r.getMaxInFlight() == int32(maxInFlight) {