Skip to content

Commit b02f021

Browse files
wprzytulapiodul
andcommitted
value: impl DeserializeValue for UDTs
Co-authored-by: Piotr Dulikowski <[email protected]>
1 parent 2bf4ca0 commit b02f021

File tree

1 file changed

+201
-1
lines changed
  • scylla-cql/src/types/deserialize

1 file changed

+201
-1
lines changed

scylla-cql/src/types/deserialize/value.rs

Lines changed: 201 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,111 @@ impl_tuple_multiple!(
10951095
t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15
10961096
);
10971097

1098+
// udts
1099+
1100+
/// An iterator over fields of a User Defined Type.
1101+
///
1102+
/// # Note
1103+
///
1104+
/// A serialized UDT will generally have one value for each field, but it is
1105+
/// allowed to have fewer. This iterator differentiates null values
1106+
/// from non-existent values in the following way:
1107+
///
1108+
/// - `None` - missing from the serialized form
1109+
/// - `Some(None)` - present, but null
1110+
/// - `Some(Some(...))` - non-null, present value
1111+
pub struct UdtIterator<'frame> {
1112+
all_fields: &'frame [(String, ColumnType)],
1113+
type_name: &'frame str,
1114+
keyspace: &'frame str,
1115+
remaining_fields: &'frame [(String, ColumnType)],
1116+
raw_iter: BytesSequenceIterator<'frame>,
1117+
}
1118+
1119+
impl<'frame> UdtIterator<'frame> {
1120+
fn new(
1121+
fields: &'frame [(String, ColumnType)],
1122+
type_name: &'frame str,
1123+
keyspace: &'frame str,
1124+
slice: FrameSlice<'frame>,
1125+
) -> Self {
1126+
Self {
1127+
all_fields: fields,
1128+
remaining_fields: fields,
1129+
type_name,
1130+
keyspace,
1131+
raw_iter: BytesSequenceIterator::new(slice),
1132+
}
1133+
}
1134+
1135+
#[inline]
1136+
pub fn fields(&self) -> &'frame [(String, ColumnType)] {
1137+
self.remaining_fields
1138+
}
1139+
}
1140+
1141+
impl<'frame> DeserializeValue<'frame> for UdtIterator<'frame> {
1142+
fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> {
1143+
match typ {
1144+
ColumnType::UserDefinedType { .. } => Ok(()),
1145+
_ => Err(mk_typck_err::<Self>(typ, UdtTypeCheckErrorKind::NotUdt)),
1146+
}
1147+
}
1148+
1149+
fn deserialize(
1150+
typ: &'frame ColumnType,
1151+
v: Option<FrameSlice<'frame>>,
1152+
) -> Result<Self, DeserializationError> {
1153+
let v = ensure_not_null_frame_slice::<Self>(typ, v)?;
1154+
let (fields, type_name, keyspace) = match typ {
1155+
ColumnType::UserDefinedType {
1156+
field_types,
1157+
type_name,
1158+
keyspace,
1159+
} => (field_types.as_ref(), type_name.as_ref(), keyspace.as_ref()),
1160+
_ => {
1161+
unreachable!("Typecheck should have prevented this scenario!")
1162+
}
1163+
};
1164+
Ok(Self::new(fields, type_name, keyspace, v))
1165+
}
1166+
}
1167+
1168+
impl<'frame> Iterator for UdtIterator<'frame> {
1169+
type Item = (
1170+
&'frame (String, ColumnType),
1171+
Result<Option<Option<FrameSlice<'frame>>>, DeserializationError>,
1172+
);
1173+
1174+
fn next(&mut self) -> Option<Self::Item> {
1175+
// TODO: Should we fail when there are too many fields?
1176+
let (head, fields) = self.remaining_fields.split_first()?;
1177+
self.remaining_fields = fields;
1178+
let raw_res = match self.raw_iter.next() {
1179+
// The field is there and it was parsed correctly
1180+
Some(Ok(raw)) => Ok(Some(raw)),
1181+
1182+
// There were some bytes but they didn't parse as correct field value
1183+
Some(Err(err)) => Err(mk_deser_err::<Self>(
1184+
&ColumnType::UserDefinedType {
1185+
type_name: self.type_name.to_owned(),
1186+
keyspace: self.keyspace.to_owned(),
1187+
field_types: self.all_fields.to_owned(),
1188+
},
1189+
BuiltinDeserializationErrorKind::GenericParseError(err),
1190+
)),
1191+
1192+
// The field is just missing from the serialized form
1193+
None => Ok(None),
1194+
};
1195+
Some((head, raw_res))
1196+
}
1197+
1198+
fn size_hint(&self) -> (usize, Option<usize>) {
1199+
self.raw_iter.size_hint()
1200+
}
1201+
}
1202+
10981203
// Utilities
10991204

11001205
fn ensure_not_null_frame_slice<'frame, T>(
@@ -1182,6 +1287,39 @@ impl<'frame> Iterator for FixedLengthBytesSequenceIterator<'frame> {
11821287
}
11831288
}
11841289

