Skip to content

Commit 92a4324

Browse files
wprzytulapiodul
andcommitted
value: impl DeserializeValue for Map
Co-authored-by: Piotr Dulikowski <[email protected]>
1 parent e13bfa9 commit 92a4324

File tree

1 file changed

+306
-3
lines changed
  • scylla-cql/src/types/deserialize

1 file changed

+306
-3
lines changed

scylla-cql/src/types/deserialize/value.rs

Lines changed: 306 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Provides types for dealing with CQL value deserialization.
22
33
use std::{
4-
collections::{BTreeSet, HashSet},
4+
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
55
hash::{BuildHasher, Hash},
66
net::IpAddr,
77
};
@@ -823,6 +823,169 @@ where
823823
}
824824
}
825825

826+
/// An iterator over a CQL map.
827+
pub struct MapIterator<'frame, K, V> {
828+
coll_typ: &'frame ColumnType,
829+
k_typ: &'frame ColumnType,
830+
v_typ: &'frame ColumnType,
831+
raw_iter: FixedLengthBytesSequenceIterator<'frame>,
832+
phantom_data_k: std::marker::PhantomData<K>,
833+
phantom_data_v: std::marker::PhantomData<V>,
834+
}
835+
836+
impl<'frame, K, V> MapIterator<'frame, K, V> {
837+
fn new(
838+
coll_typ: &'frame ColumnType,
839+
k_typ: &'frame ColumnType,
840+
v_typ: &'frame ColumnType,
841+
count: usize,
842+
slice: FrameSlice<'frame>,
843+
) -> Self {
844+
Self {
845+
coll_typ,
846+
k_typ,
847+
v_typ,
848+
raw_iter: FixedLengthBytesSequenceIterator::new(count, slice),
849+
phantom_data_k: std::marker::PhantomData,
850+
phantom_data_v: std::marker::PhantomData,
851+
}
852+
}
853+
}
854+
855+
impl<'frame, K, V> DeserializeValue<'frame> for MapIterator<'frame, K, V>
856+
where
857+
K: DeserializeValue<'frame>,
858+
V: DeserializeValue<'frame>,
859+
{
860+
fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> {
861+
match typ {
862+
ColumnType::Map(k_t, v_t) => {
863+
<K as DeserializeValue<'frame>>::type_check(k_t).map_err(|err| {
864+
mk_typck_err::<Self>(typ, MapTypeCheckErrorKind::KeyTypeCheckFailed(err))
865+
})?;
866+
<V as DeserializeValue<'frame>>::type_check(v_t).map_err(|err| {
867+
mk_typck_err::<Self>(typ, MapTypeCheckErrorKind::ValueTypeCheckFailed(err))
868+
})?;
869+
Ok(())
870+
}
871+
_ => Err(mk_typck_err::<Self>(typ, MapTypeCheckErrorKind::NotMap)),
872+
}
873+
}
874+
875+
fn deserialize(
876+
typ: &'frame ColumnType,
877+
v: Option<FrameSlice<'frame>>,
878+
) -> Result<Self, DeserializationError> {
879+
let mut v = ensure_not_null_frame_slice::<Self>(typ, v)?;
880+
let count = types::read_int_length(v.as_slice_mut()).map_err(|err| {
881+
mk_deser_err::<Self>(
882+
typ,
883+
MapDeserializationErrorKind::LengthDeserializationFailed(
884+
DeserializationError::new(err),
885+
),
886+
)
887+
})?;
888+
let (k_typ, v_typ) = match typ {
889+
ColumnType::Map(k_t, v_t) => (k_t, v_t),
890+
_ => {
891+
unreachable!("Typecheck should have prevented this scenario!")
892+
}
893+
};
894+
Ok(Self::new(typ, k_typ, v_typ, 2 * count, v))
895+
}
896+
}
897+
898+
impl<'frame, K, V> Iterator for MapIterator<'frame, K, V>
899+
where
900+
K: DeserializeValue<'frame>,
901+
V: DeserializeValue<'frame>,
902+
{
903+
type Item = Result<(K, V), DeserializationError>;
904+
905+
fn next(&mut self) -> Option<Self::Item> {
906+
let raw_k = match self.raw_iter.next() {
907+
Some(Ok(raw_k)) => raw_k,
908+
Some(Err(err)) => {
909+
return Some(Err(mk_deser_err::<Self>(
910+
self.coll_typ,
911+
BuiltinDeserializationErrorKind::GenericParseError(err),
912+
)));
913+
}
914+
None => return None,
915+
};
916+
let raw_v = match self.raw_iter.next() {
917+
Some(Ok(raw_v)) => raw_v,
918+
Some(Err(err)) => {
919+
return Some(Err(mk_deser_err::<Self>(
920+
self.coll_typ,
921+
BuiltinDeserializationErrorKind::GenericParseError(err),
922+
)));
923+
}
924+
None => return None,
925+
};
926+
927+
let do_next = || -> Result<(K, V), DeserializationError> {
928+
let k = K::deserialize(self.k_typ, raw_k).map_err(|err| {
929+
mk_deser_err::<Self>(
930+
self.coll_typ,
931+
MapDeserializationErrorKind::KeyDeserializationFailed(err),
932+
)
933+
})?;
934+
let v = V::deserialize(self.v_typ, raw_v).map_err(|err| {
935+
mk_deser_err::<Self>(
936+
self.coll_typ,
937+
MapDeserializationErrorKind::ValueDeserializationFailed(err),
938+
)
939+
})?;
940+
Ok((k, v))
941+
};
942+
Some(do_next())
943+
}
944+
945+
fn size_hint(&self) -> (usize, Option<usize>) {
946+
self.raw_iter.size_hint()
947+
}
948+
}
949+
950+
impl<'frame, K, V> DeserializeValue<'frame> for BTreeMap<K, V>
951+
where
952+
K: DeserializeValue<'frame> + Ord,
953+
V: DeserializeValue<'frame>,
954+
{
955+
fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> {
956+
MapIterator::<'frame, K, V>::type_check(typ).map_err(typck_error_replace_rust_name::<Self>)
957+
}
958+
959+
fn deserialize(
960+
typ: &'frame ColumnType,
961+
v: Option<FrameSlice<'frame>>,
962+
) -> Result<Self, DeserializationError> {
963+
MapIterator::<'frame, K, V>::deserialize(typ, v)
964+
.and_then(|it| it.collect::<Result<_, DeserializationError>>())
965+
.map_err(deser_error_replace_rust_name::<Self>)
966+
}
967+
}
968+
969+
impl<'frame, K, V, S> DeserializeValue<'frame> for HashMap<K, V, S>
970+
where
971+
K: DeserializeValue<'frame> + Eq + Hash,
972+
V: DeserializeValue<'frame>,
973+
S: BuildHasher + Default + 'frame,
974+
{
975+
fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> {
976+
MapIterator::<'frame, K, V>::type_check(typ).map_err(typck_error_replace_rust_name::<Self>)
977+
}
978+
979+
fn deserialize(
980+
typ: &'frame ColumnType,
981+
v: Option<FrameSlice<'frame>>,
982+
) -> Result<Self, DeserializationError> {
983+
MapIterator::<'frame, K, V>::deserialize(typ, v)
984+
.and_then(|it| it.collect::<Result<_, DeserializationError>>())
985+
.map_err(deser_error_replace_rust_name::<Self>)
986+
}
987+
}
988+
826989
// Utilities
827990

