Skip to content

Commit 6b46906

Browse files
absurdfarcedkropachev
authored andcommitted
PYTHON-1352 Add vector type, codec + support for parsing CQL type (datastax#1161)
1 parent 3a9ac29 commit 6b46906

File tree

3 files changed

+55
-5
lines changed

3 files changed

+55
-5
lines changed

cassandra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def emit(self, record):
2323

2424
logging.getLogger('cassandra').addHandler(NullHandler())
2525

26-
__version_info__ = (3, 27, 0)
26+
__version_info__ = (3, 28, 0b1)
2727
__version__ = '.'.join(map(str, __version_info__))
2828

2929

cassandra/cqltypes.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,15 @@ def parse_casstype_args(typestring):
226226
else:
227227
names.append(None)
228228

229-
ctype = lookup_casstype_simple(tok)
229+
try:
230+
ctype = int(tok)
231+
except ValueError:
232+
ctype = lookup_casstype_simple(tok)
230233
types.append(ctype)
231234

232235
# return the first (outer) type, which will have all parameters applied
233236
return args[0][0][0]
234237

235-
236238
def lookup_casstype(casstype):
237239
"""
238240
Given a Cassandra type as a string (possibly including parameters), hand
@@ -286,7 +288,7 @@ class _CassandraType(object, metaclass=CassandraTypeType):
286288
"""
287289

288290
def __repr__(self):
289-
return '<%s( %r )>' % (self.cql_parameterized_type(), self.val)
291+
return '<%s>' % (self.cql_parameterized_type())
290292

291293
@classmethod
292294
def from_binary(cls, byts, protocol_version):
@@ -1402,3 +1404,31 @@ def serialize(cls, v, protocol_version):
14021404
buf.write(int8_pack(cls._encode_precision(bound.precision)))
14031405

14041406
return buf.getvalue()
1407+
1408+
class VectorType(_CassandraType):
1409+
typename = 'org.apache.cassandra.db.marshal.VectorType'
1410+
vector_size = 0
1411+
subtype = None
1412+
1413+
@classmethod
1414+
def apply_parameters(cls, params, names):
1415+
assert len(params) == 2
1416+
subtype = lookup_casstype(params[0])
1417+
vsize = params[1]
1418+
return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,), {'vector_size': vsize, 'subtype': subtype})
1419+
1420+
@classmethod
1421+
def deserialize(cls, byts, protocol_version):
1422+
indexes = (4 * x for x in range(0, cls.vector_size))
1423+
return [cls.subtype.deserialize(byts[idx:idx + 4], protocol_version) for idx in indexes]
1424+
1425+
@classmethod
1426+
def serialize(cls, v, protocol_version):
1427+
buf = io.BytesIO()
1428+
for item in v:
1429+
buf.write(cls.subtype.serialize(item, protocol_version))
1430+
return buf.getvalue()
1431+
1432+
@classmethod
1433+
def cql_parameterized_type(cls):
1434+
return "%s<%s, %s>" % (cls.typename, cls.subtype.typename, cls.vector_size)

tests/unit/test_types.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
EmptyValue, LongType, SetType, UTF8Type,
2626
cql_typename, int8_pack, int64_pack, lookup_casstype,
2727
lookup_casstype_simple, parse_casstype_args,
28-
int32_pack, Int32Type, ListType, MapType
28+
int32_pack, Int32Type, ListType, MapType, VectorType,
29+
FloatType
2930
)
3031
from cassandra.encoder import cql_quote
3132
from cassandra.pool import Host
@@ -188,6 +189,12 @@ class BarType(FooType):
188189
self.assertEqual(UTF8Type, ctype.subtypes[2])
189190
self.assertEqual([b'city', None, b'zip'], ctype.names)
190191

192+
def test_parse_casstype_vector(self):
193+
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 3)")
194+
self.assertTrue(issubclass(ctype, VectorType))
195+
self.assertEqual(3, ctype.vector_size)
196+
self.assertEqual(FloatType, ctype.subtype)
197+
191198
def test_empty_value(self):
192199
self.assertEqual(str(EmptyValue()), 'EMPTY')
193200

@@ -301,6 +308,19 @@ def test_cql_quote(self):
301308
self.assertEqual(cql_quote('test'), "'test'")
302309
self.assertEqual(cql_quote(0), '0')
303310

311+
def test_vector_round_trip(self):
312+
base = [3.4, 2.9, 41.6, 12.0]
313+
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")
314+
base_bytes = ctype.serialize(base, 0)
315+
self.assertEqual(16, len(base_bytes))
316+
result = ctype.deserialize(base_bytes, 0)
317+
self.assertEqual(len(base), len(result))
318+
for idx in range(0,len(base)):
319+
self.assertAlmostEqual(base[idx], result[idx], places=5)
320+
321+
def test_vector_cql_parameterized_type(self):
322+
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")
323+
self.assertEqual(ctype.cql_parameterized_type(), "org.apache.cassandra.db.marshal.VectorType<float, 4>")
304324

305325
ZERO = datetime.timedelta(0)
306326

0 commit comments

Comments
 (0)