Skip to content

Commit 05da6e6

Browse files
authored
Support enum variants that have aliases (#53)
* Support enum variants that have aliases * Return error instead of panic on enum with 0 variants Before PR 53: thread 'main' panicked at /git/tmp/serde-reflection/serde-reflection/src/de.rs:431:18: variant indexes must be a non-empty range 0..variants.len() First draft of PR 53, debug mode (overflow checks): thread 'main' panicked at /git/tmp/serde-reflection/serde-reflection/src/de.rs:435:42: attempt to subtract with overflow First draft of PR 53, release mode: Failed to deserialize value: "invalid value: integer `0`, expected variant index 0 <= i < 0" This commit: Not supported: deserialize_enum with 0 variants
1 parent d1a1ff0 commit 05da6e6

File tree

5 files changed

+207
-56
lines changed

5 files changed

+207
-56
lines changed

Cargo.lock

Lines changed: 25 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

serde-reflection/Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ exclude = [
1717
]
1818

1919
[dependencies]
20-
thiserror = "1.0.25"
21-
serde = { version = "1.0.126", features = ["derive"] }
20+
erased-discriminant = "1"
2221
once_cell = "1.7.2"
22+
serde = { version = "1.0.126", features = ["derive"] }
23+
thiserror = "1.0.25"
24+
typeid = "1"
2325

2426
[dev-dependencies]
2527
bincode = "1.3.3"

serde-reflection/src/de.rs

Lines changed: 142 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@
44
use crate::{
55
error::{Error, Result},
66
format::{ContainerFormat, ContainerFormatEntry, Format, FormatHolder, Named, VariantFormat},
7-
trace::{Samples, Tracer},
7+
trace::{EnumProgress, Samples, Tracer, VariantId},
88
value::IntoSeqDeserializer,
99
};
10-
use serde::de::{self, DeserializeSeed, IntoDeserializer, Visitor};
11-
use std::collections::BTreeMap;
10+
use erased_discriminant::Discriminant;
11+
use serde::de::{
12+
self,
13+
value::{BorrowedStrDeserializer, U32Deserializer},
14+
DeserializeSeed, IntoDeserializer, Visitor,
15+
};
16+
use std::collections::btree_map::{BTreeMap, Entry};
1217

1318
/// Deserialize a single value.
1419
/// * The lifetime 'a is set by the deserialization call site and the
@@ -391,55 +396,151 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> {
391396

392397
// Assumption: The first variant(s) should be "base cases", i.e. not cause infinite recursion
393398
// while constructing sample values.
399+
#[allow(clippy::map_entry)] // false positive https://github.com/rust-lang/rust-clippy/issues/9470
394400
fn deserialize_enum<V>(
395401
self,
396-
name: &'static str,
402+
enum_name: &'static str,
397403
variants: &'static [&'static str],
398404
visitor: V,
399405
) -> Result<V::Value>
400406
where
401407
V: Visitor<'de>,
402408
{
403-
self.format.unify(Format::TypeName(name.into()))?;
409+
if variants.is_empty() {
410+
return Err(Error::NotSupported("deserialize_enum with 0 variants"));
411+
}
412+
413+
let enum_type_id = typeid::of::<V::Value>();
414+
self.format.unify(Format::TypeName(enum_name.into()))?;
404415
// Pre-update the registry.
405416
self.tracer
406417
.registry
407-
.entry(name.to_string())
418+
.entry(enum_name.to_string())
408419
.unify(ContainerFormat::Enum(BTreeMap::new()))?;
409-
let known_variants = match self.tracer.registry.get_mut(name) {
420+
let known_variants = match self.tracer.registry.get_mut(enum_name) {
410421
Some(ContainerFormat::Enum(x)) => x,
411422
_ => unreachable!(),
412423
};
413-
// If we have found all the variants OR if the enum is marked as
414-
// incomplete already, pick the first index.
415-
let index = if known_variants.len() == variants.len()
416-
|| self.tracer.incomplete_enums.contains(name)
417-
{
418-
0
419-
} else {
420-
let mut index = known_variants.len() as u32;
421-
// Scan the range 0..=known_variants.len() downwards to find the next
422-
// variant index to explore.
423-
while known_variants.contains_key(&index) {
424-
index -= 1;
424+
425+
// If the enum is marked as incomplete, just visit the first index
426+
// because we presume it avoids recursion.
427+
if self.tracer.incomplete_enums.contains_key(enum_name) {
428+
return visitor.visit_enum(EnumDeserializer::new(
429+
self.tracer,
430+
self.samples,
431+
VariantId::Index(0),
432+
&mut VariantFormat::unknown(),
433+
));
434+
}
435+
436+
// First visit each of the variants by name according to `variants`.
437+
// Later revisit them by u32 index until an index matching each of the
438+
// named variants has been determined.
439+
let provisional_min = u32::MAX - (variants.len() - 1) as u32;
440+
for (i, &variant_name) in variants.iter().enumerate() {
441+
if !self
442+
.tracer
443+
.discriminants
444+
.contains_key(&(enum_type_id, VariantId::Name(variant_name)))
445+
{
446+
// Insert into known_variants with a provisional index.
447+
let provisional_index = provisional_min + i as u32;
448+
let variant = known_variants
449+
.entry(provisional_index)
450+
.or_insert_with(|| Named {
451+
name: variant_name.to_owned(),
452+
value: VariantFormat::unknown(),
453+
});
454+
self.tracer
455+
.incomplete_enums
456+
.insert(enum_name.into(), EnumProgress::NamedVariantsRemaining);
457+
// Compute the discriminant and format for this variant.
458+
let mut value = variant.value.clone();
459+
let enum_value = visitor.visit_enum(EnumDeserializer::new(
460+
self.tracer,
461+
self.samples,
462+
VariantId::Name(variant_name),
463+
&mut value,
464+
))?;
465+
let discriminant = Discriminant::of(&enum_value);
466+
self.tracer
467+
.discriminants
468+
.insert((enum_type_id, VariantId::Name(variant_name)), discriminant);
469+
return Ok(enum_value);
425470
}
426-
index
471+
}
472+
473+
// We know the discriminant for every variant name. Now visit them again
474+
// by index to find the u32 id that goes with each name.
475+
//
476+
// If there are no provisional entries waiting for an index, just go
477+
// with index 0.
478+
let mut index = 0;
479+
if known_variants.range(provisional_min..).next().is_some() {
480+
self.tracer
481+
.incomplete_enums
482+
.insert(enum_name.into(), EnumProgress::IndexedVariantsRemaining);
483+
while known_variants.contains_key(&index)
484+
&& self
485+
.tracer
486+
.discriminants
487+
.contains_key(&(enum_type_id, VariantId::Index(index)))
488+
{
489+
index += 1;
490+
}
491+
}
492+
493+
// Compute the discriminant and format for this variant.
494+
let mut value = VariantFormat::unknown();
495+
let enum_value = visitor.visit_enum(EnumDeserializer::new(
496+
self.tracer,
497+
self.samples,
498+
VariantId::Index(index),
499+
&mut value,
500+
))?;
501+
let discriminant = Discriminant::of(&enum_value);
502+
self.tracer.discriminants.insert(
503+
(enum_type_id, VariantId::Index(index)),
504+
discriminant.clone(),
505+
);
506+
self.tracer.incomplete_enums.remove(enum_name);
507+
508+
// Rewrite provisional entries for which we now know a u32 index.
509+
let known_variants = match self.tracer.registry.get_mut(enum_name) {
510+
Some(ContainerFormat::Enum(x)) => x,
511+
_ => unreachable!(),
427512
};
428-
let variant = known_variants.entry(index).or_insert_with(|| Named {
429-
name: (*variants
430-
.get(index as usize)
431-
.expect("variant indexes must be a non-empty range 0..variants.len()"))
432-
.to_string(),
433-
value: VariantFormat::unknown(),
434-
});
435-
let mut value = variant.value.clone();
436-
// Mark the enum as incomplete if this was not the last variant to explore.
437-
if known_variants.len() != variants.len() {
438-
self.tracer.incomplete_enums.insert(name.into());
513+
for provisional_index in provisional_min..=u32::MAX {
514+
if let Entry::Occupied(provisional_entry) = known_variants.entry(provisional_index) {
515+
if self.tracer.discriminants
516+
[&(enum_type_id, VariantId::Name(&provisional_entry.get().name))]
517+
== discriminant
518+
{
519+
let provisional_entry = provisional_entry.remove();
520+
match known_variants.entry(index) {
521+
Entry::Vacant(vacant) => {
522+
vacant.insert(provisional_entry);
523+
}
524+
Entry::Occupied(mut existing_entry) => {
525+
// Discard the provisional entry's name and just
526+
// keep the existing one.
527+
existing_entry
528+
.get_mut()
529+
.value
530+
.unify(provisional_entry.value)?;
531+
}
532+
}
533+
} else {
534+
self.tracer
535+
.incomplete_enums
536+
.insert(enum_name.into(), EnumProgress::IndexedVariantsRemaining);
537+
}
538+
}
539+
}
540+
if let Some(existing_entry) = known_variants.get_mut(&index) {
541+
existing_entry.value.unify(value)?;
439542
}
440-
// Compute the format for this variant.
441-
let inner = EnumDeserializer::new(self.tracer, self.samples, index, &mut value);
442-
visitor.visit_enum(inner)
543+
Ok(enum_value)
443544
}
444545

445546
fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value>
@@ -539,21 +640,21 @@ where
539640
struct EnumDeserializer<'de, 'a> {
540641
tracer: &'a mut Tracer,
541642
samples: &'de Samples,
542-
index: u32,
643+
variant_id: VariantId<'static>,
543644
format: &'a mut VariantFormat,
544645
}
545646

546647
impl<'de, 'a> EnumDeserializer<'de, 'a> {
547648
fn new(
548649
tracer: &'a mut Tracer,
549650
samples: &'de Samples,
550-
index: u32,
651+
variant_id: VariantId<'static>,
551652
format: &'a mut VariantFormat,
552653
) -> Self {
553654
Self {
554655
tracer,
555656
samples,
556-
index,
657+
variant_id,
557658
format,
558659
}
559660
}
@@ -567,8 +668,10 @@ impl<'de, 'a> de::EnumAccess<'de> for EnumDeserializer<'de, 'a> {
567668
where
568669
V: DeserializeSeed<'de>,
569670
{
570-
let index = self.index;
571-
let value = seed.deserialize(index.into_deserializer())?;
671+
let value = match self.variant_id {
672+
VariantId::Index(index) => seed.deserialize(U32Deserializer::new(index)),
673+
VariantId::Name(name) => seed.deserialize(BorrowedStrDeserializer::new(name)),
674+
}?;
572675
Ok((value, self))
573676
}
574677
}

0 commit comments

Comments
 (0)