Skip to content

Commit 8e0f29a

Browse files
wprzytulapiodul
andcommitted
deser/row: impl DeserializeRow for tuples
Co-authored-by: Piotr Dulikowski <[email protected]>
1 parent 82b5d96 commit 8e0f29a

File tree

1 file changed

+189
-3
lines changed
  • scylla-cql/src/types/deserialize

1 file changed

+189
-3
lines changed

scylla-cql/src/types/deserialize/row.rs

Lines changed: 189 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ use std::fmt::Display;
44

55
use thiserror::Error;
66

7-
use super::{DeserializationError, FrameSlice, TypeCheckError};
7+
use super::value::DeserializeValue;
8+
use super::{make_error_replace_rust_name, DeserializationError, FrameSlice, TypeCheckError};
89
use crate::frame::response::result::{ColumnSpec, ColumnType};
910

1011
/// Represents a raw, unparsed column value.
@@ -122,6 +123,90 @@ impl<'frame> DeserializeRow<'frame> for ColumnIterator<'frame> {
122123
}
123124
}
124125

126+
make_error_replace_rust_name!(
127+
_typck_error_replace_rust_name,
128+
TypeCheckError,
129+
BuiltinTypeCheckError
130+
);
131+
132+
make_error_replace_rust_name!(
133+
deser_error_replace_rust_name,
134+
DeserializationError,
135+
BuiltinDeserializationError
136+
);
137+
138+
// tuples
139+
//
140+
/// This is the new encouraged way for deserializing a row.
141+
/// If only you know the exact column types in advance, you had better deserialize the row
142+
/// to a tuple. The new deserialization framework will take care of all type checking
143+
/// and needed conversions, issuing meaningful errors in case something goes wrong.
144+
macro_rules! impl_tuple {
145+
($($Ti:ident),*; $($idx:literal),*; $($idf:ident),*) => {
146+
impl<'frame, $($Ti),*> DeserializeRow<'frame> for ($($Ti,)*)
147+
where
148+
$($Ti: DeserializeValue<'frame>),*
149+
{
150+
fn type_check(specs: &[ColumnSpec]) -> Result<(), TypeCheckError> {
151+
const TUPLE_LEN: usize = (&[$($idx),*] as &[i32]).len();
152+
153+
let column_types_iter = || specs.iter().map(|spec| spec.typ.clone());
154+
if let [$($idf),*] = &specs {
155+
$(
156+
<$Ti as DeserializeValue<'frame>>::type_check(&$idf.typ)
157+
.map_err(|err| mk_typck_err::<Self>(column_types_iter(), BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed {
158+
column_index: $idx,
159+
column_name: specs[$idx].name.clone(),
160+
err
161+
}))?;
162+
)*
163+
Ok(())
164+
} else {
165+
Err(mk_typck_err::<Self>(column_types_iter(), BuiltinTypeCheckErrorKind::WrongColumnCount {
166+
rust_cols: TUPLE_LEN, cql_cols: specs.len()
167+
}))
168+
}
169+
}
170+
171+
fn deserialize(mut row: ColumnIterator<'frame>) -> Result<Self, DeserializationError> {
172+
const TUPLE_LEN: usize = (&[$($idx),*] as &[i32]).len();
173+
174+
let ret = (
175+
$({
176+
let column = row.next().unwrap_or_else(|| unreachable!(
177+
"Typecheck should have prevented this scenario! Column count mismatch: rust type {}, cql row {}",
178+
TUPLE_LEN,
179+
$idx
180+
)).map_err(deser_error_replace_rust_name::<Self>)?;
181+
182+
<$Ti as DeserializeValue<'frame>>::deserialize(&column.spec.typ, column.slice)
183+
.map_err(|err| mk_deser_err::<Self>(BuiltinDeserializationErrorKind::ColumnDeserializationFailed {
184+
column_index: column.index,
185+
column_name: column.spec.name.clone(),
186+
err,
187+
}))?
188+
},)*
189+
);
190+
assert!(
191+
row.next().is_none(),
192+
"Typecheck should have prevented this scenario! Column count mismatch: rust type {}, cql row is bigger",
193+
TUPLE_LEN,
194+
);
195+
Ok(ret)
196+
}
197+
}
198+
}
199+
}
200+
201+
use super::value::impl_tuple_multiple;
202+
203+
// Implements row-to-tuple deserialization for all tuple sizes up to 16.
204+
impl_tuple_multiple!(
205+
T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15;
206+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15;
207+
t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15
208+
);
209+
125210
// Error facilities
126211

127212
/// Failed to type check incoming result column types again given Rust type,
@@ -161,11 +246,47 @@ fn mk_typck_err_named(
161246
/// Describes why type checking incoming result column types again given Rust type failed.
162247
#[derive(Debug, Clone)]
163248
#[non_exhaustive]
164-
pub enum BuiltinTypeCheckErrorKind {}
249+
pub enum BuiltinTypeCheckErrorKind {
250+
/// The Rust type expects `rust_cols` columns, but the statement operates on `cql_cols`.
251+
WrongColumnCount {
252+
/// The number of values that the Rust type provides.
253+
rust_cols: usize,
254+
255+
/// The number of columns that the statement operates on.
256+
cql_cols: usize,
257+
},
258+
259+
/// Column type check failed between Rust type and DB type at given position (=in given column).
260+
ColumnTypeCheckFailed {
261+
/// Index of the column.
262+
column_index: usize,
263+
264+
/// Name of the column, as provided by the DB.
265+
column_name: String,
266+
267+
/// Inner type check error due to the type mismatch.
268+
err: TypeCheckError,
269+
},
270+
}
165271

166272
impl Display for BuiltinTypeCheckErrorKind {
167273
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168-
Ok(())
274+
match self {
275+
BuiltinTypeCheckErrorKind::WrongColumnCount {
276+
rust_cols,
277+
cql_cols,
278+
} => {
279+
write!(f, "wrong column count: the statement operates on {cql_cols} columns, but the given rust types contains {rust_cols}")
280+
}
281+
BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed {
282+
column_index,
283+
column_name,
284+
err,
285+
} => write!(
286+
f,
287+
"mismatched types in column {column_name} at index {column_index}: {err}"
288+
),
289+
}
169290
}
170291
}
171292

@@ -201,6 +322,18 @@ pub(super) fn mk_deser_err_named(
201322
#[derive(Debug, Clone)]
202323
#[non_exhaustive]
203324
pub enum BuiltinDeserializationErrorKind {
325+
/// One of the columns failed to deserialize.
326+
ColumnDeserializationFailed {
327+
/// Index of the column that failed to deserialize.
328+
column_index: usize,
329+
330+
/// Name of the column that failed to deserialize.
331+
column_name: String,
332+
333+
/// The error that caused the column deserialization to fail.
334+
err: DeserializationError,
335+
},
336+
204337
/// One of the raw columns failed to deserialize, most probably
205338
/// due to the invalid column structure inside a row in the frame.
206339
RawColumnDeserializationFailed {
@@ -218,6 +351,16 @@ pub enum BuiltinDeserializationErrorKind {
218351
impl Display for BuiltinDeserializationErrorKind {
219352
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
220353
match self {
354+
BuiltinDeserializationErrorKind::ColumnDeserializationFailed {
355+
column_index,
356+
column_name,
357+
err,
358+
} => {
359+
write!(
360+
f,
361+
"failed to deserialize column {column_name} at index {column_index}: {err}"
362+
)
363+
}
221364
BuiltinDeserializationErrorKind::RawColumnDeserializationFailed {
222365
column_index,
223366
column_name,
@@ -242,6 +385,49 @@ mod tests {
242385
use super::super::tests::{serialize_cells, spec};
243386
use super::{ColumnIterator, DeserializeRow};
244387

388+
#[test]
389+
fn test_tuple_deserialization() {
390+
// Empty tuple
391+
deserialize::<()>(&[], &Bytes::new()).unwrap();
392+
393+
// 1-elem tuple
394+
let (a,) = deserialize::<(i32,)>(
395+
&[spec("i", ColumnType::Int)],
396+
&serialize_cells([val_int(123)]),
397+
)
398+
.unwrap();
399+
assert_eq!(a, 123);
400+
401+
// 3-elem tuple
402+
let (a, b, c) = deserialize::<(i32, i32, i32)>(
403+
&[
404+
spec("i1", ColumnType::Int),
405+
spec("i2", ColumnType::Int),
406+
spec("i3", ColumnType::Int),
407+
],
408+
&serialize_cells([val_int(123), val_int(456), val_int(789)]),
409+
)
410+
.unwrap();
411+
assert_eq!((a, b, c), (123, 456, 789));
412+
413+
// Make sure that column type mismatch is detected
414+
deserialize::<(i32, String, i32)>(
415+
&[
416+
spec("i1", ColumnType::Int),
417+
spec("i2", ColumnType::Int),
418+
spec("i3", ColumnType::Int),
419+
],
420+
&serialize_cells([val_int(123), val_int(456), val_int(789)]),
421+
)
422+
.unwrap_err();
423+
424+
// Make sure that borrowing types compile and work correctly
425+
let specs = &[spec("s", ColumnType::Text)];
426+
let byts = serialize_cells([val_str("abc")]);
427+
let (s,) = deserialize::<(&str,)>(specs, &byts).unwrap();
428+
assert_eq!(s, "abc");
429+
}
430+
245431
#[test]
246432
fn test_deserialization_as_column_iterator() {
247433
let col_specs = [

0 commit comments

Comments
 (0)