Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bind outbound connection to address from configuration #124

Merged
merged 1 commit into from
Mar 21, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io/ioutil"
"log"
"net"
"os"
"reflect"
"strconv"
Expand Down Expand Up @@ -39,10 +40,16 @@ type Config struct {
// used to Initialize, Validate
configHandlers []configHandler

DialTimeout time.Duration `opt:"dial_timeout" default:"1s"`

// Deadlines for network reads and writes
ReadTimeout time.Duration `opt:"read_timeout" min:"100ms" max:"5m" default:"60s"`
WriteTimeout time.Duration `opt:"write_timeout" min:"100ms" max:"5m" default:"1s"`

// LocalAddr is the local address to use when dialing an nsqd.
// If empty, a local address is automatically chosen.
LocalAddr net.Addr `opt:"local_addr"`

// Duration between polling lookupd for new producers, and fractional jitter to add to
// the lookupd pool loop. this helps evenly distribute requests even if multiple consumers
// restart at the same time
Expand Down Expand Up @@ -456,6 +463,8 @@ func coerce(v interface{}, typ reflect.Type) (reflect.Value, error) {
v, err = coerceBool(v)
case "time.Duration":
v, err = coerceDuration(v)
case "net.Addr":
v, err = coerceAddr(v)
default:
v = nil
err = errors.New(fmt.Sprintf("invalid type %s", typ.String()))
Expand All @@ -476,14 +485,16 @@ func valueTypeCoerce(v interface{}, typ reflect.Type) reflect.Value {
tval.SetUint(val.Uint())
case "float32", "float64":
tval.SetFloat(val.Float())
default:
tval.Set(val)
}
return tval
}

func coerceString(v interface{}) (string, error) {
switch v.(type) {
switch v := v.(type) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not shadow v here (and all the other instances below)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's golang idiom - reuse name of variables in a type switch statements
https://golang.org/doc/effective_go.html#type_switch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL - thanks!

case string:
return v.(string), nil
return v, nil
case int, int16, int32, int64, uint, uint16, uint32, uint64:
return fmt.Sprintf("%d", v), nil
case float32, float64:
Expand All @@ -495,27 +506,37 @@ func coerceString(v interface{}) (string, error) {
}

func coerceDuration(v interface{}) (time.Duration, error) {
switch v.(type) {
switch v := v.(type) {
case string:
return time.ParseDuration(v.(string))
return time.ParseDuration(v)
case int, int16, int32, int64:
// treat like ms
return time.Duration(reflect.ValueOf(v).Int()) * time.Millisecond, nil
case uint, uint16, uint32, uint64:
// treat like ms
return time.Duration(reflect.ValueOf(v).Uint()) * time.Millisecond, nil
case time.Duration:
return v.(time.Duration), nil
return v, nil
}
return 0, errors.New("invalid value type")
}

func coerceAddr(v interface{}) (net.Addr, error) {
switch v := v.(type) {
case string:
return net.ResolveTCPAddr("tcp", v)
case net.Addr:
return v, nil
}
return nil, errors.New("invalid value type")
}

func coerceBool(v interface{}) (bool, error) {
switch v.(type) {
switch v := v.(type) {
case bool:
return v.(bool), nil
return v, nil
case string:
return strconv.ParseBool(v.(string))
return strconv.ParseBool(v)
case int, int16, int32, int64:
return reflect.ValueOf(v).Int() != 0, nil
case uint, uint16, uint32, uint64:
Expand All @@ -525,25 +546,25 @@ func coerceBool(v interface{}) (bool, error) {
}

func coerceFloat64(v interface{}) (float64, error) {
switch v.(type) {
switch v := v.(type) {
case string:
return strconv.ParseFloat(v.(string), 64)
return strconv.ParseFloat(v, 64)
case int, int16, int32, int64:
return float64(reflect.ValueOf(v).Int()), nil
case uint, uint16, uint32, uint64:
return float64(reflect.ValueOf(v).Uint()), nil
case float32:
return float64(v.(float32)), nil
return float64(v), nil
case float64:
return v.(float64), nil
return v, nil
}
return 0, errors.New("invalid value type")
}

func coerceInt64(v interface{}) (int64, error) {
switch v.(type) {
switch v := v.(type) {
case string:
return strconv.ParseInt(v.(string), 10, 64)
return strconv.ParseInt(v, 10, 64)
case int, int16, int32, int64:
return reflect.ValueOf(v).Int(), nil
case uint, uint16, uint32, uint64:
Expand All @@ -553,9 +574,9 @@ func coerceInt64(v interface{}) (int64, error) {
}

func coerceUint64(v interface{}) (uint64, error) {
switch v.(type) {
switch v := v.(type) {
case string:
return strconv.ParseUint(v.(string), 10, 64)
return strconv.ParseUint(v, 10, 64)
case int, int16, int32, int64:
return uint64(reflect.ValueOf(v).Int()), nil
case uint, uint16, uint32, uint64:
Expand Down
22 changes: 18 additions & 4 deletions config_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package nsq

import "testing"
import (
"net"
"testing"
)

func TestConfigSet(t *testing.T) {
c := NewConfig()
Expand All @@ -11,7 +14,7 @@ func TestConfigSet(t *testing.T) {
t.Error("No error when setting `tls_v1` to an invalid value")
}
if err := c.Set("tls_v1", true); err != nil {
t.Errorf("Error setting `tls_v1` config. %v", err)
t.Errorf("Error setting `tls_v1` config. %s", err)
}

if err := c.Set("tls-insecure-skip-verify", true); err != nil {
Expand All @@ -21,11 +24,23 @@ func TestConfigSet(t *testing.T) {
t.Errorf("Error setting `tls-insecure-skip-verify` config: %v", c.TlsConfig)
}
if err := c.Set("tls-min-version", "tls1.2"); err != nil {
t.Errorf("Error setting `tls-min-version` config: %v", err)
t.Errorf("Error setting `tls-min-version` config: %s", err)
}
if err := c.Set("tls-min-version", "tls1.3"); err == nil {
t.Error("No error when setting `tls-min-version` to an invalid value")
}
if err := c.Set("local_addr", &net.TCPAddr{}); err != nil {
t.Errorf("Error setting `local_addr` config: %s", err)
}
if err := c.Set("local_addr", "1.2.3.4:27015"); err != nil {
t.Errorf("Error setting `local_addr` config: %s", err)
}
if err := c.Set("dial_timeout", "5s"); err != nil {
t.Errorf("Error setting `dial_timeout` config: %s", err)
}
if c.LocalAddr.String() != "1.2.3.4:27015" {
t.Error("Failed to assign `local_addr` config")
}
}

func TestConfigValidate(t *testing.T) {
Expand All @@ -37,5 +52,4 @@ func TestConfigValidate(t *testing.T) {
if err := c.Validate(); err == nil {
t.Error("no error set for invalid value")
}

}
7 changes: 6 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ func (c *Conn) getLogger() (logger, LogLevel, string) {
// Connect dials and bootstraps the nsqd connection
// (including IDENTIFY) and returns the IdentifyResponse
func (c *Conn) Connect() (*IdentifyResponse, error) {
conn, err := net.DialTimeout("tcp", c.addr, time.Second)
dialer := &net.Dialer{
LocalAddr: c.config.LocalAddr,
Timeout: c.config.DialTimeout,
}

conn, err := dialer.Dial("tcp", c.addr)
if err != nil {
return nil, err
}
Expand Down
9 changes: 9 additions & 0 deletions consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"strconv"
Expand Down Expand Up @@ -140,6 +141,9 @@ func TestConsumerTLSClientCertViaSet(t *testing.T) {

func consumerTest(t *testing.T, cb func(c *Config)) {
config := NewConfig()
laddr := "127.0.0.2"
// so that the test can simulate binding consumer to specified address
config.LocalAddr, _ = net.ResolveTCPAddr("tcp", laddr+":0")
// so that the test can simulate reaching max requeues and a call to LogFailedMessage
config.DefaultRequeueDelay = 0
// so that the test wont timeout from backing off
Expand Down Expand Up @@ -187,6 +191,11 @@ func consumerTest(t *testing.T, cb func(c *Config)) {
t.Fatal("should not be able to connect to the same NSQ twice")
}

conn := q.conns()[0]
if !strings.HasPrefix(conn.conn.LocalAddr().String(), laddr) {
t.Fatal("connection should be bound to the specified address:", conn.conn.LocalAddr())
}

err = q.DisconnectFromNSQD("1.2.3.4:4150")
if err == nil {
t.Fatal("should not be able to disconnect from an unknown nsqd")
Expand Down
11 changes: 11 additions & 0 deletions producer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"errors"
"io/ioutil"
"log"
"net"
"os"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -40,6 +42,10 @@ func (h *ConsumerHandler) HandleMessage(message *Message) error {

func TestProducerConnection(t *testing.T) {
config := NewConfig()
laddr := "127.0.0.2"

config.LocalAddr, _ = net.ResolveTCPAddr("tcp", laddr+":0")

w, _ := NewProducer("127.0.0.1:4150", config)
w.SetLogger(nullLogger, LogLevelInfo)

Expand All @@ -48,6 +54,11 @@ func TestProducerConnection(t *testing.T) {
t.Fatalf("should lazily connect")
}

conn := w.conn.(*Conn)
if !strings.HasPrefix(conn.conn.LocalAddr().String(), laddr) {
t.Fatal("producer connection should be bound to specified address:", conn.conn.LocalAddr())
}

w.Stop()

err = w.Publish("write_test", []byte("fail test"))
Expand Down