Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve TLS support #54

Merged
merged 1 commit into from
Jun 26, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package nsq

import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"reflect"
Expand All @@ -13,6 +15,11 @@ import (
"unsafe"
)

type configHandler interface {
Handles(option string) bool
Set(c *Config, option string, value interface{}) error
}

// Config is a struct of NSQ options
//
// The only valid way to create a Config is via NewConfig, using a struct literal will panic.
Expand All @@ -21,7 +28,8 @@ import (
//
// Use Set(key string, value interface{}) as an alternate way to set parameters
type Config struct {
initialized bool
initialized bool
configHandlers []configHandler

// Deadlines for network reads and writes
ReadTimeout time.Duration `opt:"read_timeout" min:"100ms" max:"5m" default:"60s"`
Expand Down Expand Up @@ -57,6 +65,7 @@ type Config struct {
SampleRate int32 `opt:"sample_rate" min:"0" max:"99"`

// TLS Settings
// use tls-root-ca-file and tls-insecure-skip-verify to set tls config options
TlsV1 bool `opt:"tls_v1"`
TlsConfig *tls.Config `opt:"tls_config"`

Expand Down Expand Up @@ -89,6 +98,7 @@ type Config struct {
// This must be used to initialize Config structs. Values can be set directly, or through Config.Set()
func NewConfig() *Config {
c := &Config{}
c.configHandlers = append(c.configHandlers, &tlsHandler{})
c.initialized = true
if err := c.setDefaults(); err != nil {
panic(err.Error())
Expand Down Expand Up @@ -119,6 +129,12 @@ func (c *Config) Set(option string, value interface{}) error {

c.assertInitialized()

for _, h := range c.configHandlers {
if h.Handles(option) {
return h.Set(c, option, value)
}
}

val := reflect.ValueOf(c).Elem()
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
Expand Down Expand Up @@ -234,6 +250,52 @@ func (c *Config) setDefaults() error {
return nil
}

type tlsHandler struct {
}

func (t *tlsHandler) Handles(option string) bool {
switch option {
case "tls-root-ca-file", "tls-insecure-skip-verify":
return true
}
return false
}
func (t *tlsHandler) Set(c *Config, option string, value interface{}) error {
if c.TlsConfig == nil {
c.TlsConfig = &tls.Config{}
}
val := reflect.ValueOf(c.TlsConfig).Elem()

switch option {
case "tls-root-ca-file":
filename, ok := value.(string)
if !ok {
return fmt.Errorf("ERROR: %v is not a string", value)
}
tlsCertPool := x509.NewCertPool()
ca_cert_file, err := ioutil.ReadFile(filename)
if err != nil {
return fmt.Errorf("ERROR: failed to read custom Certificate Authority file %s", err)
}
if !tlsCertPool.AppendCertsFromPEM(ca_cert_file) {
return fmt.Errorf("ERROR: failed to append certificates from Certificate Authority file")
}
c.TlsConfig.ClientCAs = tlsCertPool
return nil
case "tls-insecure-skip-verify":
fieldVal := val.FieldByName("InsecureSkipVerify")
dest := unsafeValueOf(fieldVal)
coercedVal, err := coerce(value, fieldVal.Type())
if err != nil {
return fmt.Errorf("failed to coerce option %s (%v) - %s",
option, value, err)
}
dest.Set(coercedVal)
return nil
}
return fmt.Errorf("unknown option %s", option)
}

// because Config contains private structs we can't use reflect.Value
// directly, instead we need to "unsafely" address the variable
func unsafeValueOf(val reflect.Value) reflect.Value {
Expand Down
9 changes: 8 additions & 1 deletion config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@ func TestConfigSet(t *testing.T) {
t.Error("No error when setting `tls_v1` to an invalid value")
}
if err := c.Set("tls_v1", true); err != nil {
t.Errorf("Error setting `tls_v1` config: %v", err)
t.Errorf("Error setting `tls_v1` config. %v", err)
}

if err := c.Set("tls-insecure-skip-verify", true); err != nil {
t.Errorf("Error setting `tls-insecure-skip-verify` config. %v", err)
}
if c.TlsConfig.InsecureSkipVerify != true {
t.Errorf("Error setting `tls-insecure-skip-verify` config: %v", c.TlsConfig)
}
}

Expand Down
19 changes: 15 additions & 4 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func NewConn(addr string, config *Config, delegate ConnDelegate) *Conn {
return &Conn{
addr: addr,

config: config,
config: config,
delegate: delegate,

maxRdyCount: 2500,
Expand Down Expand Up @@ -348,9 +348,20 @@ func (c *Conn) identify() (*IdentifyResponse, error) {
return resp, nil
}

func (c *Conn) upgradeTLS(conf *tls.Config) error {
c.tlsConn = tls.Client(c.conn, conf)
err := c.tlsConn.Handshake()
func (c *Conn) upgradeTLS(tlsConf *tls.Config) error {
// create a local copy of the config to set ServerName for this connection
var conf tls.Config
if tlsConf != nil {
conf = *tlsConf
}
host, _, err := net.SplitHostPort(c.addr)
if err != nil {
return err
}
conf.ServerName = host

c.tlsConn = tls.Client(c.conn, &conf)
err = c.tlsConn.Handshake()
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func (r *Consumer) getMaxInFlight() int32 {
// will allow in-flight, and updates all existing connections as appropriate.
//
// For example, ChangeMaxInFlight(0) would pause message flow
//
//
// If already connected, it updates the reader RDY state for each connection.
func (r *Consumer) ChangeMaxInFlight(maxInFlight int) {
if r.getMaxInFlight() == int32(maxInFlight) {
Expand Down