828991
fn ensure_not_null_frame_slice<'frame, T>(
@@ -954,6 +1117,9 @@ pub enum BuiltinTypeCheckErrorKind {
9541117

9551118
/// A type check failure specific to a CQL set or list.
9561119
SetOrListError(SetOrListTypeCheckErrorKind),
1120+
1121+
/// A type check failure specific to a CQL map.
1122+
MapError(MapTypeCheckErrorKind),
9571123
}
9581124

9591125
impl From<SetOrListTypeCheckErrorKind> for BuiltinTypeCheckErrorKind {
@@ -963,13 +1129,21 @@ impl From<SetOrListTypeCheckErrorKind> for BuiltinTypeCheckErrorKind {
9631129
}
9641130
}
9651131

1132+
impl From<MapTypeCheckErrorKind> for BuiltinTypeCheckErrorKind {
1133+
#[inline]
1134+
fn from(value: MapTypeCheckErrorKind) -> Self {
1135+
BuiltinTypeCheckErrorKind::MapError(value)
1136+
}
1137+
}
1138+
9661139
impl Display for BuiltinTypeCheckErrorKind {
9671140
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
9681141
match self {
9691142
BuiltinTypeCheckErrorKind::MismatchedType { expected } => {
9701143
write!(f, "expected one of the CQL types: {expected:?}")
9711144
}
9721145
BuiltinTypeCheckErrorKind::SetOrListError(err) => err.fmt(f),
1146+
BuiltinTypeCheckErrorKind::MapError(err) => err.fmt(f),
9731147
}
9741148
}
9751149
}
@@ -1002,6 +1176,34 @@ impl Display for SetOrListTypeCheckErrorKind {
10021176
}
10031177
}
10041178

