44use crate :: {
55 error:: { Error , Result } ,
66 format:: { ContainerFormat , ContainerFormatEntry , Format , FormatHolder , Named , VariantFormat } ,
7- trace:: { Samples , Tracer } ,
7+ trace:: { EnumProgress , Samples , Tracer , VariantId } ,
88 value:: IntoSeqDeserializer ,
99} ;
10- use serde:: de:: { self , DeserializeSeed , IntoDeserializer , Visitor } ;
11- use std:: collections:: BTreeMap ;
10+ use erased_discriminant:: Discriminant ;
11+ use serde:: de:: {
12+ self ,
13+ value:: { BorrowedStrDeserializer , U32Deserializer } ,
14+ DeserializeSeed , IntoDeserializer , Visitor ,
15+ } ;
16+ use std:: collections:: btree_map:: { BTreeMap , Entry } ;
1217
1318/// Deserialize a single value.
1419/// * The lifetime 'a is set by the deserialization call site and the
@@ -391,55 +396,151 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> {
391396
392397 // Assumption: The first variant(s) should be "base cases", i.e. not cause infinite recursion
393398 // while constructing sample values.
399+ #[ allow( clippy:: map_entry) ] // false positive https://github.com/rust-lang/rust-clippy/issues/9470
394400 fn deserialize_enum < V > (
395401 self ,
396- name : & ' static str ,
402+ enum_name : & ' static str ,
397403 variants : & ' static [ & ' static str ] ,
398404 visitor : V ,
399405 ) -> Result < V :: Value >
400406 where
401407 V : Visitor < ' de > ,
402408 {
403- self . format . unify ( Format :: TypeName ( name. into ( ) ) ) ?;
409+ if variants. is_empty ( ) {
410+ return Err ( Error :: NotSupported ( "deserialize_enum with 0 variants" ) ) ;
411+ }
412+
413+ let enum_type_id = typeid:: of :: < V :: Value > ( ) ;
414+ self . format . unify ( Format :: TypeName ( enum_name. into ( ) ) ) ?;
404415 // Pre-update the registry.
405416 self . tracer
406417 . registry
407- . entry ( name . to_string ( ) )
418+ . entry ( enum_name . to_string ( ) )
408419 . unify ( ContainerFormat :: Enum ( BTreeMap :: new ( ) ) ) ?;
409- let known_variants = match self . tracer . registry . get_mut ( name ) {
420+ let known_variants = match self . tracer . registry . get_mut ( enum_name ) {
410421 Some ( ContainerFormat :: Enum ( x) ) => x,
411422 _ => unreachable ! ( ) ,
412423 } ;
413- // If we have found all the variants OR if the enum is marked as
414- // incomplete already, pick the first index.
415- let index = if known_variants. len ( ) == variants. len ( )
416- || self . tracer . incomplete_enums . contains ( name)
417- {
418- 0
419- } else {
420- let mut index = known_variants. len ( ) as u32 ;
421- // Scan the range 0..=known_variants.len() downwards to find the next
422- // variant index to explore.
423- while known_variants. contains_key ( & index) {
424- index -= 1 ;
424+
425+ // If the enum is marked as incomplete, just visit the first index
426+ // because we presume it avoids recursion.
427+ if self . tracer . incomplete_enums . contains_key ( enum_name) {
428+ return visitor. visit_enum ( EnumDeserializer :: new (
429+ self . tracer ,
430+ self . samples ,
431+ VariantId :: Index ( 0 ) ,
432+ & mut VariantFormat :: unknown ( ) ,
433+ ) ) ;
434+ }
435+
436+ // First visit each of the variants by name according to `variants`.
437+ // Later revisit them by u32 index until an index matching each of the
438+ // named variants has been determined.
439+ let provisional_min = u32:: MAX - ( variants. len ( ) - 1 ) as u32 ;
440+ for ( i, & variant_name) in variants. iter ( ) . enumerate ( ) {
441+ if !self
442+ . tracer
443+ . discriminants
444+ . contains_key ( & ( enum_type_id, VariantId :: Name ( variant_name) ) )
445+ {
446+ // Insert into known_variants with a provisional index.
447+ let provisional_index = provisional_min + i as u32 ;
448+ let variant = known_variants
449+ . entry ( provisional_index)
450+ . or_insert_with ( || Named {
451+ name : variant_name. to_owned ( ) ,
452+ value : VariantFormat :: unknown ( ) ,
453+ } ) ;
454+ self . tracer
455+ . incomplete_enums
456+ . insert ( enum_name. into ( ) , EnumProgress :: NamedVariantsRemaining ) ;
457+ // Compute the discriminant and format for this variant.
458+ let mut value = variant. value . clone ( ) ;
459+ let enum_value = visitor. visit_enum ( EnumDeserializer :: new (
460+ self . tracer ,
461+ self . samples ,
462+ VariantId :: Name ( variant_name) ,
463+ & mut value,
464+ ) ) ?;
465+ let discriminant = Discriminant :: of ( & enum_value) ;
466+ self . tracer
467+ . discriminants
468+ . insert ( ( enum_type_id, VariantId :: Name ( variant_name) ) , discriminant) ;
469+ return Ok ( enum_value) ;
425470 }
426- index
471+ }
472+
473+ // We know the discriminant for every variant name. Now visit them again
474+ // by index to find the u32 id that goes with each name.
475+ //
476+ // If there are no provisional entries waiting for an index, just go
477+ // with index 0.
478+ let mut index = 0 ;
479+ if known_variants. range ( provisional_min..) . next ( ) . is_some ( ) {
480+ self . tracer
481+ . incomplete_enums
482+ . insert ( enum_name. into ( ) , EnumProgress :: IndexedVariantsRemaining ) ;
483+ while known_variants. contains_key ( & index)
484+ && self
485+ . tracer
486+ . discriminants
487+ . contains_key ( & ( enum_type_id, VariantId :: Index ( index) ) )
488+ {
489+ index += 1 ;
490+ }
491+ }
492+
493+ // Compute the discriminant and format for this variant.
494+ let mut value = VariantFormat :: unknown ( ) ;
495+ let enum_value = visitor. visit_enum ( EnumDeserializer :: new (
496+ self . tracer ,
497+ self . samples ,
498+ VariantId :: Index ( index) ,
499+ & mut value,
500+ ) ) ?;
501+ let discriminant = Discriminant :: of ( & enum_value) ;
502+ self . tracer . discriminants . insert (
503+ ( enum_type_id, VariantId :: Index ( index) ) ,
504+ discriminant. clone ( ) ,
505+ ) ;
506+ self . tracer . incomplete_enums . remove ( enum_name) ;
507+
508+ // Rewrite provisional entries for which we now know a u32 index.
509+ let known_variants = match self . tracer . registry . get_mut ( enum_name) {
510+ Some ( ContainerFormat :: Enum ( x) ) => x,
511+ _ => unreachable ! ( ) ,
427512 } ;
428- let variant = known_variants. entry ( index) . or_insert_with ( || Named {
429- name : ( * variants
430- . get ( index as usize )
431- . expect ( "variant indexes must be a non-empty range 0..variants.len()" ) )
432- . to_string ( ) ,
433- value : VariantFormat :: unknown ( ) ,
434- } ) ;
435- let mut value = variant. value . clone ( ) ;
436- // Mark the enum as incomplete if this was not the last variant to explore.
437- if known_variants. len ( ) != variants. len ( ) {
438- self . tracer . incomplete_enums . insert ( name. into ( ) ) ;
513+ for provisional_index in provisional_min..=u32:: MAX {
514+ if let Entry :: Occupied ( provisional_entry) = known_variants. entry ( provisional_index) {
515+ if self . tracer . discriminants
516+ [ & ( enum_type_id, VariantId :: Name ( & provisional_entry. get ( ) . name ) ) ]
517+ == discriminant
518+ {
519+ let provisional_entry = provisional_entry. remove ( ) ;
520+ match known_variants. entry ( index) {
521+ Entry :: Vacant ( vacant) => {
522+ vacant. insert ( provisional_entry) ;
523+ }
524+ Entry :: Occupied ( mut existing_entry) => {
525+ // Discard the provisional entry's name and just
526+ // keep the existing one.
527+ existing_entry
528+ . get_mut ( )
529+ . value
530+ . unify ( provisional_entry. value ) ?;
531+ }
532+ }
533+ } else {
534+ self . tracer
535+ . incomplete_enums
536+ . insert ( enum_name. into ( ) , EnumProgress :: IndexedVariantsRemaining ) ;
537+ }
538+ }
539+ }
540+ if let Some ( existing_entry) = known_variants. get_mut ( & index) {
541+ existing_entry. value . unify ( value) ?;
439542 }
440- // Compute the format for this variant.
441- let inner = EnumDeserializer :: new ( self . tracer , self . samples , index, & mut value) ;
442- visitor. visit_enum ( inner)
543+ Ok ( enum_value)
443544 }
444545
445546 fn deserialize_identifier < V > ( self , _visitor : V ) -> Result < V :: Value >
@@ -539,21 +640,21 @@ where
539640struct EnumDeserializer < ' de , ' a > {
540641 tracer : & ' a mut Tracer ,
541642 samples : & ' de Samples ,
542- index : u32 ,
643+ variant_id : VariantId < ' static > ,
543644 format : & ' a mut VariantFormat ,
544645}
545646
546647impl < ' de , ' a > EnumDeserializer < ' de , ' a > {
547648 fn new (
548649 tracer : & ' a mut Tracer ,
549650 samples : & ' de Samples ,
550- index : u32 ,
651+ variant_id : VariantId < ' static > ,
551652 format : & ' a mut VariantFormat ,
552653 ) -> Self {
553654 Self {
554655 tracer,
555656 samples,
556- index ,
657+ variant_id ,
557658 format,
558659 }
559660 }
@@ -567,8 +668,10 @@ impl<'de, 'a> de::EnumAccess<'de> for EnumDeserializer<'de, 'a> {
567668 where
568669 V : DeserializeSeed < ' de > ,
569670 {
570- let index = self . index ;
571- let value = seed. deserialize ( index. into_deserializer ( ) ) ?;
671+ let value = match self . variant_id {
672+ VariantId :: Index ( index) => seed. deserialize ( U32Deserializer :: new ( index) ) ,
673+ VariantId :: Name ( name) => seed. deserialize ( BorrowedStrDeserializer :: new ( name) ) ,
674+ } ?;
572675 Ok ( ( value, self ) )
573676 }
574677}
0 commit comments