4848 int32_pack , int32_unpack , int64_pack , int64_unpack ,
4949 float_pack , float_unpack , double_pack , double_unpack ,
5050 varint_pack , varint_unpack , point_be , point_le ,
51- vints_pack , vints_unpack )
52- from cassandra import util , VectorDeserializationFailure
51+ vints_pack , vints_unpack , uvint_unpack , uvint_pack )
52+ from cassandra import util
5353
5454_little_endian_flag = 1 # we always serialize LE
5555import ipaddress
@@ -392,6 +392,9 @@ def cass_parameterized_type(cls, full=False):
392392 """
393393 return cls .cass_parameterized_type_with (cls .subtypes , full = full )
394394
395+ @classmethod
396+ def serial_size (cls ):
397+ return None
395398
396399# it's initially named with a _ to avoid registering it as a real type, but
397400# client programs may want to use the name still for isinstance(), etc
@@ -457,10 +460,12 @@ def serialize(uuid, protocol_version):
457460 except AttributeError :
458461 raise TypeError ("Got a non-UUID object for a UUID value" )
459462
463+ @classmethod
464+ def serial_size (cls ):
465+ return 16
460466
461467class BooleanType (_CassandraType ):
462468 typename = 'boolean'
463- serial_size = 1
464469
465470 @staticmethod
466471 def deserialize (byts , protocol_version ):
@@ -470,6 +475,10 @@ def deserialize(byts, protocol_version):
470475 def serialize (truth , protocol_version ):
471476 return int8_pack (truth )
472477
478+ @classmethod
479+ def serial_size (cls ):
480+ return 1
481+
473482class ByteType (_CassandraType ):
474483 typename = 'tinyint'
475484
@@ -500,7 +509,6 @@ def serialize(var, protocol_version):
500509
501510class FloatType (_CassandraType ):
502511 typename = 'float'
503- serial_size = 4
504512
505513 @staticmethod
506514 def deserialize (byts , protocol_version ):
@@ -510,10 +518,12 @@ def deserialize(byts, protocol_version):
510518 def serialize (byts , protocol_version ):
511519 return float_pack (byts )
512520
521+ @classmethod
522+ def serial_size (cls ):
523+ return 4
513524
514525class DoubleType (_CassandraType ):
515526 typename = 'double'
516- serial_size = 8
517527
518528 @staticmethod
519529 def deserialize (byts , protocol_version ):
@@ -523,10 +533,12 @@ def deserialize(byts, protocol_version):
523533 def serialize (byts , protocol_version ):
524534 return double_pack (byts )
525535
536+ @classmethod
537+ def serial_size (cls ):
538+ return 8
526539
527540class LongType (_CassandraType ):
528541 typename = 'bigint'
529- serial_size = 8
530542
531543 @staticmethod
532544 def deserialize (byts , protocol_version ):
@@ -536,10 +548,12 @@ def deserialize(byts, protocol_version):
536548 def serialize (byts , protocol_version ):
537549 return int64_pack (byts )
538550
551+ @classmethod
552+ def serial_size (cls ):
553+ return 8
539554
540555class Int32Type (_CassandraType ):
541556 typename = 'int'
542- serial_size = 4
543557
544558 @staticmethod
545559 def deserialize (byts , protocol_version ):
@@ -549,6 +563,9 @@ def deserialize(byts, protocol_version):
549563 def serialize (byts , protocol_version ):
550564 return int32_pack (byts )
551565
566+ @classmethod
567+ def serial_size (cls ):
568+ return 4
552569
553570class IntegerType (_CassandraType ):
554571 typename = 'varint'
@@ -645,14 +662,16 @@ def serialize(v, protocol_version):
645662
646663 return int64_pack (int (timestamp ))
647664
665+ @classmethod
666+ def serial_size (cls ):
667+ return 8
648668
649669class TimestampType (DateType ):
650670 pass
651671
652672
653673class TimeUUIDType (DateType ):
654674 typename = 'timeuuid'
655- serial_size = 16
656675
657676 def my_timestamp (self ):
658677 return util .unix_time_from_uuid1 (self .val )
@@ -668,6 +687,9 @@ def serialize(timeuuid, protocol_version):
668687 except AttributeError :
669688 raise TypeError ("Got a non-UUID object for a UUID value" )
670689
690+ @classmethod
691+ def serial_size (cls ):
692+ return 16
671693
672694class SimpleDateType (_CassandraType ):
673695 typename = 'date'
@@ -699,7 +721,6 @@ def serialize(val, protocol_version):
699721
700722class ShortType (_CassandraType ):
701723 typename = 'smallint'
702- serial_size = 2
703724
704725 @staticmethod
705726 def deserialize (byts , protocol_version ):
@@ -709,10 +730,14 @@ def deserialize(byts, protocol_version):
709730 def serialize (byts , protocol_version ):
710731 return int16_pack (byts )
711732
712-
713733class TimeType (_CassandraType ):
714734 typename = 'time'
715- serial_size = 8
735+ # Time should be a fixed size 8 byte type but Cassandra 5.0 code marks it as
736+ # variable size... and we have to match what the server expects since the server
737+ # uses that specification to encode data of that type.
738+ #@classmethod
739+ #def serial_size(cls):
740+ # return 8
716741
717742 @staticmethod
718743 def deserialize (byts , protocol_version ):
@@ -1418,6 +1443,11 @@ class VectorType(_CassandraType):
14181443 vector_size = 0
14191444 subtype = None
14201445
1446+ @classmethod
1447+ def serial_size (cls ):
1448+ serialized_size = cls .subtype .serial_size ()
1449+ return cls .vector_size * serialized_size if serialized_size is not None else None
1450+
14211451 @classmethod
14221452 def apply_parameters (cls , params , names ):
14231453 assert len (params ) == 2
@@ -1427,19 +1457,50 @@ def apply_parameters(cls, params, names):
14271457
14281458 @classmethod
14291459 def deserialize (cls , byts , protocol_version ):
1430- serialized_size = getattr (cls .subtype , "serial_size" , None )
1431- if not serialized_size :
1432- raise VectorDeserializationFailure ("Cannot determine serialized size for vector with subtype %s" % cls .subtype .__name__ )
1433- indexes = (serialized_size * x for x in range (0 , cls .vector_size ))
1434- return [cls .subtype .deserialize (byts [idx :idx + serialized_size ], protocol_version ) for idx in indexes ]
1460+ serialized_size = cls .subtype .serial_size ()
1461+ if serialized_size is not None :
1462+ expected_byte_size = serialized_size * cls .vector_size
1463+ if len (byts ) != expected_byte_size :
1464+ raise ValueError (
1465+ "Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead" \
1466+ .format (cls .subtype .typename , cls .vector_size , expected_byte_size , len (byts )))
1467+ indexes = (serialized_size * x for x in range (0 , cls .vector_size ))
1468+ return [cls .subtype .deserialize (byts [idx :idx + serialized_size ], protocol_version ) for idx in indexes ]
1469+
1470+ idx = 0
1471+ rv = []
1472+ while (len (rv ) < cls .vector_size ):
1473+ try :
1474+ size , bytes_read = uvint_unpack (byts [idx :])
1475+ idx += bytes_read
1476+ rv .append (cls .subtype .deserialize (byts [idx :idx + size ], protocol_version ))
1477+ idx += size
1478+ except :
1479+ raise ValueError ("Error reading additional data during vector deserialization after successfully adding {} elements" \
1480+ .format (len (rv )))
1481+
1482+ # If we have any additional data in the serialized vector treat that as an error as well
1483+ if idx < len (byts ):
1484+ raise ValueError ("Additional bytes remaining after vector deserialization completed" )
1485+ return rv
14351486
14361487 @classmethod
14371488 def serialize (cls , v , protocol_version ):
1489+ v_length = len (v )
1490+ if cls .vector_size != v_length :
1491+ raise ValueError (
1492+ "Expected sequence of size {0} for vector of type {1} and dimension {0}, observed sequence of length {2}" \
1493+ .format (cls .vector_size , cls .subtype .typename , v_length ))
1494+
1495+ serialized_size = cls .subtype .serial_size ()
14381496 buf = io .BytesIO ()
14391497 for item in v :
1440- buf .write (cls .subtype .serialize (item , protocol_version ))
1498+ item_bytes = cls .subtype .serialize (item , protocol_version )
1499+ if serialized_size is None :
1500+ buf .write (uvint_pack (len (item_bytes )))
1501+ buf .write (item_bytes )
14411502 return buf .getvalue ()
14421503
14431504 @classmethod
14441505 def cql_parameterized_type (cls ):
1445- return "%s<%s, %s>" % (cls .typename , cls .subtype .typename , cls .vector_size )
1506+ return "%s<%s, %s>" % (cls .typename , cls .subtype .cql_parameterized_type () , cls .vector_size )
0 commit comments