Skip to content

Commit

Permalink
Add HostAddressTranslator
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshartig committed Apr 7, 2022
1 parent e9dc5b2 commit dd21b22
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 52 deletions.
32 changes: 31 additions & 1 deletion address_translators.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,52 @@ import "net"
// AddressTranslator provides a way to translate node addresses (and ports) that are
// discovered or received as a node event. This can be useful in an ec2 environment,
// for instance, to translate public IPs to private IPs.
// Deprecated: Use HostAddressTranslator
type AddressTranslator interface {
// Translate will translate the provided address and/or port to another
// address and/or port. If no translation is possible, Translate will return the
// address and port provided to it.
Translate(addr net.IP, port int) (net.IP, int)
}

// Deprecated: Use HostAddressTranslatorFunc
type AddressTranslatorFunc func(addr net.IP, port int) (net.IP, int)

func (fn AddressTranslatorFunc) Translate(addr net.IP, port int) (net.IP, int) {
return fn(addr, port)
}

// IdentityTranslator will do nothing but return what it was provided. It is essentially a no-op.
// IdentityTranslator will do nothing but return what it was provided. It is
// essentially a no-op.
// Deprecated: Use HostIdentityTranslator
func IdentityTranslator() AddressTranslator {
return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
return addr, port
})
}

// HostAddressTranslator provides a way to translate node addresses (and ports)
// that are discovered or received as a node event. This can be useful in an ec2
// environment, for instance, to translate public IPs to private IPs. The
// HostInfo is provided to allow for inspecting the various attributes of the
// node but depending on the context the HostInfo might not be filled in
// completely.
type HostAddressTranslator interface {
// Translate will translate the provided address and/or port to another
// address and/or port. If no translation is possible, Translate will return the
// address and port provided to it.
Translate(addr net.IP, port int, host *HostInfo) (net.IP, int)
}

type HostAddressTranslatorFunc func(addr net.IP, port int, host *HostInfo) (net.IP, int)

func (fn HostAddressTranslatorFunc) Translate(addr net.IP, port int, host *HostInfo) (net.IP, int) {
return fn(addr, port, host)
}

// IdentityTranslator will do nothing but return what it was provided. It is essentially a no-op.
func HostIdentityTranslator() HostAddressTranslator {
return HostAddressTranslatorFunc(func(addr net.IP, port int, host *HostInfo) (net.IP, int) {
return addr, port
})
}
28 changes: 28 additions & 0 deletions address_translators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,31 @@ func TestIdentityAddressTranslator_HostProvided(t *testing.T) {
}
assertEqual(t, "translated port", 9042, port)
}

func TestHostIdentityAddressTranslator_NilAddrAndZeroPort(t *testing.T) {
var tr HostAddressTranslator = HostIdentityTranslator()
hostIP := net.ParseIP("")
if hostIP != nil {
t.Errorf("expected host ip to be (nil) but was (%+v) instead", hostIP)
}

addr, port := tr.Translate(hostIP, 0, &HostInfo{})
if addr != nil {
t.Errorf("expected translated host to be (nil) but was (%+v) instead", addr)
}
assertEqual(t, "translated port", 0, port)
}

