diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index 4967e35d1..ad7508173 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -1,5 +1,5 @@ use crate::fragment::{Expr, Fragment, Match, Stmts}; -use crate::internals::ast::{Container, Data, Field, Style, Variant}; +use crate::internals::ast::{Container, Data, Discriminant, Field, Style, Variant}; use crate::internals::name::Name; use crate::internals::{attr, replace_receiver, ungroup, Ctxt, Derive}; use crate::{bound, dummy, pretend, this}; @@ -981,6 +981,7 @@ fn deserialize_struct( .map(|(i, field)| FieldWithAliases { ident: field_i(i), aliases: field.attrs.aliases(), + discriminant: None, }) .collect(); @@ -1143,6 +1144,7 @@ fn deserialize_struct_in_place( .map(|(i, field)| FieldWithAliases { ident: field_i(i), aliases: field.attrs.aliases(), + discriminant: None, }) .collect(); @@ -1239,7 +1241,10 @@ fn deserialize_homogeneous_enum( } } -fn prepare_enum_variant_enum(variants: &[Variant]) -> (TokenStream, Stmts) { +fn prepare_enum_variant_enum( + variants: &[Variant], + cattrs: &attr::Container, +) -> (TokenStream, Stmts) { let deserialized_variants = variants .iter() .enumerate() @@ -1263,10 +1268,32 @@ fn prepare_enum_variant_enum(variants: &[Variant]) -> (TokenStream, Stmts) { } }; + let mut discr = 0; let deserialized_variants: Vec<_> = deserialized_variants - .map(|(i, variant)| FieldWithAliases { - ident: field_i(i), - aliases: variant.attrs.aliases(), + .map(|(i, variant)| { + let discriminant = if cattrs.explicit_tags() { + match &variant.discriminant { + Discriminant::None => {} + Discriminant::Explicit(d) => { + discr = *d; + } + Discriminant::Other(_expr) => { + // TODO: handle error properly + panic!("unsupported discriminant expression"); + } + } + Some(discr) + } else { + None + }; + + discr += 1; + + FieldWithAliases { + ident: field_i(i), + aliases: variant.attrs.aliases(), + discriminant, + } }) .collect(); @@ -1295,7 +1322,7 @@ fn deserialize_externally_tagged_enum( let expecting = format!("enum {}", params.type_name()); let expecting = cattrs.expecting().unwrap_or(&expecting); - let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants); + let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs); // Match arms to extract a variant from a string let variant_arms = variants @@ -1381,7 +1408,7 @@ fn deserialize_internally_tagged_enum( cattrs: &attr::Container, tag: &str, ) -> Fragment { - let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants); + let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs); // Match arms to extract a variant from a string let variant_arms = variants @@ -1435,7 +1462,7 @@ fn deserialize_adjacently_tagged_enum( split_with_de_lifetime(params); let delife = params.borrowed.de_lifetime(); - let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants); + let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs); let variant_arms: &Vec<_> = &variants .iter() @@ -2014,6 +2041,7 @@ fn deserialize_untagged_newtype_variant( struct FieldWithAliases<'a> { ident: Ident, aliases: &'a BTreeSet, + discriminant: Option, } fn deserialize_generated_identifier( @@ -2159,6 +2187,7 @@ fn deserialize_custom_identifier( .map(|variant| FieldWithAliases { ident: variant.ident.clone(), aliases: variant.attrs.aliases(), + discriminant: None, }) .collect(); @@ -2398,7 +2427,7 @@ fn deserialize_identifier( } } else { let u64_mapping = deserialized_fields.iter().enumerate().map(|(i, field)| { - let i = i as u64; + let i = field.discriminant.map_or(i as u64, |d| d as u64); let ident = &field.ident; quote!(#i => _serde::__private::Ok(#this_value::#ident)) }); diff --git a/serde_derive/src/internals/ast.rs b/serde_derive/src/internals/ast.rs index 3293823a7..38672c8c5 100644 --- a/serde_derive/src/internals/ast.rs +++ b/serde_derive/src/internals/ast.rs @@ -2,7 +2,7 @@ use crate::internals::{attr, check, Ctxt, Derive}; use syn::punctuated::Punctuated; -use syn::Token; +use syn::{Expr, Token}; /// A source data structure annotated with `#[derive(Serialize)]` and/or `#[derive(Deserialize)]`, /// parsed into an internal representation. @@ -31,11 +31,22 @@ pub enum Data<'a> { pub struct Variant<'a> { pub ident: syn::Ident, pub attrs: attr::Variant, + pub discriminant: Discriminant<'a>, pub style: Style, pub fields: Vec>, pub original: &'a syn::Variant, } +/// Optional discriminant of an enum variant, used to override `variant_index`. +pub enum Discriminant<'a> { + /// No explicit discriminant. + None, + /// An explicit integer discriminant. + Explicit(u32), + /// An explicit discriminant that cannot be used for `variant_index`. + Other(&'a Expr), +} + /// A field of a struct. pub struct Field<'a> { pub member: syn::Member, @@ -134,11 +145,29 @@ fn enum_from_ast<'a>( .iter() .map(|variant| { let attrs = attr::Variant::from_ast(cx, variant); + let discriminant = if let Some((_eq, expr)) = &variant.discriminant { + if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Int(int), + .. + }) = expr + { + if let Ok(n) = int.base10_parse() { + Discriminant::Explicit(n) + } else { + Discriminant::Other(expr) + } + } else { + Discriminant::Other(expr) + } + } else { + Discriminant::None + }; let (style, fields) = struct_from_ast(cx, &variant.fields, Some(&attrs), container_default); Variant { ident: variant.ident.clone(), attrs, + discriminant, style, fields, original: variant, diff --git a/serde_derive/src/internals/attr.rs b/serde_derive/src/internals/attr.rs index 6d846ed01..5c511b45d 100644 --- a/serde_derive/src/internals/attr.rs +++ b/serde_derive/src/internals/attr.rs @@ -173,6 +173,7 @@ pub struct Container { /// Error message generated when type can't be deserialized expecting: Option, non_exhaustive: bool, + explicit_tags: bool, } /// Styles of representing an enum. @@ -258,6 +259,7 @@ impl Container { let mut variant_identifier = BoolAttr::none(cx, VARIANT_IDENTIFIER); let mut serde_path = Attr::none(cx, CRATE); let mut expecting = Attr::none(cx, EXPECTING); + let mut explicit_tags = BoolAttr::none(cx, EXPLICIT_TAGS); let mut non_exhaustive = false; for attr in &item.attrs { @@ -491,6 +493,9 @@ impl Container { if let Some(s) = get_lit_str(cx, EXPECTING, &meta)? { expecting.set(&meta.path, s.value()); } + } else if meta.path == EXPLICIT_TAGS { + // #[serde(explicit_tags)] + explicit_tags.set_true(meta.path); } else { let path = meta.path.to_token_stream().to_string().replace(' ', ""); return Err( @@ -542,6 +547,7 @@ impl Container { is_packed, expecting: expecting.get(), non_exhaustive, + explicit_tags: explicit_tags.get(), } } @@ -623,6 +629,10 @@ impl Container { pub fn non_exhaustive(&self) -> bool { self.non_exhaustive } + + pub fn explicit_tags(&self) -> bool { + self.explicit_tags + } } fn decide_tag( diff --git a/serde_derive/src/internals/symbol.rs b/serde_derive/src/internals/symbol.rs index 59ef8de7c..175780632 100644 --- a/serde_derive/src/internals/symbol.rs +++ b/serde_derive/src/internals/symbol.rs @@ -14,6 +14,7 @@ pub const DENY_UNKNOWN_FIELDS: Symbol = Symbol("deny_unknown_fields"); pub const DESERIALIZE: Symbol = Symbol("deserialize"); pub const DESERIALIZE_WITH: Symbol = Symbol("deserialize_with"); pub const EXPECTING: Symbol = Symbol("expecting"); +pub const EXPLICIT_TAGS: Symbol = Symbol("explicit_tags"); pub const FIELD_IDENTIFIER: Symbol = Symbol("field_identifier"); pub const FLATTEN: Symbol = Symbol("flatten"); pub const FROM: Symbol = Symbol("from"); diff --git a/serde_derive/src/ser.rs b/serde_derive/src/ser.rs index 46be736c4..31527412a 100644 --- a/serde_derive/src/ser.rs +++ b/serde_derive/src/ser.rs @@ -1,5 +1,5 @@ use crate::fragment::{Fragment, Match, Stmts}; -use crate::internals::ast::{Container, Data, Field, Style, Variant}; +use crate::internals::ast::{Container, Data, Discriminant, Field, Style, Variant}; use crate::internals::name::Name; use crate::internals::{attr, replace_receiver, Ctxt, Derive}; use crate::{bound, dummy, pretend, this}; @@ -17,14 +17,15 @@ pub fn expand_derive_serialize(input: &mut syn::DeriveInput) -> syn::Result return Err(ctxt.check().unwrap_err()), }; precondition(&ctxt, &cont); - ctxt.check()?; let ident = &cont.ident; let params = Parameters::new(&cont); let (impl_generics, ty_generics, where_clause) = params.generics.split_for_impl(); - let body = Stmts(serialize_body(&cont, ¶ms)); + let body = Stmts(serialize_body(&ctxt, &cont, ¶ms)); let serde = cont.attrs.serde_path(); + ctxt.check()?; + let impl_block = if let Some(remote) = cont.attrs.remote() { let vis = &input.vis; let used = pretend::pretend_used(&cont, params.is_packed); @@ -165,14 +166,14 @@ fn needs_serialize_bound(field: &attr::Field, variant: Option<&attr::Variant>) - }) } -fn serialize_body(cont: &Container, params: &Parameters) -> Fragment { +fn serialize_body(cx: &Ctxt, cont: &Container, params: &Parameters) -> Fragment { if cont.attrs.transparent() { serialize_transparent(cont, params) } else if let Some(type_into) = cont.attrs.type_into() { serialize_into(params, type_into) } else { match &cont.data { - Data::Enum(variants) => serialize_enum(params, variants, &cont.attrs), + Data::Enum(variants) => serialize_enum(cx, params, variants, &cont.attrs), Data::Struct(Style::Struct, fields) => serialize_struct(params, fields, &cont.attrs), Data::Struct(Style::Tuple, fields) => { serialize_tuple_struct(params, fields, &cont.attrs) @@ -389,16 +390,40 @@ fn serialize_struct_as_map( } } -fn serialize_enum(params: &Parameters, variants: &[Variant], cattrs: &attr::Container) -> Fragment { +fn serialize_enum( + cx: &Ctxt, + params: &Parameters, + variants: &[Variant], + cattrs: &attr::Container, +) -> Fragment { assert!(variants.len() as u64 <= u64::from(u32::MAX)); let self_var = ¶ms.self_var; + let mut discriminant = 0; + let mut arms: Vec<_> = variants .iter() .enumerate() .map(|(variant_index, variant)| { - serialize_variant(params, variant, variant_index as u32, cattrs) + let variant_index = if cattrs.explicit_tags() { + match &variant.discriminant { + Discriminant::None => {} + Discriminant::Explicit(d) => { + discriminant = *d; + } + Discriminant::Other(expr) => { + cx.error_spanned_by(expr, "unsupported expression for enum discriminant"); + } + } + discriminant + } else { + variant_index as u32 + }; + + discriminant += 1; + + serialize_variant(params, variant, variant_index, cattrs) }) .collect();