1179+
/// Describes why type checking of a map type failed.
1180+
#[derive(Debug, Clone)]
1181+
#[non_exhaustive]
1182+
pub enum MapTypeCheckErrorKind {
1183+
/// The CQL type is not a map.
1184+
NotMap,
1185+
/// Incompatible key types.
1186+
KeyTypeCheckFailed(TypeCheckError),
1187+
/// Incompatible value types.
1188+
ValueTypeCheckFailed(TypeCheckError),
1189+
}
1190+
1191+
impl Display for MapTypeCheckErrorKind {
1192+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1193+
match self {
1194+
MapTypeCheckErrorKind::NotMap => {
1195+
f.write_str("the CQL type the Rust type was attempted to be type checked against was neither a map")
1196+
}
1197+
MapTypeCheckErrorKind::KeyTypeCheckFailed(err) => {
1198+
write!(f, "the map key types between the CQL type and the Rust type failed to type check against each other: {}", err)
1199+
},
1200+
MapTypeCheckErrorKind::ValueTypeCheckFailed(err) => {
1201+
write!(f, "the map value types between the CQL type and the Rust type failed to type check against each other: {}", err)
1202+
},
1203+
}
1204+
}
1205+
}
1206+
10051207
/// Deserialization of one of the built-in types failed.
10061208
#[derive(Debug, Error)]
10071209
#[error("Failed to deserialize Rust type {rust_name} from CQL type {cql_type:?}: {kind}")]
@@ -1063,6 +1265,9 @@ pub enum BuiltinDeserializationErrorKind {
10631265

10641266
/// A deserialization failure specific to a CQL set or list.
10651267
SetOrListError(SetOrListDeserializationErrorKind),
1268+
1269+
/// A deserialization failure specific to a CQL map.
1270+
MapError(MapDeserializationErrorKind),
10661271
}
10671272

10681273
impl Display for BuiltinDeserializationErrorKind {
@@ -1091,6 +1296,7 @@ impl Display for BuiltinDeserializationErrorKind {
10911296
"the length of read value in bytes ({got}) is not suitable for IP address; expected 4 or 16"
10921297
),
10931298
BuiltinDeserializationErrorKind::SetOrListError(err) => err.fmt(f),
1299+
BuiltinDeserializationErrorKind::MapError(err) => err.fmt(f),
10941300
}
10951301
}
10961302
}
@@ -1126,12 +1332,48 @@ impl From<SetOrListDeserializationErrorKind> for BuiltinDeserializationErrorKind
11261332
}
11271333
}
11281334

1335+
/// Describes why deserialization of a map type failed.
1336+
#[derive(Debug)]
1337+
#[non_exhaustive]
1338+
pub enum MapDeserializationErrorKind {
1339+
/// Failed to deserialize map's length.
1340+
LengthDeserializationFailed(DeserializationError),
1341+
1342+
/// One of the keys in the map failed to deserialize.
1343+
KeyDeserializationFailed(DeserializationError),
1344+
1345+
/// One of the values in the map failed to deserialize.
1346+
ValueDeserializationFailed(DeserializationError),
1347+
}
1348+
1349+
impl Display for MapDeserializationErrorKind {
1350+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1351+
match self {
1352+
MapDeserializationErrorKind::LengthDeserializationFailed(err) => {
1353+
write!(f, "failed to deserialize map's length: {}", err)
1354+
}
1355+
MapDeserializationErrorKind::KeyDeserializationFailed(err) => {
1356+
write!(f, "failed to deserialize one of the keys: {}", err)
1357+
}
1358+
MapDeserializationErrorKind::ValueDeserializationFailed(err) => {
1359+
write!(f, "failed to deserialize one of the values: {}", err)
1360+
}
1361+
}
1362+
}
1363+
}
1364+
1365+
impl From<MapDeserializationErrorKind> for BuiltinDeserializationErrorKind {
1366+
fn from(err: MapDeserializationErrorKind) -> Self {
1367+
Self::MapError(err)
1368+
}
1369+
}
1370+
11291371
#[cfg(test)]
11301372
mod tests {
11311373
use bytes::{BufMut, Bytes, BytesMut};
11321374
use uuid::Uuid;
11331375

1134-
use std::collections::{BTreeSet, HashSet};
1376+
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
11351377
use std::fmt::Debug;
11361378
use std::net::{IpAddr, Ipv6Addr};
11371379

@@ -1147,7 +1389,7 @@ mod tests {
11471389

11481390
use super::{
11491391
mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue, ListlikeIterator,
1150-
MaybeEmpty,
1392+
MapIterator, MaybeEmpty,
11511393
};
11521394

11531395
#[test]
@@ -1530,6 +1772,19 @@ mod tests {
15301772
compat_check::<Vec<i32>>(&set_type, set.clone());
15311773
compat_check::<BTreeSet<i32>>(&set_type, set.clone());
15321774
compat_check::<HashSet<i32>>(&set_type, set);
1775+
1776+
let mut map = BytesMut::new();
1777+
map.put_i32(3);
1778+
append_bytes(&mut map, &123i32.to_be_bytes());
1779+
append_bytes(&mut map, "quick".as_bytes());
1780+
append_bytes(&mut map, &456i32.to_be_bytes());
1781+
append_bytes(&mut map, "brown".as_bytes());
1782+
append_bytes(&mut map, &789i32.to_be_bytes());
1783+
append_bytes(&mut map, "fox".as_bytes());
1784+
let map = make_bytes(&map);
1785+
let map_type = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Text));
1786+
compat_check::<BTreeMap<i32, String>>(&map_type, map.clone());
1787+
compat_check::<HashMap<i32, String>>(&map_type, map);
15331788
}
15341789

