Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions scylla-cql/src/deserialize/row_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ fn test_struct_deserialization_loose_ordering() {
d: i32,
#[scylla(default_when_null)]
e: &'a str,
#[scylla(allow_missing)]
f: &'a str,
}

// Original order of columns
Expand All @@ -124,6 +126,7 @@ fn test_struct_deserialization_loose_ordering() {
c: String::new(),
d: 0,
e: "def",
f: "",
}
);

Expand All @@ -144,6 +147,7 @@ fn test_struct_deserialization_loose_ordering() {
c: String::new(),
d: 0,
e: "",
f: "",
}
);

Expand All @@ -169,6 +173,90 @@ fn test_struct_deserialization_loose_ordering() {
MyRow::type_check(specs).unwrap_err();
}

#[test]
fn test_struct_deserialization_loose_ordering_allow_missing() {
#[derive(DeserializeRow, PartialEq, Eq, Debug)]
#[scylla(crate = "crate")]
#[scylla(allow_missing)]
struct MyRow<'a> {
a: &'a str,
b: Option<i32>,
#[scylla(skip)]
c: String,
#[scylla(default_when_null)]
d: i32,
#[scylla(default_when_null)]
e: &'a str,
f: &'a str,
g: &'a str,
}

// Original order of columns
let specs = &[
spec("a", ColumnType::Native(NativeType::Text)),
spec("b", ColumnType::Native(NativeType::Int)),
spec("d", ColumnType::Native(NativeType::Int)),
spec("e", ColumnType::Native(NativeType::Text)),
];
let byts = serialize_cells([val_str("abc"), val_int(123), None, val_str("def")]);
let row = deserialize::<MyRow<'_>>(specs, &byts).unwrap();
assert_eq!(
row,
MyRow {
a: "abc",
b: Some(123),
c: String::new(),
d: 0,
e: "def",
f: "",
g: "",
}
);

// Different order of columns - should still work
let specs = &[
spec("e", ColumnType::Native(NativeType::Text)),
spec("b", ColumnType::Native(NativeType::Int)),
spec("d", ColumnType::Native(NativeType::Int)),
spec("a", ColumnType::Native(NativeType::Text)),
];
let byts = serialize_cells([None, val_int(123), None, val_str("abc")]);
let row = deserialize::<MyRow<'_>>(specs, &byts).unwrap();
assert_eq!(
row,
MyRow {
a: "abc",
b: Some(123),
c: String::new(),
d: 0,
e: "",
f: "",
g: "",
}
);

// Missing column
let specs = &[
spec("a", ColumnType::Native(NativeType::Text)),
spec("e", ColumnType::Native(NativeType::Text)),
];
MyRow::type_check(specs).unwrap();

// Missing both default_when_null column
let specs = &[
spec("a", ColumnType::Native(NativeType::Text)),
spec("b", ColumnType::Native(NativeType::Int)),
];
MyRow::type_check(specs).unwrap();

// Wrong column type
let specs = &[
spec("a", ColumnType::Native(NativeType::Int)),
spec("b", ColumnType::Native(NativeType::Int)),
];
MyRow::type_check(specs).unwrap_err();
}

