diff --git a/policies.go b/policies.go index 8be5bcec3..8515c831c 100644 --- a/policies.go +++ b/policies.go @@ -948,19 +948,37 @@ func (host selectedHostPoolHost) Mark(err error) { } type dcAwareRR struct { - local string - localHosts cowHostList - remoteHosts cowHostList - lastUsedHostIdx uint64 + local string + localHosts cowHostList + remoteHosts cowHostList + lastUsedHostIdx uint64 + disableDCFailover bool +} + +type dcFailoverDisabledPolicy interface { + setDCFailoverDisabled() +} + +type dcAwarePolicyOption func(p dcFailoverDisabledPolicy) + +func HostPolicyOptionDisableDCFailover(p dcFailoverDisabledPolicy) { + p.setDCFailoverDisabled() } // DCAwareRoundRobinPolicy is a host selection policies which will prioritize and // return hosts which are in the local datacentre before returning hosts in all // other datercentres -func DCAwareRoundRobinPolicy(localDC string) HostSelectionPolicy { - return &dcAwareRR{local: localDC} +func DCAwareRoundRobinPolicy(localDC string, opts ...dcAwarePolicyOption) HostSelectionPolicy { + p := &dcAwareRR{local: localDC, disableDCFailover: false} + for _, opt := range opts { + opt(p) + } + return p } +func (d *dcAwareRR) setDCFailoverDisabled() { + d.disableDCFailover = true +} func (d *dcAwareRR) Init(*Session) {} func (d *dcAwareRR) Reset() {} func (d *dcAwareRR) KeyspaceChanged(KeyspaceUpdateEvent) {} @@ -1035,7 +1053,12 @@ func roundRobbin(shift int, hosts ...[]*HostInfo) NextHost { func (d *dcAwareRR) Pick(q ExecutableQuery) NextHost { nextStartOffset := atomic.AddUint64(&d.lastUsedHostIdx, 1) - return roundRobbin(int(nextStartOffset), d.localHosts.get(), d.remoteHosts.get()) + if !d.disableDCFailover { + return roundRobbin(int(nextStartOffset), d.localHosts.get(), d.remoteHosts.get()) + } else { + return roundRobbin(int(nextStartOffset), d.localHosts.get()) + } + } // RackAwareRoundRobinPolicy is a host selection policies which will prioritize and @@ -1047,15 +1070,19 @@ type rackAwareRR struct { // It is accessed atomically and needs to be aligned to 64 bits, so we // keep it first in the struct. Do not move it or add new struct members // before it. - lastUsedHostIdx uint64 - localDC string - localRack string - hosts []cowHostList + lastUsedHostIdx uint64 + localDC string + localRack string + hosts []cowHostList + disableDCFailover bool } -func RackAwareRoundRobinPolicy(localDC string, localRack string) HostSelectionPolicy { - hosts := make([]cowHostList, 3) - return &rackAwareRR{localDC: localDC, localRack: localRack, hosts: hosts} +func RackAwareRoundRobinPolicy(localDC string, localRack string, opts ...dcAwarePolicyOption) HostSelectionPolicy { + p := &rackAwareRR{localDC: localDC, localRack: localRack, hosts: make([]cowHostList, 3), disableDCFailover: false} + for _, opt := range opts { + opt(p) + } + return p } func (d *rackAwareRR) Init(*Session) {} @@ -1067,6 +1094,10 @@ func (d *rackAwareRR) MaxHostTier() uint { return 2 } +func (d *rackAwareRR) setDCFailoverDisabled() { + d.disableDCFailover = true +} + // Experimental, this interface and use may change func (d *rackAwareRR) SetTablets(tablets []*TabletInfo) {} @@ -1101,7 +1132,12 @@ func (d *rackAwareRR) HostDown(host *HostInfo) { d.RemoveHost(host) } func (d *rackAwareRR) Pick(q ExecutableQuery) NextHost { nextStartOffset := atomic.AddUint64(&d.lastUsedHostIdx, 1) - return roundRobbin(int(nextStartOffset), d.hosts[0].get(), d.hosts[1].get(), d.hosts[2].get()) + if !d.disableDCFailover { + return roundRobbin(int(nextStartOffset), d.hosts[0].get(), d.hosts[1].get(), d.hosts[2].get()) + } else { + return roundRobbin(int(nextStartOffset), d.hosts[0].get(), d.hosts[1].get()) + } + } // ReadyPolicy defines a policy for when a HostSelectionPolicy can be used. After diff --git a/policies_test.go b/policies_test.go index 60f043b19..d8e83f341 100644 --- a/policies_test.go +++ b/policies_test.go @@ -601,6 +601,37 @@ func TestHostPolicy_DCAwareRR(t *testing.T) { } +func TestHostPolicy_DCAwareRR_wrongDc(t *testing.T) { + p := DCAwareRoundRobinPolicy("wrong_dc", HostPolicyOptionDisableDCFailover) + + hosts := [...]*HostInfo{ + {hostId: "0", connectAddress: net.ParseIP("10.0.0.1"), dataCenter: "local"}, + {hostId: "1", connectAddress: net.ParseIP("10.0.0.2"), dataCenter: "local"}, + {hostId: "2", connectAddress: net.ParseIP("10.0.0.3"), dataCenter: "remote"}, + {hostId: "3", connectAddress: net.ParseIP("10.0.0.4"), dataCenter: "remote"}, + } + + for _, host := range hosts { + p.AddHost(host) + } + + got := make(map[string]bool, len(hosts)) + + it := p.Pick(nil) + for h := it(); h != nil; h = it() { + id := h.Info().hostId + + if got[id] { + t.Fatalf("got duplicate host %s", id) + } + got[id] = true + } + + if len(got) != 0 { + t.Fatalf("expected %d hosts got %d", 0, len(got)) + } +} + // Tests of the token-aware host selection policy implementation with a // DC aware round-robin host selection policy fallback // with {"class": "NetworkTopologyStrategy", "a": 1, "b": 1, "c": 1} replication.