Skip to content

Commit 598ff85

Browse files
authored
Add support for enums containing unit variants (#2)
* Add support for enums containing unit variants * Fixes after review * Fix after review
1 parent 669edf9 commit 598ff85

File tree

1 file changed

+107
-11
lines changed

1 file changed

+107
-11
lines changed

src/de.rs

Lines changed: 107 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ use rquickjs::{
55
object::ObjectIter,
66
qjs::{JS_GetClassID, JS_GetProperty},
77
};
8-
use serde::{de, forward_to_deserialize_any};
8+
use serde::{
9+
de::{self, IntoDeserializer, Unexpected},
10+
forward_to_deserialize_any,
11+
};
912

1013
use crate::err::{Error, Result};
1114
use crate::utils::{as_key, to_string_lossy};
@@ -213,14 +216,13 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
213216
}
214217

215218
// FIXME: Replace type_of when https://github.com/DelSkayn/rquickjs/pull/458 is merged.
216-
if get_class_id(&self.value) == ClassId::BigInt as u32
217-
|| self.value.type_of() == rquickjs::Type::BigInt
219+
if (get_class_id(&self.value) == ClassId::BigInt as u32
220+
|| self.value.type_of() == rquickjs::Type::BigInt)
221+
&& let Some(f) = get_to_json(&self.value)
218222
{
219-
if let Some(f) = get_to_json(&self.value) {
220-
let v: Value = f.call((This(self.value.clone()),)).map_err(Error::new)?;
221-
self.value = v;
222-
return self.deserialize_any(visitor);
223-
}
223+
let v: Value = f.call((This(self.value.clone()),)).map_err(Error::new)?;
224+
self.value = v;
225+
return self.deserialize_any(visitor);
224226
}
225227

226228
Err(Error::new(Exception::throw_type(
@@ -255,12 +257,28 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
255257
self,
256258
_name: &'static str,
257259
_variants: &'static [&'static str],
258-
_visitor: V,
260+
visitor: V,
259261
) -> Result<V::Value>
260262
where
261263
V: de::Visitor<'de>,
262264
{
263-
unimplemented!()
265+
if get_class_id(&self.value) == ClassId::String as u32
266+
&& let Some(f) = get_to_string(&self.value)
267+
{
268+
let v = f.call((This(self.value.clone()),)).map_err(Error::new)?;
269+
self.value = v;
270+
}
271+
272+
// Now require a primitive string.
273+
let s = if let Some(s) = self.value.as_string() {
274+
s.to_string()
275+
.unwrap_or_else(|e| to_string_lossy(self.value.ctx(), s, e))
276+
} else {
277+
return Err(Error::new("expected a string for enum unit variant"));
278+
};
279+
280+
// Hand Serde an EnumAccess that only supports unit variants.
281+
visitor.visit_enum(UnitEnumAccess { variant: s })
264282
}
265283

266284
forward_to_deserialize_any! {
@@ -532,16 +550,75 @@ fn ensure_supported(value: &Value<'_>) -> Result<bool> {
532550
))
533551
}
534552

553+
/// A helper struct for deserializing enums containing unit variants.
554+
struct UnitEnumAccess {
555+
variant: String,
556+
}
557+
558+
impl<'de> de::EnumAccess<'de> for UnitEnumAccess {
559+
type Error = Error;
560+
type Variant = UnitOnlyVariant;
561+
562+
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
563+
where
564+
V: de::DeserializeSeed<'de>,
565+
{
566+
let v = seed.deserialize(self.variant.into_deserializer())?;
567+
Ok((v, UnitOnlyVariant))
568+
}
569+
}
570+
571+
struct UnitOnlyVariant;
572+
573+
impl<'de> de::VariantAccess<'de> for UnitOnlyVariant {
574+
type Error = Error;
575+
576+
fn unit_variant(self) -> Result<()> {
577+
Ok(())
578+
}
579+
580+
fn newtype_variant_seed<T>(self, _seed: T) -> Result<T::Value>
581+
where
582+
T: de::DeserializeSeed<'de>,
583+
{
584+
Err(de::Error::invalid_type(
585+
Unexpected::NewtypeVariant,
586+
&"unit variant",
587+
))
588+
}
589+
590+
fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value>
591+
where
592+
V: de::Visitor<'de>,
593+
{
594+
Err(de::Error::invalid_type(
595+
Unexpected::TupleVariant,
596+
&"unit variant",
597+
))
598+
}
599+
600+
fn struct_variant<V>(self, _fields: &'static [&'static str], _visitor: V) -> Result<V::Value>
601+
where
602+
V: de::Visitor<'de>,
603+
{
604+
Err(de::Error::invalid_type(
605+
Unexpected::StructVariant,
606+
&"unit variant",
607+
))
608+
}
609+
}
610+
535611
#[cfg(test)]
536612
mod tests {
537613
use std::collections::BTreeMap;
538614

539615
use rquickjs::Value;
540616
use serde::de::DeserializeOwned;
617+
use serde::{Deserialize, Serialize};
541618

542619
use super::Deserializer as ValueDeserializer;
543-
use crate::MAX_SAFE_INTEGER;
544620
use crate::test::Runtime;
621+
use crate::{MAX_SAFE_INTEGER, from_value, to_value};
545622

546623
fn deserialize_value<T>(v: Value<'_>) -> T
547624
where
@@ -759,4 +836,23 @@ mod tests {
759836
assert_eq!(vec![None; 5], val);
760837
});
761838
}
839+
840+
#[test]
841+
fn test_enum() {
842+
let rt = Runtime::default();
843+
844+
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
845+
enum Test {
846+
One,
847+
Two,
848+
Three,
849+
}
850+
851+
rt.context().with(|cx| {
852+
let left = Test::Two;
853+
let value = to_value(cx, left).unwrap();
854+
let right: Test = from_value(value).unwrap();
855+
assert_eq!(left, right);
856+
});
857+
}
762858
}

0 commit comments

Comments
 (0)