diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 31ecd15b6f..b230443d7e 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1220,30 +1220,7 @@ def __init__(self, self.endpoint_factory = endpoint_factory or DefaultEndPointFactory(port=self.port) self.endpoint_factory.configure(self) - raw_contact_points = [] - for cp in [cp for cp in self.contact_points if not isinstance(cp, EndPoint)]: - raw_contact_points.append(cp if isinstance(cp, tuple) else (cp, port)) - - self.endpoints_resolved = [cp for cp in self.contact_points if isinstance(cp, EndPoint)] - self._endpoint_map_for_insights = {repr(ep): '{ip}:{port}'.format(ip=ep.address, port=ep.port) - for ep in self.endpoints_resolved} - - strs_resolved_map = _resolve_contact_points_to_string_map(raw_contact_points) - self.endpoints_resolved.extend(list(chain( - *[ - [DefaultEndPoint(ip, port) for ip, port in xs if ip is not None] - for xs in strs_resolved_map.values() if xs is not None - ] - ))) - - self._endpoint_map_for_insights.update( - {key: ['{ip}:{port}'.format(ip=ip, port=port) for ip, port in value] - for key, value in strs_resolved_map.items() if value is not None} - ) - - if contact_points and (not self.endpoints_resolved): - # only want to raise here if the user specified CPs but resolution failed - raise UnresolvableContactPoints(self._endpoint_map_for_insights) + self._resolve_hostnames() self.compression = compression @@ -1427,6 +1404,31 @@ def __init__(self, if application_version is not None: self.application_version = application_version + def _resolve_hostnames(self): + raw_contact_points = [] + for cp in [cp for cp in self.contact_points if not isinstance(cp, EndPoint)]: + raw_contact_points.append(cp if isinstance(cp, tuple) else (cp, self.port)) + + self.endpoints_resolved = [cp for cp in self.contact_points if isinstance(cp, EndPoint)] + self._endpoint_map_for_insights = {repr(ep): '{ip}:{port}'.format(ip=ep.address, port=ep.port) + for ep in self.endpoints_resolved} + strs_resolved_map = _resolve_contact_points_to_string_map(raw_contact_points) + self.endpoints_resolved.extend(list(chain( + *[ + [DefaultEndPoint(ip, port) for ip, port in xs if ip is not None] + for xs in strs_resolved_map.values() if xs is not None + ] + ))) + + self._endpoint_map_for_insights.update( + {key: ['{ip}:{port}'.format(ip=ip, port=port) for ip, port in value] + for key, value in strs_resolved_map.items() if value is not None} + ) + + if self.contact_points and (not self.endpoints_resolved): + # only want to raise here if the user specified CPs but resolution failed + raise UnresolvableContactPoints(self._endpoint_map_for_insights) + def _create_thread_pool_executor(self, **kwargs): """ Create a ThreadPoolExecutor for the cluster. In most cases, the built-in @@ -1720,6 +1722,20 @@ def protocol_downgrade(self, host_endpoint, previous_version): "http://datastax.github.io/python-driver/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_endpoint) self.protocol_version = new_version + def _add_resolved_hosts(self): + for endpoint in self.endpoints_resolved: + host, new = self.add_host(endpoint, signal=False) + if new: + host.set_up() + for listener in self.listeners: + listener.on_add(host) + + self.profile_manager.populate( + weakref.proxy(self), self.metadata.all_hosts()) + self.load_balancing_policy.populate( + weakref.proxy(self), self.metadata.all_hosts() + ) + def connect(self, keyspace=None, wait_for_all_pools=False): """ Creates and returns a new :class:`~.Session` object. @@ -1740,18 +1756,8 @@ def connect(self, keyspace=None, wait_for_all_pools=False): self.contact_points, self.protocol_version) self.connection_class.initialize_reactor() _register_cluster_shutdown(self) - for endpoint in self.endpoints_resolved: - host, new = self.add_host(endpoint, signal=False) - if new: - host.set_up() - for listener in self.listeners: - listener.on_add(host) - - self.profile_manager.populate( - weakref.proxy(self), self.metadata.all_hosts()) - self.load_balancing_policy.populate( - weakref.proxy(self), self.metadata.all_hosts() - ) + + self._add_resolved_hosts() try: self.control_connection.connect() @@ -3585,16 +3591,8 @@ def _set_new_connection(self, conn): if old: log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) old.close() - - def _reconnect_internal(self): - """ - Tries to connect to each host in the query plan until one succeeds - or every attempt fails. If successful, a new Connection will be - returned. Otherwise, :exc:`NoHostAvailable` will be raised - with an "errors" arg that is a dict mapping host addresses - to the exception that was raised when an attempt was made to open - a connection to that host. - """ + + def _connect_host_in_lbp(self): errors = {} lbp = ( self._cluster.load_balancing_policy @@ -3604,7 +3602,7 @@ def _reconnect_internal(self): for host in lbp.make_query_plan(): try: - return self._try_connect(host) + return (self._try_connect(host), None) except ConnectionException as exc: errors[str(host.endpoint)] = exc log.warning("[control connection] Error connecting to %s:", host, exc_info=True) @@ -3614,7 +3612,31 @@ def _reconnect_internal(self): log.warning("[control connection] Error connecting to %s:", host, exc_info=True) if self._is_shutdown: raise DriverException("[control connection] Reconnection in progress during shutdown") + + return (None, errors) + def _reconnect_internal(self): + """ + Tries to connect to each host in the query plan until one succeeds + or every attempt fails. If successful, a new Connection will be + returned. Otherwise, :exc:`NoHostAvailable` will be raised + with an "errors" arg that is a dict mapping host addresses + to the exception that was raised when an attempt was made to open + a connection to that host. + """ + (conn, _) = self._connect_host_in_lbp() + if conn is not None: + return conn + + # Try to re-resolve hostnames as a fallback when all hosts are unreachable + self._cluster._resolve_hostnames() + + self._cluster._add_resolved_hosts() + + (conn, errors) = self._connect_host_in_lbp() + if conn is not None: + return conn + raise NoHostAvailable("Unable to connect to any servers", errors) def _try_connect(self, host):