15351790
#[test]
@@ -1598,6 +1853,54 @@ mod tests {
15981853
);
15991854
}
16001855

1856+
#[test]
1857+
fn test_map() {
1858+
let mut collection_contents = BytesMut::new();
1859+
collection_contents.put_i32(3);
1860+
append_bytes(&mut collection_contents, &1i32.to_be_bytes());
1861+
append_bytes(&mut collection_contents, "quick".as_bytes());
1862+
append_bytes(&mut collection_contents, &2i32.to_be_bytes());
1863+
append_bytes(&mut collection_contents, "brown".as_bytes());
1864+
append_bytes(&mut collection_contents, &3i32.to_be_bytes());
1865+
append_bytes(&mut collection_contents, "fox".as_bytes());
1866+
1867+
let collection = make_bytes(&collection_contents);
1868+
1869+
let typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Ascii));
1870+
1871+
// iterator
1872+
let mut iter = deserialize::<MapIterator<i32, &str>>(&typ, &collection).unwrap();
1873+
assert_eq!(iter.next().transpose().unwrap(), Some((1, "quick")));
1874+
assert_eq!(iter.next().transpose().unwrap(), Some((2, "brown")));
1875+
assert_eq!(iter.next().transpose().unwrap(), Some((3, "fox")));
1876+
assert_eq!(iter.next().transpose().unwrap(), None);
1877+
1878+
let expected_str = vec![(1, "quick"), (2, "brown"), (3, "fox")];
1879+
let expected_string = vec![
1880+
(1, "quick".to_string()),
1881+
(2, "brown".to_string()),
1882+
(3, "fox".to_string()),
1883+
];
1884+
1885+
// hash set
1886+
let decoded_hash_str = deserialize::<HashMap<i32, &str>>(&typ, &collection).unwrap();
1887+
let decoded_hash_string = deserialize::<HashMap<i32, String>>(&typ, &collection).unwrap();
1888+
assert_eq!(decoded_hash_str, expected_str.clone().into_iter().collect());
1889+
assert_eq!(
1890+
decoded_hash_string,
1891+
expected_string.clone().into_iter().collect(),
1892+
);
1893+
1894+
// btree set
1895+
let decoded_btree_str = deserialize::<BTreeMap<i32, &str>>(&typ, &collection).unwrap();
1896+
let decoded_btree_string = deserialize::<BTreeMap<i32, String>>(&typ, &collection).unwrap();
1897+
assert_eq!(
1898+
decoded_btree_str,
1899+
expected_str.clone().into_iter().collect(),
1900+
);
1901+
assert_eq!(decoded_btree_string, expected_string.into_iter().collect(),);
1902+
}
1903+
16011904
// Checks that both new and old serialization framework
16021905
// produces the same results in this case
16031906
fn compat_check<T>(typ: &ColumnType, raw: Bytes)

0 commit comments

Comments
 (0)