Skip to content

Commit c060873

Browse files
committed
Add support for enums containing unit variants
1 parent 669edf9 commit c060873

File tree

1 file changed

+100
-11
lines changed

1 file changed

+100
-11
lines changed

src/de.rs

Lines changed: 100 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ use rquickjs::{
55
object::ObjectIter,
66
qjs::{JS_GetClassID, JS_GetProperty},
77
};
8-
use serde::{de, forward_to_deserialize_any};
8+
use serde::{
9+
de::{self, value::StrDeserializer},
10+
forward_to_deserialize_any,
11+
};
912

1013
use crate::err::{Error, Result};
1114
use crate::utils::{as_key, to_string_lossy};
@@ -213,14 +216,13 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
213216
}
214217

215218
// FIXME: Replace type_of when https://github.com/DelSkayn/rquickjs/pull/458 is merged.
216-
if get_class_id(&self.value) == ClassId::BigInt as u32
217-
|| self.value.type_of() == rquickjs::Type::BigInt
219+
if (get_class_id(&self.value) == ClassId::BigInt as u32
220+
|| self.value.type_of() == rquickjs::Type::BigInt)
221+
&& let Some(f) = get_to_json(&self.value)
218222
{
219-
if let Some(f) = get_to_json(&self.value) {
220-
let v: Value = f.call((This(self.value.clone()),)).map_err(Error::new)?;
221-
self.value = v;
222-
return self.deserialize_any(visitor);
223-
}
223+
let v: Value = f.call((This(self.value.clone()),)).map_err(Error::new)?;
224+
self.value = v;
225+
return self.deserialize_any(visitor);
224226
}
225227

226228
Err(Error::new(Exception::throw_type(
@@ -255,12 +257,29 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
255257
self,
256258
_name: &'static str,
257259
_variants: &'static [&'static str],
258-
_visitor: V,
260+
visitor: V,
259261
) -> Result<V::Value>
260262
where
261263
V: de::Visitor<'de>,
262264
{
263-
unimplemented!()
265+
if get_class_id(&self.value) == ClassId::String as u32
266+
&& let Some(f) = get_to_string(&self.value)
267+
{
268+
let v = f.call((This(self.value.clone()),)).map_err(Error::new)?;
269+
self.value = v;
270+
}
271+
272+
// Now require a primitive string.
273+
let s = if self.value.is_string() {
274+
let js_s = self.value.as_string().unwrap();
275+
js_s.to_string()
276+
.unwrap_or_else(|e| to_string_lossy(self.value.ctx(), js_s, e))
277+
} else {
278+
return Err(Error::new("expected a string for enum unit variant"));
279+
};
280+
281+
// Hand Serde an EnumAccess that only supports unit variants.
282+
visitor.visit_enum(UnitEnumAccess { variant: s })
264283
}
265284

266285
forward_to_deserialize_any! {
@@ -532,16 +551,67 @@ fn ensure_supported(value: &Value<'_>) -> Result<bool> {
532551
))
533552
}
534553

554+
/// A helper struct for deserializing enums containing unit variants.
555+
struct UnitEnumAccess {
556+
variant: String,
557+
}
558+
559+
impl<'de> de::EnumAccess<'de> for UnitEnumAccess {
560+
type Error = Error;
561+
type Variant = UnitOnlyVariant;
562+
563+
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
564+
where
565+
V: de::DeserializeSeed<'de>,
566+
{
567+
let id = StrDeserializer::<Error>::new(&self.variant);
568+
let v = seed.deserialize(id)?;
569+
Ok((v, UnitOnlyVariant))
570+
}
571+
}
572+
573+
struct UnitOnlyVariant;
574+
575+
impl<'de> de::VariantAccess<'de> for UnitOnlyVariant {
576+
type Error = Error;
577+
578+
fn unit_variant(self) -> Result<()> {
579+
Ok(())
580+
}
581+
582+
fn newtype_variant_seed<T>(self, _seed: T) -> Result<T::Value>
583+
where
584+
T: de::DeserializeSeed<'de>,
585+
{
586+
Err(Error::new("only unit variants are supported"))
587+
}
588+
589+
fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value>
590+
where
591+
V: de::Visitor<'de>,
592+
{
593+
Err(Error::new("only unit variants are supported"))
594+
}
595+
596+
fn struct_variant<V>(self, _fields: &'static [&'static str], _visitor: V) -> Result<V::Value>
597+
where
598+
V: de::Visitor<'de>,
599+
{
600+
Err(Error::new("only unit variants are supported"))
601+
}
602+
}
603+
535604
#[cfg(test)]
536605
mod tests {
537606
use std::collections::BTreeMap;
538607

539608
use rquickjs::Value;
540609
use serde::de::DeserializeOwned;
610+
use serde::{Deserialize, Serialize};
541611

542612
use super::Deserializer as ValueDeserializer;
543-
use crate::MAX_SAFE_INTEGER;
544613
use crate::test::Runtime;
614+
use crate::{MAX_SAFE_INTEGER, from_value, to_value};
545615

546616
fn deserialize_value<T>(v: Value<'_>) -> T
547617
where
@@ -759,4 +829,23 @@ mod tests {
759829
assert_eq!(vec![None; 5], val);
760830
});
761831
}
832+
833+
#[test]
834+
fn test_enum() {
835+
let rt = Runtime::default();
836+
837+
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
838+
enum Test {
839+
One,
840+
Two,
841+
Three,
842+
}
843+
844+
rt.context().with(|cx| {
845+
let left = Test::Two;
846+
let value = to_value(cx, left).unwrap();
847+
let right: Test = from_value(value).unwrap();
848+
assert_eq!(left, right);
849+
});
850+
}
762851
}

0 commit comments

Comments
 (0)