func TestHostIdentityAddressTranslator_HostProvided(t *testing.T) {
var tr HostAddressTranslator = HostIdentityTranslator()
hostIP := net.ParseIP("10.1.2.3")
if hostIP == nil {
t.Error("expected host ip not to be (nil)")
}

addr, port := tr.Translate(hostIP, 9042, &HostInfo{})
if !hostIP.Equal(addr) {
t.Errorf("expected translated addr to be (%+v) but was (%+v) instead", hostIP, addr)
}
assertEqual(t, "translated port", 9042, port)
}
23 changes: 18 additions & 5 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,13 @@ type ClusterConfig struct {

// AddressTranslator will translate addresses found on peer discovery and/or
// node change events.
// Deprecated: Use HostAddressTranslator
AddressTranslator AddressTranslator

// HostAddressTranslator will translate addresses found on peer discovery and/or
// node change events.
HostAddressTranslator HostAddressTranslator

// If IgnorePeerAddr is true and the address in system.peers does not match
// the supplied host by either initial hosts or discovered via events then the
// host will be replaced with the supplied address.
Expand Down Expand Up @@ -210,11 +215,19 @@ func (cfg *ClusterConfig) CreateSession() (*Session, error) {
return NewSession(*cfg)
}

// translateAddressPort is a helper method that will use the given AddressTranslator
// if defined, to translate the given address and port into a possibly new address
// and port, If no AddressTranslator or if an error occurs, the given address and
// port will be returned.
func (cfg *ClusterConfig) translateAddressPort(addr net.IP, port int) (net.IP, int) {
// translateAddressPort is a helper method that will use the given
// HostAddressTranslator or AddressTranslator, if defined, to translate the given
// address and port into a possibly new address and port, If no
// HostAddressTranslator orAddressTranslator or if an error occurs, the given
// address and port will be returned.
func (cfg *ClusterConfig) translateAddressPort(addr net.IP, port int, host *HostInfo) (net.IP, int) {
if cfg.HostAddressTranslator != nil {
newAddr, newPort := cfg.HostAddressTranslator.Translate(addr, port, host)
if gocqlDebug {
cfg.logger().Printf("gocql: translating address '%v:%d' to '%v:%d'", addr, port, newAddr, newPort)
}
return newAddr, newPort
}
if cfg.AddressTranslator == nil || len(addr) == 0 {
return addr, port
}
Expand Down
30 changes: 27 additions & 3 deletions cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,47 @@ func TestNewCluster_WithHosts(t *testing.T) {
func TestClusterConfig_translateAddressAndPort_NilTranslator(t *testing.T) {
cfg := NewCluster()
assertNil(t, "cluster config address translator", cfg.AddressTranslator)
newAddr, newPort := cfg.translateAddressPort(net.ParseIP("10.0.0.1"), 1234)
newAddr, newPort := cfg.translateAddressPort(net.ParseIP("10.0.0.1"), 1234, &HostInfo{})
assertTrue(t, "same address as provided", net.ParseIP("10.0.0.1").Equal(newAddr))
assertEqual(t, "translated host and port", 1234, newPort)
}

func TestClusterConfig_translateAddressAndPort_EmptyAddr(t *testing.T) {
cfg := NewCluster()
cfg.AddressTranslator = staticAddressTranslator(net.ParseIP("10.10.10.10"), 5432)
newAddr, newPort := cfg.translateAddressPort(net.IP([]byte{}), 0)
newAddr, newPort := cfg.translateAddressPort(net.IP([]byte{}), 0, &HostInfo{})
assertTrue(t, "translated address is still empty", len(newAddr) == 0)
assertEqual(t, "translated port", 0, newPort)
}

func TestClusterConfig_translateAddressAndPort_Success(t *testing.T) {
cfg := NewCluster()
cfg.AddressTranslator = staticAddressTranslator(net.ParseIP("10.10.10.10"), 5432)
newAddr, newPort := cfg.translateAddressPort(net.ParseIP("10.0.0.1"), 2345)
newAddr, newPort := cfg.translateAddressPort(net.ParseIP("10.0.0.1"), 2345, &HostInfo{})
assertTrue(t, "translated address", net.ParseIP("10.10.10.10").Equal(newAddr))
assertEqual(t, "translated port", 5432, newPort)
}

func TestClusterConfig_translateHostAddressAndPort_NilTranslator(t *testing.T) {
cfg := NewCluster()
assertNil(t, "cluster config address translator", cfg.HostAddressTranslator)
newAddr, newPort := cfg.translateAddressPort(net.ParseIP("10.0.0.1"), 1234, &HostInfo{})
assertTrue(t, "same address as provided", net.ParseIP("10.0.0.1").Equal(newAddr))
assertEqual(t, "translated host and port", 1234, newPort)
}

func TestClusterConfig_translateHostAddressAndPort_EmptyAddr(t *testing.T) {
cfg := NewCluster()
cfg.HostAddressTranslator = staticHostAddressTranslator(net.ParseIP("10.10.10.10"), 5432)
newAddr, newPort := cfg.translateAddressPort(net.IP([]byte{}), 0, &HostInfo{})
assertTrue(t, "translated address", net.ParseIP("10.10.10.10").Equal(newAddr))
assertEqual(t, "translated port", 5432, newPort)
}

func TestClusterConfig_translateHostAddressAndPort_Success(t *testing.T) {
cfg := NewCluster()
cfg.HostAddressTranslator = staticHostAddressTranslator(net.ParseIP("10.10.10.10"), 5432)
newAddr, newPort := cfg.translateAddressPort(net.ParseIP("10.0.0.1"), 2345, &HostInfo{})
assertTrue(t, "translated address", net.ParseIP("10.10.10.10").Equal(newAddr))
assertEqual(t, "translated port", 5432, newPort)
}
6 changes: 6 additions & 0 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ func staticAddressTranslator(newAddr net.IP, newPort int) AddressTranslator {
})
}

func staticHostAddressTranslator(newAddr net.IP, newPort int) HostAddressTranslator {
return HostAddressTranslatorFunc(func(addr net.IP, port int, host *HostInfo) (net.IP, int) {
return newAddr, newPort
})
}

func assertTrue(t *testing.T, description string, value bool) {
t.Helper()
if !value {
Expand Down
91 changes: 56 additions & 35 deletions events.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,26 +190,16 @@ func (s *Session) handleNodeEvent(frames []frame) {
}
}

func (s *Session) addNewNode(ip net.IP, port int) {
// Get host info and apply any filters to the host
hostInfo, err := s.hostSource.getHostInfo(ip, port)
if err != nil {
s.logger.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err)
return
} else if hostInfo == nil {
// ignore if it's null because we couldn't find it
return
}

if t := hostInfo.Version().nodeUpDelay(); t > 0 {
func (s *Session) addNewNode(host *HostInfo) {
if t := host.Version().nodeUpDelay(); t > 0 {
time.Sleep(t)
}

// should this handle token moving?
hostInfo = s.ring.addOrUpdate(hostInfo)
host = s.ring.addOrUpdate(host)

if !s.cfg.filterHost(hostInfo) {
s.startPoolFill(hostInfo)
if !s.cfg.filterHost(host) {
s.startPoolFill(host)
}

if s.control != nil && !s.cfg.IgnorePeerAddr {
Expand All @@ -223,30 +213,48 @@ func (s *Session) handleNewNode(ip net.IP, port int) {
s.logger.Printf("gocql: Session.handleNewNode: %s:%d\n", ip.String(), port)
}

ip, port = s.cfg.translateAddressPort(ip, port)

// if we already have the host and it's already up, then do nothing
host := s.ring.getHost(ip)
if host != nil && host.IsUp() {
return
}

s.addNewNode(ip, port)
host, err := s.hostSource.getHostInfo(ip, port)
if err != nil {
s.logger.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err)
return
}
if host == nil {
s.logger.Printf("gocql: events: unable to find host info for (%s:%d): %v\n", ip, port, err)
return
}
// check again for the host because we might've just translated the IP
if host := s.ring.getHost(host.ConnectAddress()); host != nil && host.IsUp() {
return
}

s.addNewNode(host)
}

func (s *Session) handleRemovedNode(ip net.IP, port int) {
if gocqlDebug {
s.logger.Printf("gocql: Session.handleRemovedNode: %s:%d\n", ip.String(), port)
}

ip, port = s.cfg.translateAddressPort(ip, port)

// we remove all nodes but only add ones which pass the filter
// search to see if the host is known
host := s.ring.getHost(ip)
if host == nil {
host = &HostInfo{connectAddress: ip, port: port}
host = &HostInfo{connectAddress: ip, port: port, peer: ip}
ip, port = s.cfg.translateAddressPort(ip, port, host)
host.connectAddress = ip
host.port = port
if host.invalidConnectAddr() {
s.logger.Printf("gocql: host ConnectAddress invalid in handleRemovedNode ip=%v: %v", ip, host)
return
}
}
s.ring.removeHost(ip)
ip = host.ConnectAddress()
s.ring.removeIP(ip)

host.setState(NodeDown)
if !s.cfg.filterHost(host) {
Expand All @@ -259,27 +267,32 @@ func (s *Session) handleRemovedNode(ip net.IP, port int) {
}
}

func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) {
func (s *Session) handleNodeUp(ip net.IP, port int) {
if gocqlDebug {
s.logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort)
s.logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", ip.String(), port)
}

ip, port := s.cfg.translateAddressPort(eventIp, eventPort)

// if we already have the host and it's already up, then do nothing
host := s.ring.getHost(ip)
if host == nil {
s.addNewNode(ip, port)
if host != nil && host.IsUp() {
return
}

if s.cfg.filterHost(host) {
host, err := s.hostSource.getHostInfo(ip, port)
if err != nil {
s.logger.Printf("gocql: events: unable to fetch host info for (%s:%d): %v\n", ip, port, err)
return
}

if d := host.Version().nodeUpDelay(); d > 0 {
time.Sleep(d)
if host == nil {
s.logger.Printf("gocql: events: unable to find host info for (%s:%d): %v\n", ip, port, err)
return
}
s.startPoolFill(host)
// check again for the host because we might've just translated the IP
if host := s.ring.getHost(host.ConnectAddress()); host != nil && host.IsUp() {
return
}

s.addNewNode(host)
}

func (s *Session) startPoolFill(host *HostInfo) {
Expand Down Expand Up @@ -307,14 +320,22 @@ func (s *Session) handleNodeDown(ip net.IP, port int) {

host := s.ring.getHost(ip)
if host == nil {
host = &HostInfo{connectAddress: ip, port: port}
host = &HostInfo{connectAddress: ip, port: port, peer: ip}
ip, port = s.cfg.translateAddressPort(ip, port, host)
host.connectAddress = ip
host.port = port
if host.invalidConnectAddr() {
s.logger.Printf("gocql: host ConnectAddress invalid in handleNodeDown ip=%v: %v", ip, host)
return
}
}

host.setState(NodeDown)
if s.cfg.filterHost(host) {
return
}

ip = host.ConnectAddress()
s.policy.HostDown(host)
s.pool.hostDown(ip)
}
Loading

0 comments on commit dd21b22

Please sign in to comment.