diff --git a/scylla-cql/src/deserialize/row_tests.rs b/scylla-cql/src/deserialize/row_tests.rs index 10267b9a34..89e8abfce7 100644 --- a/scylla-cql/src/deserialize/row_tests.rs +++ b/scylla-cql/src/deserialize/row_tests.rs @@ -105,6 +105,8 @@ fn test_struct_deserialization_loose_ordering() { d: i32, #[scylla(default_when_null)] e: &'a str, + #[scylla(allow_missing)] + f: &'a str, } // Original order of columns @@ -124,6 +126,7 @@ fn test_struct_deserialization_loose_ordering() { c: String::new(), d: 0, e: "def", + f: "", } ); @@ -144,6 +147,7 @@ fn test_struct_deserialization_loose_ordering() { c: String::new(), d: 0, e: "", + f: "", } ); @@ -169,6 +173,90 @@ fn test_struct_deserialization_loose_ordering() { MyRow::type_check(specs).unwrap_err(); } +#[test] +fn test_struct_deserialization_loose_ordering_allow_missing() { + #[derive(DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate")] + #[scylla(allow_missing)] + struct MyRow<'a> { + a: &'a str, + b: Option, + #[scylla(skip)] + c: String, + #[scylla(default_when_null)] + d: i32, + #[scylla(default_when_null)] + e: &'a str, + f: &'a str, + g: &'a str, + } + + // Original order of columns + let specs = &[ + spec("a", ColumnType::Native(NativeType::Text)), + spec("b", ColumnType::Native(NativeType::Int)), + spec("d", ColumnType::Native(NativeType::Int)), + spec("e", ColumnType::Native(NativeType::Text)), + ]; + let byts = serialize_cells([val_str("abc"), val_int(123), None, val_str("def")]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + d: 0, + e: "def", + f: "", + g: "", + } + ); + + // Different order of columns - should still work + let specs = &[ + spec("e", ColumnType::Native(NativeType::Text)), + spec("b", ColumnType::Native(NativeType::Int)), + spec("d", ColumnType::Native(NativeType::Int)), + spec("a", ColumnType::Native(NativeType::Text)), + ]; + let byts = serialize_cells([None, val_int(123), None, val_str("abc")]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + d: 0, + e: "", + f: "", + g: "", + } + ); + + // Missing column + let specs = &[ + spec("a", ColumnType::Native(NativeType::Text)), + spec("e", ColumnType::Native(NativeType::Text)), + ]; + MyRow::type_check(specs).unwrap(); + + // Missing both default_when_null column + let specs = &[ + spec("a", ColumnType::Native(NativeType::Text)), + spec("b", ColumnType::Native(NativeType::Int)), + ]; + MyRow::type_check(specs).unwrap(); + + // Wrong column type + let specs = &[ + spec("a", ColumnType::Native(NativeType::Int)), + spec("b", ColumnType::Native(NativeType::Int)), + ]; + MyRow::type_check(specs).unwrap_err(); +} + #[test] fn test_struct_deserialization_strict_ordering() { #[derive(DeserializeRow, PartialEq, Eq, Debug)] diff --git a/scylla-macros/src/deserialize/row.rs b/scylla-macros/src/deserialize/row.rs index c86d423a49..88315594e8 100644 --- a/scylla-macros/src/deserialize/row.rs +++ b/scylla-macros/src/deserialize/row.rs @@ -24,6 +24,13 @@ struct StructAttrs { // This annotation only works if `enforce_order` is specified. #[darling(default)] skip_name_checks: bool, + + // If true, then - if this field is missing from the UDT fields metadata + // - it will be initialized to Default::default(). + // currently only supported with Flavor::MatchByName + #[darling(default)] + #[darling(rename = "allow_missing")] + default_when_missing: bool, } impl DeserializeCommonStructAttrs for StructAttrs { @@ -51,6 +58,13 @@ struct Field { #[darling(default)] default_when_null: bool, + // If true, then - if this field is missing from the UDT fields metadata + // - it will be initialized to Default::default(). + // currently only supported with Flavor::MatchByName + #[darling(default)] + #[darling(rename = "allow_missing")] + default_when_missing: bool, + ident: Option, ty: syn::Type, } @@ -135,7 +149,7 @@ fn validate_attrs(attrs: &StructAttrs, fields: &[Field]) -> Result<(), darling:: impl Field { // Returns whether this field is mandatory for deserialization. fn is_required(&self) -> bool { - !self.skip + !self.skip && !self.default_when_missing } // The name of the column corresponding to this Rust struct field @@ -209,13 +223,7 @@ impl TypeCheckAssumeOrderGenerator<'_> { let macro_internal = self.0.struct_attrs().macro_internal_path(); let (frame_lifetime, metadata_lifetime) = self.0.constraint_lifetimes(); - let required_fields_iter = || { - self.0 - .fields() - .iter() - .enumerate() - .filter(|(_, f)| f.is_required()) - }; + let required_fields_iter = || self.0.fields().iter().enumerate().filter(|(_, f)| !f.skip); let required_fields_count = required_fields_iter().count(); let required_fields_idents: Vec<_> = (0..required_fields_count) .map(|i| quote::format_ident!("f_{}", i)) @@ -394,7 +402,7 @@ impl TypeCheckUnorderedGenerator<'_> { let visited_flag = Self::visited_flag_variable(field); let typ = field.deserialize_target(); let cql_name_literal = field.cql_name_literal(); - let decrement_if_required: Option:: = field.is_required().then(|| parse_quote! { + let decrement_if_required: Option:: = (!self.0.attrs.default_when_missing && field.is_required()).then(|| parse_quote! { remaining_required_fields -= 1; }); @@ -467,7 +475,11 @@ impl TypeCheckUnorderedGenerator<'_> { .iter() .filter(|f| !f.skip) .map(|f| f.cql_name_literal()); - let field_count_lit = fields.iter().filter(|f| f.is_required()).count(); + let field_count_lit = if self.0.attrs.default_when_missing { + 0 + } else { + fields.iter().filter(|f| f.is_required()).count() + }; parse_quote! { fn type_check( @@ -541,11 +553,18 @@ impl DeserializeUnorderedGenerator<'_> { let deserialize_field = Self::deserialize_field_variable(field); let cql_name_literal = field.cql_name_literal(); - parse_quote! { - #deserialize_field.unwrap_or_else(|| ::std::panic!( - "column {} missing in DB row - type check should have prevented this!", - #cql_name_literal - )) + if self.0.attrs.default_when_missing || field.default_when_missing { + // Generate Default::default if the field was missing + parse_quote! { + #deserialize_field.unwrap_or_default() + } + } else { + parse_quote! { + #deserialize_field.unwrap_or_else(|| ::std::panic!( + "column {} missing in DB row - type check should have prevented this!", + #cql_name_literal + )) + } } } diff --git a/scylla-macros/src/deserialize/value.rs b/scylla-macros/src/deserialize/value.rs index ceacd7bd2c..85615a2666 100644 --- a/scylla-macros/src/deserialize/value.rs +++ b/scylla-macros/src/deserialize/value.rs @@ -31,6 +31,12 @@ struct StructAttrs { // they will be ignored. With true, an error will be raised. #[darling(default)] forbid_excess_udt_fields: bool, + + // If true, then - if this field is missing from the UDT fields metadata + // - it will be initialized to Default::default(). + #[darling(default)] + #[darling(rename = "allow_missing")] + default_when_missing: bool, } impl DeserializeCommonStructAttrs for StructAttrs { @@ -229,7 +235,7 @@ impl TypeCheckAssumeOrderGenerator<'_> { let (frame_lifetime, metadata_lifetime) = self.0.constraint_lifetimes(); let rust_field_name = field.cql_name_literal(); let rust_field_typ = field.deserialize_target(); - let default_when_missing = field.default_when_missing; + let default_when_missing = self.0.attrs.default_when_missing || field.default_when_missing; let skip_name_checks = self.0.attrs.skip_name_checks; // Action performed in case of field name mismatch. @@ -593,8 +599,8 @@ impl TypeCheckUnorderedGenerator<'_> { let visited_flag = Self::visited_flag_variable(field); let typ = field.deserialize_target(); let cql_name_literal = field.cql_name_literal(); - let decrement_if_required: Option = field - .is_required() + let decrement_if_required: Option = (!self.0.attrs.default_when_missing && field + .is_required()) .then(|| parse_quote! {remaining_required_cql_fields -= 1;}); parse_quote! { @@ -659,7 +665,12 @@ impl TypeCheckUnorderedGenerator<'_> { .iter() .filter(|f| !f.skip) .map(|f| f.cql_name_literal()); - let required_cql_field_count = rust_fields.iter().filter(|f| f.is_required()).count(); + let required_cql_field_count = if self.0.attrs.default_when_missing { + 0 + } else { + rust_fields.iter().filter(|f| f.is_required()).count() + }; + let required_cql_field_count_lit = syn::LitInt::new(&required_cql_field_count.to_string(), Span::call_site()); let extract_cql_fields_expr = self.0.generate_extract_fields_from_type(parse_quote!(typ)); @@ -746,7 +757,7 @@ impl DeserializeUnorderedGenerator<'_> { } let deserialize_field = Self::deserialize_field_variable(field); - if field.default_when_missing { + if self.0.attrs.default_when_missing || field.default_when_missing { // Generate Default::default if the field was missing parse_quote! { #deserialize_field.unwrap_or_default() diff --git a/scylla-macros/src/lib.rs b/scylla-macros/src/lib.rs index 05a8ddcf78..a1bdf04865 100644 --- a/scylla-macros/src/lib.rs +++ b/scylla-macros/src/lib.rs @@ -378,6 +378,11 @@ mod deserialize; /// column into the first field, second column into the second field and so on. /// It will still still verify that the column types and field types match. /// +/// #[scylla(allow_missing)] +/// +/// if set, implementation will not fail if some columns are missing. +/// Instead, it will initialize the field with `Default::default()`. +/// /// ## Field attributes /// /// `#[scylla(skip)]` @@ -395,6 +400,11 @@ mod deserialize; /// By default, the generated implementation will try to match the Rust field /// to a column with the same name. This attribute allows to match to a column /// with provided name. +/// +/// #[scylla(allow_missing)] +/// +/// if set, implementation will not fail if some columns are missing. +/// Instead, it will initialize the field with `Default::default()`. #[proc_macro_derive(DeserializeRow, attributes(scylla))] pub fn deserialize_row_derive(tokens_input: TokenStream) -> TokenStream { match deserialize::row::deserialize_row_derive(tokens_input) { @@ -501,6 +511,11 @@ pub fn deserialize_row_derive(tokens_input: TokenStream) -> TokenStream { /// If more strictness is desired, this flag makes sure that no excess fields /// are present and forces error in case there are some. /// +/// `#[scylla(allow_missing)]` +/// +/// If the value of the field received from DB is null, the field will be +/// initialized with `Default::default()`. +/// /// ## Field attributes /// /// `#[scylla(skip)]` diff --git a/scylla-macros/src/serialize/row.rs b/scylla-macros/src/serialize/row.rs index e871dcb6ee..dec6a42323 100644 --- a/scylla-macros/src/serialize/row.rs +++ b/scylla-macros/src/serialize/row.rs @@ -22,6 +22,11 @@ struct Attributes { // This annotation only works if `enforce_order` flavor is specified. #[darling(default)] skip_name_checks: bool, + + // Used for deserialization only. Ignored in serialization. + #[darling(default)] + #[darling(rename = "allow_missing")] + _default_when_missing: bool, } impl Attributes { @@ -70,6 +75,11 @@ struct FieldAttributes { #[darling(default)] #[darling(rename = "default_when_null")] _default_when_null: bool, + + // Used for deserialization only. Ignored in serialization. + #[darling(default)] + #[darling(rename = "allow_missing")] + _default_when_missing: bool, } struct Context { diff --git a/scylla-macros/src/serialize/value.rs b/scylla-macros/src/serialize/value.rs index 3d16e79bd6..4b0ac52da5 100644 --- a/scylla-macros/src/serialize/value.rs +++ b/scylla-macros/src/serialize/value.rs @@ -30,6 +30,11 @@ struct Attributes { // the DB will interpret them as NULLs anyway. #[darling(default)] forbid_excess_udt_fields: bool, + + // Used for deserialization only. Ignored in serialization. + #[darling(default)] + #[darling(rename = "allow_missing")] + _default_when_missing: bool, } impl Attributes { diff --git a/scylla/tests/integration/macros/hygiene.rs b/scylla/tests/integration/macros/hygiene.rs index d8ed80f112..a16e4d57ce 100644 --- a/scylla/tests/integration/macros/hygiene.rs +++ b/scylla/tests/integration/macros/hygiene.rs @@ -239,6 +239,24 @@ macro_rules! test_crate { g: ::core::primitive::i32, } + // Test attributes for value struct with ordered flavor + #[derive( + _scylla::DeserializeValue, _scylla::SerializeValue, PartialEq, Debug, + )] + #[scylla(crate = _scylla, flavor = "enforce_order")] + #[scylla(allow_missing)] + struct TestStructOrderedAllowedMissing { + a: ::core::primitive::i32, + b: ::core::primitive::i32, + #[scylla(default_when_null)] + c: ::core::primitive::i32, + #[scylla(skip)] + d: ::core::primitive::i32, + #[scylla(rename = "f")] + e: ::core::primitive::i32, + g: ::core::primitive::i32, + } + // Test attributes for value struct with strict ordered flavor #[derive( _scylla::DeserializeValue, _scylla::SerializeValue, PartialEq, Debug, @@ -304,6 +322,24 @@ macro_rules! test_crate { c: ::core::primitive::i32, #[scylla(default_when_null)] d: ::core::primitive::i32, + #[scylla(allow_missing)] + e: ::core::primitive::i32, + } + // Test attributes for row struct with name flavor + #[derive( + _scylla::DeserializeRow, _scylla::SerializeRow, PartialEq, Debug, + )] + #[scylla(crate = _scylla)] + #[scylla(allow_missing)] + struct TestRowByNameWithMissing { + #[scylla(skip)] + a: ::core::primitive::i32, + #[scylla(rename = "f")] + b: ::core::primitive::i32, + c: ::core::primitive::i32, + #[scylla(default_when_null)] + d: ::core::primitive::i32, + e: ::core::primitive::i32, } // Test attributes for row struct with ordered flavor