diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index 8167b3b894..c2c0d9f905 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -832,9 +832,12 @@ def serialize_safe(cls, items, protocol_version): buf.write(pack(len(items))) inner_proto = max(3, protocol_version) for item in items: - itembytes = subtype.to_binary(item, inner_proto) - buf.write(pack(len(itembytes))) - buf.write(itembytes) + if item is None: + buf.write(pack(-1)) + else: + itembytes = subtype.to_binary(item, inner_proto) + buf.write(pack(len(itembytes))) + buf.write(itembytes) return buf.getvalue() @@ -902,12 +905,18 @@ def serialize_safe(cls, themap, protocol_version): raise TypeError("Got a non-map object for a map value") inner_proto = max(3, protocol_version) for key, val in items: - keybytes = key_type.to_binary(key, inner_proto) - valbytes = value_type.to_binary(val, inner_proto) - buf.write(pack(len(keybytes))) - buf.write(keybytes) - buf.write(pack(len(valbytes))) - buf.write(valbytes) + if key is not None: + keybytes = key_type.to_binary(key, inner_proto) + buf.write(pack(len(keybytes))) + buf.write(keybytes) + else: + buf.write(pack(-1)) + if val is not None: + valbytes = value_type.to_binary(val, inner_proto) + buf.write(pack(len(valbytes))) + buf.write(valbytes) + else: + buf.write(pack(-1)) return buf.getvalue() diff --git a/tests/integration/standard/test_types.py b/tests/integration/standard/test_types.py index bc26a3013e..4329574ba6 100644 --- a/tests/integration/standard/test_types.py +++ b/tests/integration/standard/test_types.py @@ -26,7 +26,7 @@ from cassandra.concurrent import execute_concurrent_with_args from cassandra.cqltypes import Int32Type, EMPTY from cassandra.query import dict_factory, ordered_dict_factory -from cassandra.util import sortedset, Duration +from cassandra.util import sortedset, Duration, OrderedMap from tests.unit.cython.utils import cythontest from tests.integration import use_singledc, execute_until_pass, notprotocolv1, \ @@ -723,6 +723,51 @@ def test_can_insert_tuples_with_nulls(self): self.assertEqual(('', None, None, b''), result[0].t) self.assertEqual(('', None, None, b''), s.execute(read)[0].t) + def test_insert_collection_with_null_fails(self): + """ + NULLs in list / sets / maps are forbidden. + This is a regression test - there was a bug that serialized None values + in collections as empty values instead of nulls. + """ + s = self.session + columns = [] + for collection_type in ['list', 'set']: + for simple_type in PRIMITIVE_DATATYPES_KEYS: + columns.append(f'{collection_type}_{simple_type} {collection_type}<{simple_type}>') + for simple_type in PRIMITIVE_DATATYPES_KEYS: + columns.append(f'map_k_{simple_type} map<{simple_type}, ascii>') + columns.append(f'map_v_{simple_type} map') + s.execute(f'CREATE TABLE collection_nulls (k int PRIMARY KEY, {", ".join(columns)})') + + def raises_simple_and_prepared(exc_type, query_str, args): + self.assertRaises(exc_type, lambda: s.execute(query_str, args)) + p = s.prepare(query_str.replace('%s', '?')) + self.assertRaises(exc_type, lambda: s.execute(p, args)) + + i = 0 + for simple_type in PRIMITIVE_DATATYPES_KEYS: + query_str = f'INSERT INTO collection_nulls (k, set_{simple_type}) VALUES (%s, %s)' + args = [i, sortedset([None, get_sample(simple_type)])] + raises_simple_and_prepared(InvalidRequest, query_str, args) + i += 1 + for simple_type in PRIMITIVE_DATATYPES_KEYS: + query_str = f'INSERT INTO collection_nulls (k, list_{simple_type}) VALUES (%s, %s)' + args = [i, [None, get_sample(simple_type)]] + raises_simple_and_prepared(InvalidRequest, query_str, args) + i += 1 + for simple_type in PRIMITIVE_DATATYPES_KEYS: + query_str = f'INSERT INTO collection_nulls (k, map_k_{simple_type}) VALUES (%s, %s)' + args = [i, OrderedMap([(get_sample(simple_type), 'abc'), (None, 'def')])] + raises_simple_and_prepared(InvalidRequest, query_str, args) + i += 1 + for simple_type in PRIMITIVE_DATATYPES_KEYS: + query_str = f'INSERT INTO collection_nulls (k, map_v_{simple_type}) VALUES (%s, %s)' + args = [i, OrderedMap([('abc', None), ('def', get_sample(simple_type))])] + raises_simple_and_prepared(InvalidRequest, query_str, args) + i += 1 + + + def test_can_insert_unicode_query_string(self): """ Test to ensure unicode strings can be used in a query