Skip to content

Commit

Permalink
Merge pull request #124 from yashkin/local-address-options
Browse files Browse the repository at this point in the history
Bind outbound connection to address from configuration
  • Loading branch information
mreiferson committed Mar 21, 2015
2 parents 596f983 + 740994f commit f9995d9
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 21 deletions.
53 changes: 37 additions & 16 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io/ioutil"
"log"
"net"
"os"
"reflect"
"strconv"
Expand Down Expand Up @@ -39,10 +40,16 @@ type Config struct {
// used to Initialize, Validate
configHandlers []configHandler

DialTimeout time.Duration `opt:"dial_timeout" default:"1s"`

// Deadlines for network reads and writes
ReadTimeout time.Duration `opt:"read_timeout" min:"100ms" max:"5m" default:"60s"`
WriteTimeout time.Duration `opt:"write_timeout" min:"100ms" max:"5m" default:"1s"`

// LocalAddr is the local address to use when dialing an nsqd.
// If empty, a local address is automatically chosen.
LocalAddr net.Addr `opt:"local_addr"`

// Duration between polling lookupd for new producers, and fractional jitter to add to
// the lookupd pool loop. this helps evenly distribute requests even if multiple consumers
// restart at the same time
Expand Down Expand Up @@ -456,6 +463,8 @@ func coerce(v interface{}, typ reflect.Type) (reflect.Value, error) {
v, err = coerceBool(v)
case "time.Duration":
v, err = coerceDuration(v)
case "net.Addr":
v, err = coerceAddr(v)
default:
v = nil
err = errors.New(fmt.Sprintf("invalid type %s", typ.String()))
Expand All @@ -476,14 +485,16 @@ func valueTypeCoerce(v interface{}, typ reflect.Type) reflect.Value {
tval.SetUint(val.Uint())
case "float32", "float64":
tval.SetFloat(val.Float())
default:
tval.Set(val)
}
return tval
}

func coerceString(v interface{}) (string, error) {
switch v.(type) {
switch v := v.(type) {
case string:
return v.(string), nil
return v, nil
case int, int16, int32, int64, uint, uint16, uint32, uint64:
return fmt.Sprintf("%d", v), nil
case float32, float64:
Expand All @@ -495,27 +506,37 @@ func coerceString(v interface{}) (string, error) {
}

func coerceDuration(v interface{}) (time.Duration, error) {
switch v.(type) {
switch v := v.(type) {
case string:
return time.ParseDuration(v.(string))
return time.ParseDuration(v)
case int, int16, int32, int64:
// treat like ms
return time.Duration(reflect.ValueOf(v).Int()) * time.Millisecond, nil
case uint, uint16, uint32, uint64:
// treat like ms
return time.Duration(reflect.ValueOf(v).Uint()) * time.Millisecond, nil
case time.Duration:
return v.(time.Duration), nil
return v, nil
}
return 0, errors.New("invalid value type")
}

func coerceAddr(v interface{}) (net.Addr, error) {
switch v := v.(type) {
case string:
return net.ResolveTCPAddr("tcp", v)
case net.Addr:
return v, nil
}
return nil, errors.New("invalid value type")
}

func coerceBool(v interface{}) (bool, error) {
switch v.(type) {
switch v := v.(type) {
case bool:
return v.(bool), nil
return v, nil
case string:
return strconv.ParseBool(v.(string))
return strconv.ParseBool(v)
case int, int16, int32, int64:
return reflect.ValueOf(v).Int() != 0, nil
case uint, uint16, uint32, uint64:
Expand All @@ -525,25 +546,25 @@ func coerceBool(v interface{}) (bool, error) {
}

func coerceFloat64(v interface{}) (float64, error) {
switch v.(type) {
switch v := v.(type) {
case string:
return strconv.ParseFloat(v.(string), 64)
return strconv.ParseFloat(v, 64)
case int, int16, int32, int64:
return float64(reflect.ValueOf(v).Int()), nil
case uint, uint16, uint32, uint64:
return float64(reflect.ValueOf(v).Uint()), nil
case float32:
return float64(v.(float32)), nil
return float64(v), nil
case float64:
return v.(float64), nil
return v, nil
}
return 0, errors.New("invalid value type")
}

func coerceInt64(v interface{}) (int64, error) {
switch v.(type) {
switch v := v.(type) {
case string:
return strconv.ParseInt(v.(string), 10, 64)
return strconv.ParseInt(v, 10, 64)
case int, int16, int32, int64:
return reflect.ValueOf(v).Int(), nil
case uint, uint16, uint32, uint64:
Expand All @@ -553,9 +574,9 @@ func coerceInt64(v interface{}) (int64, error) {
}

func coerceUint64(v interface{}) (uint64, error) {
switch v.(type) {
switch v := v.(type) {
case string:
return strconv.ParseUint(v.(string), 10, 64)
return strconv.ParseUint(v, 10, 64)
case int, int16, int32, int64:
return uint64(reflect.ValueOf(v).Int()), nil
case uint, uint16, uint32, uint64:
Expand Down
22 changes: 18 additions & 4 deletions config_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package nsq

import "testing"
import (
"net"
"testing"
)

func TestConfigSet(t *testing.T) {
c := NewConfig()
Expand All @@ -11,7 +14,7 @@ func TestConfigSet(t *testing.T) {
t.Error("No error when setting `tls_v1` to an invalid value")
}
if err := c.Set("tls_v1", true); err != nil {
t.Errorf("Error setting `tls_v1` config. %v", err)
t.Errorf("Error setting `tls_v1` config. %s", err)
}

if err := c.Set("tls-insecure-skip-verify", true); err != nil {
Expand All @@ -21,11 +24,23 @@ func TestConfigSet(t *testing.T) {
t.Errorf("Error setting `tls-insecure-skip-verify` config: %v", c.TlsConfig)
}
if err := c.Set("tls-min-version", "tls1.2"); err != nil {
t.Errorf("Error setting `tls-min-version` config: %v", err)
t.Errorf("Error setting `tls-min-version` config: %s", err)
}
if err := c.Set("tls-min-version", "tls1.3"); err == nil {
t.Error("No error when setting `tls-min-version` to an invalid value")
}
if err := c.Set("local_addr", &net.TCPAddr{}); err != nil {
t.Errorf("Error setting `local_addr` config: %s", err)
}
if err := c.Set("local_addr", "1.2.3.4:27015"); err != nil {
t.Errorf("Error setting `local_addr` config: %s", err)
}
if err := c.Set("dial_timeout", "5s"); err != nil {
t.Errorf("Error setting `dial_timeout` config: %s", err)
}
if c.LocalAddr.String() != "1.2.3.4:27015" {
t.Error("Failed to assign `local_addr` config")
}
}

func TestConfigValidate(t *testing.T) {
Expand All @@ -37,5 +52,4 @@ func TestConfigValidate(t *testing.T) {
if err := c.Validate(); err == nil {
t.Error("no error set for invalid value")
}

}
7 changes: 6 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ func (c *Conn) getLogger() (logger, LogLevel, string) {
// Connect dials and bootstraps the nsqd connection
// (including IDENTIFY) and returns the IdentifyResponse
func (c *Conn) Connect() (*IdentifyResponse, error) {
conn, err := net.DialTimeout("tcp", c.addr, time.Second)
dialer := &net.Dialer{
LocalAddr: c.config.LocalAddr,
Timeout: c.config.DialTimeout,
}

conn, err := dialer.Dial("tcp", c.addr)
if err != nil {
return nil, err
}
Expand Down
9 changes: 9 additions & 0 deletions consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"strconv"
Expand Down Expand Up @@ -140,6 +141,9 @@ func TestConsumerTLSClientCertViaSet(t *testing.T) {

func consumerTest(t *testing.T, cb func(c *Config)) {
config := NewConfig()
laddr := "127.0.0.2"
// so that the test can simulate binding consumer to specified address
config.LocalAddr, _ = net.ResolveTCPAddr("tcp", laddr+":0")
// so that the test can simulate reaching max requeues and a call to LogFailedMessage
config.DefaultRequeueDelay = 0
// so that the test wont timeout from backing off
Expand Down Expand Up @@ -187,6 +191,11 @@ func consumerTest(t *testing.T, cb func(c *Config)) {
t.Fatal("should not be able to connect to the same NSQ twice")
}

conn := q.conns()[0]
if !strings.HasPrefix(conn.conn.LocalAddr().String(), laddr) {
t.Fatal("connection should be bound to the specified address:", conn.conn.LocalAddr())
}

err = q.DisconnectFromNSQD("1.2.3.4:4150")
if err == nil {
t.Fatal("should not be able to disconnect from an unknown nsqd")
Expand Down
11 changes: 11 additions & 0 deletions producer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"errors"
"io/ioutil"
"log"
"net"
"os"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -40,6 +42,10 @@ func (h *ConsumerHandler) HandleMessage(message *Message) error {

func TestProducerConnection(t *testing.T) {
config := NewConfig()
laddr := "127.0.0.2"

config.LocalAddr, _ = net.ResolveTCPAddr("tcp", laddr+":0")

w, _ := NewProducer("127.0.0.1:4150", config)
w.SetLogger(nullLogger, LogLevelInfo)

Expand All @@ -48,6 +54,11 @@ func TestProducerConnection(t *testing.T) {
t.Fatalf("should lazily connect")
}

conn := w.conn.(*Conn)
if !strings.HasPrefix(conn.conn.LocalAddr().String(), laddr) {
t.Fatal("producer connection should be bound to specified address:", conn.conn.LocalAddr())
}

w.Stop()

err = w.Publish("write_test", []byte("fail test"))
Expand Down

0 comments on commit f9995d9

Please sign in to comment.