From 1e050fe4a101a47aa5c3b5cd627c3d5ca1b7f6dd Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Thu, 8 Aug 2024 20:58:56 -0400 Subject: [PATCH] Make blocking set keyspace query to fail by timeout --- cassandra/cluster.py | 2 +- cassandra/connection.py | 4 +-- cassandra/pool.py | 10 +++--- tests/integration/standard/test_cluster.py | 42 ++++++++++++++++++++++ tests/unit/test_host_connection_pool.py | 4 +-- 5 files changed, 52 insertions(+), 10 deletions(-) diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 06e6293ef8..b032149a79 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -2388,7 +2388,7 @@ def _prepare_all_queries(self, host): else: for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace): if keyspace is not None: - connection.set_keyspace_blocking(keyspace) + connection.set_keyspace_blocking(keyspace, self.control_connection_timeout) # prepare 10 statements at a time ks_statements = list(ks_statements) diff --git a/cassandra/connection.py b/cassandra/connection.py index ebdfe99993..722566c26e 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1498,14 +1498,14 @@ def _handle_auth_response(self, auth_response): log.error(msg, self.endpoint, auth_response) raise ProtocolError(msg % (self.endpoint, auth_response)) - def set_keyspace_blocking(self, keyspace): + def set_keyspace_blocking(self, keyspace, timeout=None): if not keyspace or keyspace == self.keyspace: return query = QueryMessage(query='USE "%s"' % (keyspace,), consistency_level=ConsistencyLevel.ONE) try: - result = self.wait_for_response(query) + result = self.wait_for_response(query, timeout=timeout) except InvalidRequestException as ire: # the keyspace probably doesn't exist raise ire.to_exception() diff --git a/cassandra/pool.py b/cassandra/pool.py index 738fc8e6d6..593909751e 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -435,7 +435,7 @@ def __init__(self, host, host_distance, session): self._keyspace = session.keyspace if self._keyspace: - first_connection.set_keyspace_blocking(self._keyspace) + first_connection.set_keyspace_blocking(self._keyspace, session.cluster.control_connection_timeout) if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable: self.host.sharding_info = first_connection.features.sharding_info self._open_connections_for_all_shards(first_connection.features.shard_id) @@ -615,7 +615,7 @@ def _replace(self, connection): connection = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) if self._keyspace: - connection.set_keyspace_blocking(self._keyspace) + connection.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout) self._connections[connection.features.shard_id] = connection except Exception: log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,)) @@ -766,7 +766,7 @@ def _open_connection_to_missing_shard(self, shard_id): self.host ) if self._keyspace: - conn.set_keyspace_blocking(self._keyspace) + conn.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout) self._connections[conn.features.shard_id] = conn if old_conn is not None: @@ -953,7 +953,7 @@ def __init__(self, host, host_distance, session): self._keyspace = session.keyspace if self._keyspace: for conn in self._connections: - conn.set_keyspace_blocking(self._keyspace) + conn.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout) self._trash = set() self._next_trash_allowed_at = time.time() @@ -1053,7 +1053,7 @@ def _add_conn_if_under_max(self): try: conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released) if self._keyspace: - conn.set_keyspace_blocking(self._session.keyspace) + conn.set_keyspace_blocking(self._session.keyspace, self._session.cluster.control_connection_timeout) self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL with self._lock: new_connections = self._connections[:] + [conn] diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 43356dbd82..9eb6e3cf72 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -32,6 +32,7 @@ RetryPolicy, SimpleConvictionPolicy, HostDistance, AddressTranslator, TokenAwarePolicy, HostFilterPolicy) from cassandra import ConsistencyLevel +from cassandra.protocol import ProtocolHandler, QueryMessage from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory from cassandra.auth import PlainTextAuthProvider, SaslAuthProvider @@ -484,6 +485,47 @@ def test_refresh_schema_table(self): self.assertEqual(original_system_schema_meta.as_cql_query(), current_system_schema_meta.as_cql_query()) cluster.shutdown() + def test_use_keyspace_blocking(self): + ks = "test_refresh_schema_type" + + cluster = TestCluster() + + class ConnectionWrapper(cluster.connection_class): + def __init__(self, *args, **kwargs): + super(ConnectionWrapper, self).__init__(*args, **kwargs) + + def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, + decoder=ProtocolHandler.decode_message, result_metadata=None): + if isinstance(msg, QueryMessage) and f'USE "{ks}"' in msg.query: + orig_decoder = decoder + + def decode_patched(protocol_version, protocol_features, user_type_map, stream_id, flags, opcode, + body, + decompressor, result_metadata): + time.sleep(cluster.control_connection_timeout + 0.1) + return orig_decoder(protocol_version, protocol_features, user_type_map, stream_id, flags, + opcode, body, decompressor, result_metadata) + + decoder = decode_patched + + return super(ConnectionWrapper, self).send_msg(msg, request_id, cb, encoder, decoder, result_metadata) + + cluster.connection_class = ConnectionWrapper + + cluster.connect().execute(""" + CREATE KEYSPACE IF NOT EXISTS %s + WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' } + """ % ks) + + try: + cluster.connect(ks) + except NoHostAvailable: + pass + except Exception as e: + self.fail(f"got unexpected exception {e}") + else: + self.fail("connection should fail, but was not") + def test_refresh_schema_type(self): if get_server_versions()[0] < (2, 1, 0): raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1') diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index efed55daa2..df801e474e 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -56,7 +56,7 @@ def test_borrow_and_return(self): c, request_id = pool.borrow_connection(timeout=0.01) self.assertIs(c, conn) self.assertEqual(1, conn.in_flight) - conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace') + conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace', session.cluster.control_connection_timeout) pool.return_connection(conn) self.assertEqual(0, conn.in_flight) @@ -256,7 +256,7 @@ def get_conn(): c, request_id = pool.borrow_connection(1.0) self.assertIs(conn, c) self.assertEqual(1, conn.in_flight) - conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace') + conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace', session.cluster.control_connection_timeout) pool.return_connection(c) t = Thread(target=get_conn)