#[test]
fn test_struct_deserialization_strict_ordering() {
#[derive(DeserializeRow, PartialEq, Eq, Debug)]
Expand Down
49 changes: 34 additions & 15 deletions scylla-macros/src/deserialize/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ struct StructAttrs {
// This annotation only works if `enforce_order` is specified.
#[darling(default)]
skip_name_checks: bool,

// If true, then - if this field is missing from the UDT fields metadata
// - it will be initialized to Default::default().
// currently only supported with Flavor::MatchByName
#[darling(default)]
#[darling(rename = "allow_missing")]
default_when_missing: bool,
}

impl DeserializeCommonStructAttrs for StructAttrs {
Expand Down Expand Up @@ -51,6 +58,13 @@ struct Field {
#[darling(default)]
default_when_null: bool,

// If true, then - if this field is missing from the UDT fields metadata
// - it will be initialized to Default::default().
// currently only supported with Flavor::MatchByName
#[darling(default)]
#[darling(rename = "allow_missing")]
default_when_missing: bool,

ident: Option<syn::Ident>,
ty: syn::Type,
}
Expand Down Expand Up @@ -135,7 +149,7 @@ fn validate_attrs(attrs: &StructAttrs, fields: &[Field]) -> Result<(), darling::
impl Field {
// Returns whether this field is mandatory for deserialization.
fn is_required(&self) -> bool {
!self.skip
!self.skip && !self.default_when_missing
}

// The name of the column corresponding to this Rust struct field
Expand Down Expand Up @@ -209,13 +223,7 @@ impl TypeCheckAssumeOrderGenerator<'_> {
let macro_internal = self.0.struct_attrs().macro_internal_path();
let (frame_lifetime, metadata_lifetime) = self.0.constraint_lifetimes();

let required_fields_iter = || {
self.0
.fields()
.iter()
.enumerate()
.filter(|(_, f)| f.is_required())
};
let required_fields_iter = || self.0.fields().iter().enumerate().filter(|(_, f)| !f.skip);
let required_fields_count = required_fields_iter().count();
let required_fields_idents: Vec<_> = (0..required_fields_count)
.map(|i| quote::format_ident!("f_{}", i))
Expand Down Expand Up @@ -394,7 +402,7 @@ impl TypeCheckUnorderedGenerator<'_> {
let visited_flag = Self::visited_flag_variable(field);
let typ = field.deserialize_target();
let cql_name_literal = field.cql_name_literal();
let decrement_if_required: Option::<syn::Stmt> = field.is_required().then(|| parse_quote! {
let decrement_if_required: Option::<syn::Stmt> = (!self.0.attrs.default_when_missing && field.is_required()).then(|| parse_quote! {
remaining_required_fields -= 1;
});

Expand Down Expand Up @@ -467,7 +475,11 @@ impl TypeCheckUnorderedGenerator<'_> {
.iter()
.filter(|f| !f.skip)
.map(|f| f.cql_name_literal());
let field_count_lit = fields.iter().filter(|f| f.is_required()).count();
let field_count_lit = if self.0.attrs.default_when_missing {
0
} else {
fields.iter().filter(|f| f.is_required()).count()
};

parse_quote! {
fn type_check(
Expand Down Expand Up @@ -541,11 +553,18 @@ impl DeserializeUnorderedGenerator<'_> {

let deserialize_field = Self::deserialize_field_variable(field);
let cql_name_literal = field.cql_name_literal();
parse_quote! {
#deserialize_field.unwrap_or_else(|| ::std::panic!(
"column {} missing in DB row - type check should have prevented this!",
#cql_name_literal
))
if self.0.attrs.default_when_missing || field.default_when_missing {
// Generate Default::default if the field was missing
parse_quote! {
#deserialize_field.unwrap_or_default()
}
} else {
parse_quote! {
#deserialize_field.unwrap_or_else(|| ::std::panic!(
"column {} missing in DB row - type check should have prevented this!",
#cql_name_literal
))
}
}
}

Expand Down
21 changes: 16 additions & 5 deletions scylla-macros/src/deserialize/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ struct StructAttrs {
// they will be ignored. With true, an error will be raised.
#[darling(default)]
forbid_excess_udt_fields: bool,

// If true, then - if this field is missing from the UDT fields metadata
// - it will be initialized to Default::default().
#[darling(default)]
#[darling(rename = "allow_missing")]
default_when_missing: bool,
}

impl DeserializeCommonStructAttrs for StructAttrs {
Expand Down Expand Up @@ -229,7 +235,7 @@ impl TypeCheckAssumeOrderGenerator<'_> {
let (frame_lifetime, metadata_lifetime) = self.0.constraint_lifetimes();
let rust_field_name = field.cql_name_literal();
let rust_field_typ = field.deserialize_target();
let default_when_missing = field.default_when_missing;
let default_when_missing = self.0.attrs.default_when_missing || field.default_when_missing;
let skip_name_checks = self.0.attrs.skip_name_checks;

// Action performed in case of field name mismatch.
Expand Down Expand Up @@ -593,8 +599,8 @@ impl TypeCheckUnorderedGenerator<'_> {
let visited_flag = Self::visited_flag_variable(field);
let typ = field.deserialize_target();
let cql_name_literal = field.cql_name_literal();
let decrement_if_required: Option<syn::Stmt> = field
.is_required()
let decrement_if_required: Option<syn::Stmt> = (!self.0.attrs.default_when_missing && field
.is_required())
.then(|| parse_quote! {remaining_required_cql_fields -= 1;});

parse_quote! {
Expand Down Expand Up @@ -659,7 +665,12 @@ impl TypeCheckUnorderedGenerator<'_> {
.iter()
.filter(|f| !f.skip)
.map(|f| f.cql_name_literal());
let required_cql_field_count = rust_fields.iter().filter(|f| f.is_required()).count();
let required_cql_field_count = if self.0.attrs.default_when_missing {
0
} else {
rust_fields.iter().filter(|f| f.is_required()).count()
};

let required_cql_field_count_lit =
syn::LitInt::new(&required_cql_field_count.to_string(), Span::call_site());
let extract_cql_fields_expr = self.0.generate_extract_fields_from_type(parse_quote!(typ));
Expand Down Expand Up @@ -746,7 +757,7 @@ impl DeserializeUnorderedGenerator<'_> {
}

let deserialize_field = Self::deserialize_field_variable(field);
if field.default_when_missing {
if self.0.attrs.default_when_missing || field.default_when_missing {
// Generate Default::default if the field was missing
parse_quote! {
#deserialize_field.unwrap_or_default()
Expand Down
15 changes: 15 additions & 0 deletions scylla-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,11 @@ mod deserialize;
/// column into the first field, second column into the second field and so on.
/// It will still still verify that the column types and field types match.
///
/// #[scylla(allow_missing)]
///
/// if set, implementation will not fail if some columns are missing.
/// Instead, it will initialize the field with `Default::default()`.
///
/// ## Field attributes
///
/// `#[scylla(skip)]`
Expand All @@ -395,6 +400,11 @@ mod deserialize;
/// By default, the generated implementation will try to match the Rust field
/// to a column with the same name. This attribute allows to match to a column
/// with provided name.
///
/// #[scylla(allow_missing)]
///
/// if set, implementation will not fail if some columns are missing.
/// Instead, it will initialize the field with `Default::default()`.
#[proc_macro_derive(DeserializeRow, attributes(scylla))]
pub fn deserialize_row_derive(tokens_input: TokenStream) -> TokenStream {
match deserialize::row::deserialize_row_derive(tokens_input) {
Expand Down Expand Up @@ -501,6 +511,11 @@ pub fn deserialize_row_derive(tokens_input: TokenStream) -> TokenStream {
/// If more strictness is desired, this flag makes sure that no excess fields
/// are present and forces error in case there are some.
///
/// `#[scylla(allow_missing)]`
///
/// If the value of the field received from DB is null, the field will be
/// initialized with `Default::default()`.
///
/// ## Field attributes
///
/// `#[scylla(skip)]`
Expand Down
10 changes: 10 additions & 0 deletions scylla-macros/src/serialize/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ struct Attributes {
// This annotation only works if `enforce_order` flavor is specified.
#[darling(default)]
skip_name_checks: bool,

// Used for deserialization only. Ignored in serialization.
#[darling(default)]
#[darling(rename = "allow_missing")]
_default_when_missing: bool,
}

impl Attributes {
Expand Down Expand Up @@ -70,6 +75,11 @@ struct FieldAttributes {
#[darling(default)]
#[darling(rename = "default_when_null")]
_default_when_null: bool,

// Used for deserialization only. Ignored in serialization.
#[darling(default)]
#[darling(rename = "allow_missing")]
_default_when_missing: bool,
}

struct Context {
Expand Down
5 changes: 5 additions & 0 deletions scylla-macros/src/serialize/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ struct Attributes {
// the DB will interpret them as NULLs anyway.
#[darling(default)]
forbid_excess_udt_fields: bool,

// Used for deserialization only. Ignored in serialization.
#[darling(default)]
#[darling(rename = "allow_missing")]
_default_when_missing: bool,
}

impl Attributes {
Expand Down
36 changes: 36 additions & 0 deletions scylla/tests/integration/macros/hygiene.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,24 @@ macro_rules! test_crate {
g: ::core::primitive::i32,
}

// Test attributes for value struct with ordered flavor
#[derive(
_scylla::DeserializeValue, _scylla::SerializeValue, PartialEq, Debug,
)]
#[scylla(crate = _scylla, flavor = "enforce_order")]
#[scylla(allow_missing)]
struct TestStructOrderedAllowedMissing {
a: ::core::primitive::i32,
b: ::core::primitive::i32,
#[scylla(default_when_null)]
c: ::core::primitive::i32,
#[scylla(skip)]
d: ::core::primitive::i32,
#[scylla(rename = "f")]
e: ::core::primitive::i32,
g: ::core::primitive::i32,
}

// Test attributes for value struct with strict ordered flavor
#[derive(
_scylla::DeserializeValue, _scylla::SerializeValue, PartialEq, Debug,
Expand Down Expand Up @@ -304,6 +322,24 @@ macro_rules! test_crate {
c: ::core::primitive::i32,
#[scylla(default_when_null)]
d: ::core::primitive::i32,
#[scylla(allow_missing)]
e: ::core::primitive::i32,
}
// Test attributes for row struct with name flavor
#[derive(
_scylla::DeserializeRow, _scylla::SerializeRow, PartialEq, Debug,
)]
#[scylla(crate = _scylla)]
#[scylla(allow_missing)]
struct TestRowByNameWithMissing {
#[scylla(skip)]
a: ::core::primitive::i32,
#[scylla(rename = "f")]
b: ::core::primitive::i32,
c: ::core::primitive::i32,
#[scylla(default_when_null)]
d: ::core::primitive::i32,
e: ::core::primitive::i32,
}

// Test attributes for row struct with ordered flavor
Expand Down
Loading