Skip to content

Commit

Permalink
Make blocking set keyspace query to fail by timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
dkropachev committed Aug 9, 2024
1 parent 7e0b02d commit 1e050fe
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 10 deletions.
2 changes: 1 addition & 1 deletion cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2388,7 +2388,7 @@ def _prepare_all_queries(self, host):
else:
for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace):
if keyspace is not None:
connection.set_keyspace_blocking(keyspace)
connection.set_keyspace_blocking(keyspace, self.control_connection_timeout)

# prepare 10 statements at a time
ks_statements = list(ks_statements)
Expand Down
4 changes: 2 additions & 2 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,14 +1498,14 @@ def _handle_auth_response(self, auth_response):
log.error(msg, self.endpoint, auth_response)
raise ProtocolError(msg % (self.endpoint, auth_response))

def set_keyspace_blocking(self, keyspace):
def set_keyspace_blocking(self, keyspace, timeout=None):
if not keyspace or keyspace == self.keyspace:
return

query = QueryMessage(query='USE "%s"' % (keyspace,),
consistency_level=ConsistencyLevel.ONE)
try:
result = self.wait_for_response(query)
result = self.wait_for_response(query, timeout=timeout)
except InvalidRequestException as ire:
# the keyspace probably doesn't exist
raise ire.to_exception()
Expand Down
10 changes: 5 additions & 5 deletions cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def __init__(self, host, host_distance, session):
self._keyspace = session.keyspace

if self._keyspace:
first_connection.set_keyspace_blocking(self._keyspace)
first_connection.set_keyspace_blocking(self._keyspace, session.cluster.control_connection_timeout)
if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable:
self.host.sharding_info = first_connection.features.sharding_info
self._open_connections_for_all_shards(first_connection.features.shard_id)
Expand Down Expand Up @@ -615,7 +615,7 @@ def _replace(self, connection):
connection = self._session.cluster.connection_factory(self.host.endpoint,
on_orphaned_stream_released=self.on_orphaned_stream_released)
if self._keyspace:
connection.set_keyspace_blocking(self._keyspace)
connection.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout)
self._connections[connection.features.shard_id] = connection
except Exception:
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
Expand Down Expand Up @@ -766,7 +766,7 @@ def _open_connection_to_missing_shard(self, shard_id):
self.host
)
if self._keyspace:
conn.set_keyspace_blocking(self._keyspace)
conn.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout)

self._connections[conn.features.shard_id] = conn
if old_conn is not None:
Expand Down Expand Up @@ -953,7 +953,7 @@ def __init__(self, host, host_distance, session):
self._keyspace = session.keyspace
if self._keyspace:
for conn in self._connections:
conn.set_keyspace_blocking(self._keyspace)
conn.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout)

self._trash = set()
self._next_trash_allowed_at = time.time()
Expand Down Expand Up @@ -1053,7 +1053,7 @@ def _add_conn_if_under_max(self):
try:
conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
if self._keyspace:
conn.set_keyspace_blocking(self._session.keyspace)
conn.set_keyspace_blocking(self._session.keyspace, self._session.cluster.control_connection_timeout)
self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL
with self._lock:
new_connections = self._connections[:] + [conn]
Expand Down
42 changes: 42 additions & 0 deletions tests/integration/standard/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
RetryPolicy, SimpleConvictionPolicy, HostDistance,
AddressTranslator, TokenAwarePolicy, HostFilterPolicy)
from cassandra import ConsistencyLevel
from cassandra.protocol import ProtocolHandler, QueryMessage

from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory
from cassandra.auth import PlainTextAuthProvider, SaslAuthProvider
Expand Down Expand Up @@ -484,6 +485,47 @@ def test_refresh_schema_table(self):
self.assertEqual(original_system_schema_meta.as_cql_query(), current_system_schema_meta.as_cql_query())
cluster.shutdown()

def test_use_keyspace_blocking(self):
ks = "test_refresh_schema_type"

cluster = TestCluster()

class ConnectionWrapper(cluster.connection_class):
def __init__(self, *args, **kwargs):
super(ConnectionWrapper, self).__init__(*args, **kwargs)

def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message,
decoder=ProtocolHandler.decode_message, result_metadata=None):
if isinstance(msg, QueryMessage) and f'USE "{ks}"' in msg.query:
orig_decoder = decoder

def decode_patched(protocol_version, protocol_features, user_type_map, stream_id, flags, opcode,
body,
decompressor, result_metadata):
time.sleep(cluster.control_connection_timeout + 0.1)
return orig_decoder(protocol_version, protocol_features, user_type_map, stream_id, flags,
opcode, body, decompressor, result_metadata)

decoder = decode_patched

return super(ConnectionWrapper, self).send_msg(msg, request_id, cb, encoder, decoder, result_metadata)

cluster.connection_class = ConnectionWrapper

cluster.connect().execute("""
CREATE KEYSPACE IF NOT EXISTS %s
WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' }
""" % ks)

try:
cluster.connect(ks)
except NoHostAvailable:
pass
except Exception as e:
self.fail(f"got unexpected exception {e}")
else:
self.fail("connection should fail, but was not")

def test_refresh_schema_type(self):
if get_server_versions()[0] < (2, 1, 0):
raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1')
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_host_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_borrow_and_return(self):
c, request_id = pool.borrow_connection(timeout=0.01)
self.assertIs(c, conn)
self.assertEqual(1, conn.in_flight)
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace')
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace', session.cluster.control_connection_timeout)

pool.return_connection(conn)
self.assertEqual(0, conn.in_flight)
Expand Down Expand Up @@ -256,7 +256,7 @@ def get_conn():
c, request_id = pool.borrow_connection(1.0)
self.assertIs(conn, c)
self.assertEqual(1, conn.in_flight)
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace')
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace', session.cluster.control_connection_timeout)
pool.return_connection(c)

t = Thread(target=get_conn)
Expand Down

0 comments on commit 1e050fe

Please sign in to comment.