@@ -5,7 +5,10 @@ use rquickjs::{
55 object:: ObjectIter ,
66 qjs:: { JS_GetClassID , JS_GetProperty } ,
77} ;
8- use serde:: { de, forward_to_deserialize_any} ;
8+ use serde:: {
9+ de:: { self , value:: StrDeserializer } ,
10+ forward_to_deserialize_any,
11+ } ;
912
1013use crate :: err:: { Error , Result } ;
1114use crate :: utils:: { as_key, to_string_lossy} ;
@@ -213,14 +216,13 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
213216 }
214217
215218 // FIXME: Replace type_of when https://github.com/DelSkayn/rquickjs/pull/458 is merged.
216- if get_class_id ( & self . value ) == ClassId :: BigInt as u32
217- || self . value . type_of ( ) == rquickjs:: Type :: BigInt
219+ if ( get_class_id ( & self . value ) == ClassId :: BigInt as u32
220+ || self . value . type_of ( ) == rquickjs:: Type :: BigInt )
221+ && let Some ( f) = get_to_json ( & self . value )
218222 {
219- if let Some ( f) = get_to_json ( & self . value ) {
220- let v: Value = f. call ( ( This ( self . value . clone ( ) ) , ) ) . map_err ( Error :: new) ?;
221- self . value = v;
222- return self . deserialize_any ( visitor) ;
223- }
223+ let v: Value = f. call ( ( This ( self . value . clone ( ) ) , ) ) . map_err ( Error :: new) ?;
224+ self . value = v;
225+ return self . deserialize_any ( visitor) ;
224226 }
225227
226228 Err ( Error :: new ( Exception :: throw_type (
@@ -255,12 +257,29 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
255257 self ,
256258 _name : & ' static str ,
257259 _variants : & ' static [ & ' static str ] ,
258- _visitor : V ,
260+ visitor : V ,
259261 ) -> Result < V :: Value >
260262 where
261263 V : de:: Visitor < ' de > ,
262264 {
263- unimplemented ! ( )
265+ if get_class_id ( & self . value ) == ClassId :: String as u32
266+ && let Some ( f) = get_to_string ( & self . value )
267+ {
268+ let v = f. call ( ( This ( self . value . clone ( ) ) , ) ) . map_err ( Error :: new) ?;
269+ self . value = v;
270+ }
271+
272+ // Now require a primitive string.
273+ let s = if self . value . is_string ( ) {
274+ let js_s = self . value . as_string ( ) . unwrap ( ) ;
275+ js_s. to_string ( )
276+ . unwrap_or_else ( |e| to_string_lossy ( self . value . ctx ( ) , js_s, e) )
277+ } else {
278+ return Err ( Error :: new ( "expected a string for enum unit variant" ) ) ;
279+ } ;
280+
281+ // Hand Serde an EnumAccess that only supports unit variants.
282+ visitor. visit_enum ( UnitEnumAccess { variant : s } )
264283 }
265284
266285 forward_to_deserialize_any ! {
@@ -532,16 +551,67 @@ fn ensure_supported(value: &Value<'_>) -> Result<bool> {
532551 ) )
533552}
534553
554+ /// A helper struct for deserializing enums containing unit variants.
555+ struct UnitEnumAccess {
556+ variant : String ,
557+ }
558+
559+ impl < ' de > de:: EnumAccess < ' de > for UnitEnumAccess {
560+ type Error = Error ;
561+ type Variant = UnitOnlyVariant ;
562+
563+ fn variant_seed < V > ( self , seed : V ) -> Result < ( V :: Value , Self :: Variant ) >
564+ where
565+ V : de:: DeserializeSeed < ' de > ,
566+ {
567+ let id = StrDeserializer :: < Error > :: new ( & self . variant ) ;
568+ let v = seed. deserialize ( id) ?;
569+ Ok ( ( v, UnitOnlyVariant ) )
570+ }
571+ }
572+
573+ struct UnitOnlyVariant ;
574+
575+ impl < ' de > de:: VariantAccess < ' de > for UnitOnlyVariant {
576+ type Error = Error ;
577+
578+ fn unit_variant ( self ) -> Result < ( ) > {
579+ Ok ( ( ) )
580+ }
581+
582+ fn newtype_variant_seed < T > ( self , _seed : T ) -> Result < T :: Value >
583+ where
584+ T : de:: DeserializeSeed < ' de > ,
585+ {
586+ Err ( Error :: new ( "only unit variants are supported" ) )
587+ }
588+
589+ fn tuple_variant < V > ( self , _len : usize , _visitor : V ) -> Result < V :: Value >
590+ where
591+ V : de:: Visitor < ' de > ,
592+ {
593+ Err ( Error :: new ( "only unit variants are supported" ) )
594+ }
595+
596+ fn struct_variant < V > ( self , _fields : & ' static [ & ' static str ] , _visitor : V ) -> Result < V :: Value >
597+ where
598+ V : de:: Visitor < ' de > ,
599+ {
600+ Err ( Error :: new ( "only unit variants are supported" ) )
601+ }
602+ }
603+
535604#[ cfg( test) ]
536605mod tests {
537606 use std:: collections:: BTreeMap ;
538607
539608 use rquickjs:: Value ;
540609 use serde:: de:: DeserializeOwned ;
610+ use serde:: { Deserialize , Serialize } ;
541611
542612 use super :: Deserializer as ValueDeserializer ;
543- use crate :: MAX_SAFE_INTEGER ;
544613 use crate :: test:: Runtime ;
614+ use crate :: { MAX_SAFE_INTEGER , from_value, to_value} ;
545615
546616 fn deserialize_value < T > ( v : Value < ' _ > ) -> T
547617 where
@@ -759,4 +829,23 @@ mod tests {
759829 assert_eq ! ( vec![ None ; 5 ] , val) ;
760830 } ) ;
761831 }
832+
833+ #[ test]
834+ fn test_enum ( ) {
835+ let rt = Runtime :: default ( ) ;
836+
837+ #[ derive( Debug , Clone , Copy , PartialEq , Serialize , Deserialize ) ]
838+ enum Test {
839+ One ,
840+ Two ,
841+ Three ,
842+ }
843+
844+ rt. context ( ) . with ( |cx| {
845+ let left = Test :: Two ;
846+ let value = to_value ( cx, left) . unwrap ( ) ;
847+ let right: Test = from_value ( value) . unwrap ( ) ;
848+ assert_eq ! ( left, right) ;
849+ } ) ;
850+ }
762851}
0 commit comments