1290+
/// Iterates over a sequence of `[bytes]` items from a frame subslice.
1291+
///
1292+
/// The `[bytes]` items are parsed until the end of subslice is reached.
1293+
#[derive(Clone, Copy, Debug)]
1294+
pub struct BytesSequenceIterator<'frame> {
1295+
slice: FrameSlice<'frame>,
1296+
}
1297+
1298+
impl<'frame> BytesSequenceIterator<'frame> {
1299+
fn new(slice: FrameSlice<'frame>) -> Self {
1300+
Self { slice }
1301+
}
1302+
}
1303+
1304+
impl<'frame> From<FrameSlice<'frame>> for BytesSequenceIterator<'frame> {
1305+
#[inline]
1306+
fn from(slice: FrameSlice<'frame>) -> Self {
1307+
Self::new(slice)
1308+
}
1309+
}
1310+
1311+
impl<'frame> Iterator for BytesSequenceIterator<'frame> {
1312+
type Item = Result<Option<FrameSlice<'frame>>, ParseError>;
1313+
1314+
fn next(&mut self) -> Option<Self::Item> {
1315+
if self.slice.as_slice().is_empty() {
1316+
None
1317+
} else {
1318+
Some(self.slice.read_cql_bytes())
1319+
}
1320+
}
1321+
}
1322+
11851323
// Error facilities
11861324

11871325
/// Type checking of one of the built-in types failed.
@@ -1250,6 +1388,9 @@ pub enum BuiltinTypeCheckErrorKind {
12501388

12511389
/// A type check failure specific to a CQL tuple.
12521390
TupleError(TupleTypeCheckErrorKind),
1391+
1392+
/// A type check failure specific to a CQL UDT.
1393+
UdtError(UdtTypeCheckErrorKind),
12531394
}
12541395

12551396
impl From<SetOrListTypeCheckErrorKind> for BuiltinTypeCheckErrorKind {
@@ -1273,6 +1414,13 @@ impl From<TupleTypeCheckErrorKind> for BuiltinTypeCheckErrorKind {
12731414
}
12741415
}
12751416

1417+
impl From<UdtTypeCheckErrorKind> for BuiltinTypeCheckErrorKind {
1418+
#[inline]
1419+
fn from(value: UdtTypeCheckErrorKind) -> Self {
1420+
BuiltinTypeCheckErrorKind::UdtError(value)
1421+
}
1422+
}
1423+
12761424
impl Display for BuiltinTypeCheckErrorKind {
12771425
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
12781426
match self {
@@ -1282,6 +1430,7 @@ impl Display for BuiltinTypeCheckErrorKind {
12821430
BuiltinTypeCheckErrorKind::SetOrListError(err) => err.fmt(f),
12831431
BuiltinTypeCheckErrorKind::MapError(err) => err.fmt(f),
12841432
BuiltinTypeCheckErrorKind::TupleError(err) => err.fmt(f),
1433+
BuiltinTypeCheckErrorKind::UdtError(err) => err.fmt(f),
12851434
}
12861435
}
12871436
}
@@ -1393,6 +1542,25 @@ impl Display for TupleTypeCheckErrorKind {
13931542
}
13941543
}
13951544

1545+
/// Describes why type checking of a user defined type failed.
1546+
#[derive(Debug, Clone)]
1547+
#[non_exhaustive]
1548+
pub enum UdtTypeCheckErrorKind {
1549+
/// The CQL type is not a user defined type.
1550+
NotUdt,
1551+
}
1552+
1553+
impl Display for UdtTypeCheckErrorKind {
1554+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1555+
match self {
1556+
UdtTypeCheckErrorKind::NotUdt => write!(
1557+
f,
1558+
"the CQL type the Rust type was attempted to be type checked against is not a UDT"
1559+
),
1560+
}
1561+
}
1562+
}
1563+
13961564
/// Deserialization of one of the built-in types failed.
13971565
#[derive(Debug, Error)]
13981566
#[error("Failed to deserialize Rust type {rust_name} from CQL type {cql_type:?}: {kind}")]
@@ -1609,7 +1777,7 @@ mod tests {
16091777
use crate::frame::value::{
16101778
Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint,
16111779
};
1612-
use crate::types::deserialize::{DeserializationError, FrameSlice};
1780+
use crate::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError};
16131781
use crate::types::serialize::value::SerializeValue;
16141782
use crate::types::serialize::CellWriter;
16151783

@@ -2142,6 +2310,38 @@ mod tests {
21422310
assert_eq!(tup, (42, "foo", None));
21432311
}
21442312

2313+
#[test]
2314+
fn test_custom_type_parser() {
2315+
#[derive(Default, Debug, PartialEq, Eq)]
2316+
struct SwappedPair<A, B>(B, A);
2317+
impl<'frame, A, B> DeserializeValue<'frame> for SwappedPair<A, B>
2318+
where
2319+
A: DeserializeValue<'frame>,
2320+
B: DeserializeValue<'frame>,
2321+
{
2322+
fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> {
2323+
<(B, A) as DeserializeValue<'frame>>::type_check(typ)
2324+
}
2325+
2326+
fn deserialize(
2327+
typ: &'frame ColumnType,
2328+
v: Option<FrameSlice<'frame>>,
2329+
) -> Result<Self, DeserializationError> {
2330+
<(B, A) as DeserializeValue<'frame>>::deserialize(typ, v).map(|(b, a)| Self(b, a))
2331+
}
2332+
}
2333+
2334+
let mut tuple_contents = BytesMut::new();
2335+
append_bytes(&mut tuple_contents, "foo".as_bytes());
2336+
append_bytes(&mut tuple_contents, &42i32.to_be_bytes());
2337+
let tuple = make_bytes(&tuple_contents);
2338+
2339+
let typ = ColumnType::Tuple(vec![ColumnType::Ascii, ColumnType::Int]);
2340+
2341+
let tup = deserialize::<SwappedPair<i32, &str>>(&typ, &tuple).unwrap();
2342+
assert_eq!(tup, SwappedPair("foo", 42));
2343+
}
2344+
21452345
// Checks that both new and old serialization framework
21462346
// produces the same results in this case
21472347
fn compat_check<T>(typ: &ColumnType, raw: Bytes)

0 commit comments

Comments
 (0)