@@ -5,7 +5,10 @@ use rquickjs::{
55 object:: ObjectIter ,
66 qjs:: { JS_GetClassID , JS_GetProperty } ,
77} ;
8- use serde:: { de, forward_to_deserialize_any} ;
8+ use serde:: {
9+ de:: { self , IntoDeserializer , Unexpected } ,
10+ forward_to_deserialize_any,
11+ } ;
912
1013use crate :: err:: { Error , Result } ;
1114use crate :: utils:: { as_key, to_string_lossy} ;
@@ -213,14 +216,13 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
213216 }
214217
215218 // FIXME: Replace type_of when https://github.com/DelSkayn/rquickjs/pull/458 is merged.
216- if get_class_id ( & self . value ) == ClassId :: BigInt as u32
217- || self . value . type_of ( ) == rquickjs:: Type :: BigInt
219+ if ( get_class_id ( & self . value ) == ClassId :: BigInt as u32
220+ || self . value . type_of ( ) == rquickjs:: Type :: BigInt )
221+ && let Some ( f) = get_to_json ( & self . value )
218222 {
219- if let Some ( f) = get_to_json ( & self . value ) {
220- let v: Value = f. call ( ( This ( self . value . clone ( ) ) , ) ) . map_err ( Error :: new) ?;
221- self . value = v;
222- return self . deserialize_any ( visitor) ;
223- }
223+ let v: Value = f. call ( ( This ( self . value . clone ( ) ) , ) ) . map_err ( Error :: new) ?;
224+ self . value = v;
225+ return self . deserialize_any ( visitor) ;
224226 }
225227
226228 Err ( Error :: new ( Exception :: throw_type (
@@ -255,12 +257,28 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
255257 self ,
256258 _name : & ' static str ,
257259 _variants : & ' static [ & ' static str ] ,
258- _visitor : V ,
260+ visitor : V ,
259261 ) -> Result < V :: Value >
260262 where
261263 V : de:: Visitor < ' de > ,
262264 {
263- unimplemented ! ( )
265+ if get_class_id ( & self . value ) == ClassId :: String as u32
266+ && let Some ( f) = get_to_string ( & self . value )
267+ {
268+ let v = f. call ( ( This ( self . value . clone ( ) ) , ) ) . map_err ( Error :: new) ?;
269+ self . value = v;
270+ }
271+
272+ // Now require a primitive string.
273+ let s = if let Some ( s) = self . value . as_string ( ) {
274+ s. to_string ( )
275+ . unwrap_or_else ( |e| to_string_lossy ( self . value . ctx ( ) , s, e) )
276+ } else {
277+ return Err ( Error :: new ( "expected a string for enum unit variant" ) ) ;
278+ } ;
279+
280+ // Hand Serde an EnumAccess that only supports unit variants.
281+ visitor. visit_enum ( UnitEnumAccess { variant : s } )
264282 }
265283
266284 forward_to_deserialize_any ! {
@@ -532,16 +550,75 @@ fn ensure_supported(value: &Value<'_>) -> Result<bool> {
532550 ) )
533551}
534552
553+ /// A helper struct for deserializing enums containing unit variants.
554+ struct UnitEnumAccess {
555+ variant : String ,
556+ }
557+
558+ impl < ' de > de:: EnumAccess < ' de > for UnitEnumAccess {
559+ type Error = Error ;
560+ type Variant = UnitOnlyVariant ;
561+
562+ fn variant_seed < V > ( self , seed : V ) -> Result < ( V :: Value , Self :: Variant ) >
563+ where
564+ V : de:: DeserializeSeed < ' de > ,
565+ {
566+ let v = seed. deserialize ( self . variant . into_deserializer ( ) ) ?;
567+ Ok ( ( v, UnitOnlyVariant ) )
568+ }
569+ }
570+
571+ struct UnitOnlyVariant ;
572+
573+ impl < ' de > de:: VariantAccess < ' de > for UnitOnlyVariant {
574+ type Error = Error ;
575+
576+ fn unit_variant ( self ) -> Result < ( ) > {
577+ Ok ( ( ) )
578+ }
579+
580+ fn newtype_variant_seed < T > ( self , _seed : T ) -> Result < T :: Value >
581+ where
582+ T : de:: DeserializeSeed < ' de > ,
583+ {
584+ Err ( de:: Error :: invalid_type (
585+ Unexpected :: NewtypeVariant ,
586+ & "unit variant" ,
587+ ) )
588+ }
589+
590+ fn tuple_variant < V > ( self , _len : usize , _visitor : V ) -> Result < V :: Value >
591+ where
592+ V : de:: Visitor < ' de > ,
593+ {
594+ Err ( de:: Error :: invalid_type (
595+ Unexpected :: TupleVariant ,
596+ & "unit variant" ,
597+ ) )
598+ }
599+
600+ fn struct_variant < V > ( self , _fields : & ' static [ & ' static str ] , _visitor : V ) -> Result < V :: Value >
601+ where
602+ V : de:: Visitor < ' de > ,
603+ {
604+ Err ( de:: Error :: invalid_type (
605+ Unexpected :: StructVariant ,
606+ & "unit variant" ,
607+ ) )
608+ }
609+ }
610+
535611#[ cfg( test) ]
536612mod tests {
537613 use std:: collections:: BTreeMap ;
538614
539615 use rquickjs:: Value ;
540616 use serde:: de:: DeserializeOwned ;
617+ use serde:: { Deserialize , Serialize } ;
541618
542619 use super :: Deserializer as ValueDeserializer ;
543- use crate :: MAX_SAFE_INTEGER ;
544620 use crate :: test:: Runtime ;
621+ use crate :: { MAX_SAFE_INTEGER , from_value, to_value} ;
545622
546623 fn deserialize_value < T > ( v : Value < ' _ > ) -> T
547624 where
@@ -759,4 +836,23 @@ mod tests {
759836 assert_eq ! ( vec![ None ; 5 ] , val) ;
760837 } ) ;
761838 }
839+
840+ #[ test]
841+ fn test_enum ( ) {
842+ let rt = Runtime :: default ( ) ;
843+
844+ #[ derive( Debug , Clone , Copy , PartialEq , Serialize , Deserialize ) ]
845+ enum Test {
846+ One ,
847+ Two ,
848+ Three ,
849+ }
850+
851+ rt. context ( ) . with ( |cx| {
852+ let left = Test :: Two ;
853+ let value = to_value ( cx, left) . unwrap ( ) ;
854+ let right: Test = from_value ( value) . unwrap ( ) ;
855+ assert_eq ! ( left, right) ;
856+ } ) ;
857+ }
762858}
0 commit comments