Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make blocking set keyspace query to fail by timeout #362

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2388,7 +2388,7 @@ def _prepare_all_queries(self, host):
else:
for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace):
if keyspace is not None:
connection.set_keyspace_blocking(keyspace)
connection.set_keyspace_blocking(keyspace, self.control_connection_timeout)

# prepare 10 statements at a time
ks_statements = list(ks_statements)
Expand Down
4 changes: 2 additions & 2 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,14 +1498,14 @@ def _handle_auth_response(self, auth_response):
log.error(msg, self.endpoint, auth_response)
raise ProtocolError(msg % (self.endpoint, auth_response))

def set_keyspace_blocking(self, keyspace):
def set_keyspace_blocking(self, keyspace, timeout=None):
if not keyspace or keyspace == self.keyspace:
return

query = QueryMessage(query='USE "%s"' % (keyspace,),
consistency_level=ConsistencyLevel.ONE)
try:
result = self.wait_for_response(query)
result = self.wait_for_response(query, timeout=timeout)
except InvalidRequestException as ire:
# the keyspace probably doesn't exist
raise ire.to_exception()
Expand Down
10 changes: 5 additions & 5 deletions cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def __init__(self, host, host_distance, session):
self._keyspace = session.keyspace

if self._keyspace:
first_connection.set_keyspace_blocking(self._keyspace)
first_connection.set_keyspace_blocking(self._keyspace, session.cluster.control_connection_timeout)
if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable:
self.host.sharding_info = first_connection.features.sharding_info
self._open_connections_for_all_shards(first_connection.features.shard_id)
Expand Down Expand Up @@ -615,7 +615,7 @@ def _replace(self, connection):
connection = self._session.cluster.connection_factory(self.host.endpoint,
on_orphaned_stream_released=self.on_orphaned_stream_released)
if self._keyspace:
connection.set_keyspace_blocking(self._keyspace)
connection.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout)
self._connections[connection.features.shard_id] = connection
except Exception:
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
Expand Down Expand Up @@ -766,7 +766,7 @@ def _open_connection_to_missing_shard(self, shard_id):
self.host
)
if self._keyspace:
conn.set_keyspace_blocking(self._keyspace)
conn.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout)

self._connections[conn.features.shard_id] = conn
if old_conn is not None:
Expand Down Expand Up @@ -953,7 +953,7 @@ def __init__(self, host, host_distance, session):
self._keyspace = session.keyspace
if self._keyspace:
for conn in self._connections:
conn.set_keyspace_blocking(self._keyspace)
conn.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout)

self._trash = set()
self._next_trash_allowed_at = time.time()
Expand Down Expand Up @@ -1053,7 +1053,7 @@ def _add_conn_if_under_max(self):
try:
conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
if self._keyspace:
conn.set_keyspace_blocking(self._session.keyspace)
conn.set_keyspace_blocking(self._session.keyspace, self._session.cluster.control_connection_timeout)
self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL
with self._lock:
new_connections = self._connections[:] + [conn]
Expand Down
42 changes: 42 additions & 0 deletions tests/integration/standard/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
RetryPolicy, SimpleConvictionPolicy, HostDistance,
AddressTranslator, TokenAwarePolicy, HostFilterPolicy)
from cassandra import ConsistencyLevel
from cassandra.protocol import ProtocolHandler, QueryMessage

from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory
from cassandra.auth import PlainTextAuthProvider, SaslAuthProvider
Expand Down Expand Up @@ -484,6 +485,47 @@ def test_refresh_schema_table(self):
self.assertEqual(original_system_schema_meta.as_cql_query(), current_system_schema_meta.as_cql_query())
cluster.shutdown()

def test_use_keyspace_blocking(self):
ks = "test_refresh_schema_type"

cluster = TestCluster()

class ConnectionWrapper(cluster.connection_class):
def __init__(self, *args, **kwargs):
super(ConnectionWrapper, self).__init__(*args, **kwargs)

def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message,
decoder=ProtocolHandler.decode_message, result_metadata=None):
if isinstance(msg, QueryMessage) and f'USE "{ks}"' in msg.query:
orig_decoder = decoder

def decode_patched(protocol_version, protocol_features, user_type_map, stream_id, flags, opcode,
body,
decompressor, result_metadata):
time.sleep(cluster.control_connection_timeout + 0.1)
return orig_decoder(protocol_version, protocol_features, user_type_map, stream_id, flags,
opcode, body, decompressor, result_metadata)

decoder = decode_patched

return super(ConnectionWrapper, self).send_msg(msg, request_id, cb, encoder, decoder, result_metadata)

cluster.connection_class = ConnectionWrapper

cluster.connect().execute("""
CREATE KEYSPACE IF NOT EXISTS %s
WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' }
""" % ks)

try:
cluster.connect(ks)
except NoHostAvailable:
pass
except Exception as e:
self.fail(f"got unexpected exception {e}")
else:
self.fail("connection should fail, but was not")

def test_refresh_schema_type(self):
if get_server_versions()[0] < (2, 1, 0):
raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1')
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_host_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_borrow_and_return(self):
c, request_id = pool.borrow_connection(timeout=0.01)
self.assertIs(c, conn)
self.assertEqual(1, conn.in_flight)
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace')
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace', session.cluster.control_connection_timeout)

pool.return_connection(conn)
self.assertEqual(0, conn.in_flight)
Expand Down Expand Up @@ -256,7 +256,7 @@ def get_conn():
c, request_id = pool.borrow_connection(1.0)
self.assertIs(conn, c)
self.assertEqual(1, conn.in_flight)
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace')
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace', session.cluster.control_connection_timeout)
pool.return_connection(c)

t = Thread(target=get_conn)
Expand Down
Loading