diff --git a/src/de.rs b/src/de.rs index 961f225..5e41e64 100644 --- a/src/de.rs +++ b/src/de.rs @@ -5,7 +5,10 @@ use rquickjs::{ object::ObjectIter, qjs::{JS_GetClassID, JS_GetProperty}, }; -use serde::{de, forward_to_deserialize_any}; +use serde::{ + de::{self, IntoDeserializer, Unexpected}, + forward_to_deserialize_any, +}; use crate::err::{Error, Result}; use crate::utils::{as_key, to_string_lossy}; @@ -213,14 +216,13 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { } // FIXME: Replace type_of when https://github.com/DelSkayn/rquickjs/pull/458 is merged. - if get_class_id(&self.value) == ClassId::BigInt as u32 - || self.value.type_of() == rquickjs::Type::BigInt + if (get_class_id(&self.value) == ClassId::BigInt as u32 + || self.value.type_of() == rquickjs::Type::BigInt) + && let Some(f) = get_to_json(&self.value) { - if let Some(f) = get_to_json(&self.value) { - let v: Value = f.call((This(self.value.clone()),)).map_err(Error::new)?; - self.value = v; - return self.deserialize_any(visitor); - } + let v: Value = f.call((This(self.value.clone()),)).map_err(Error::new)?; + self.value = v; + return self.deserialize_any(visitor); } Err(Error::new(Exception::throw_type( @@ -255,12 +257,28 @@ impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> { self, _name: &'static str, _variants: &'static [&'static str], - _visitor: V, + visitor: V, ) -> Result where V: de::Visitor<'de>, { - unimplemented!() + if get_class_id(&self.value) == ClassId::String as u32 + && let Some(f) = get_to_string(&self.value) + { + let v = f.call((This(self.value.clone()),)).map_err(Error::new)?; + self.value = v; + } + + // Now require a primitive string. + let s = if let Some(s) = self.value.as_string() { + s.to_string() + .unwrap_or_else(|e| to_string_lossy(self.value.ctx(), s, e)) + } else { + return Err(Error::new("expected a string for enum unit variant")); + }; + + // Hand Serde an EnumAccess that only supports unit variants. + visitor.visit_enum(UnitEnumAccess { variant: s }) } forward_to_deserialize_any! { @@ -532,16 +550,75 @@ fn ensure_supported(value: &Value<'_>) -> Result { )) } +/// A helper struct for deserializing enums containing unit variants. +struct UnitEnumAccess { + variant: String, +} + +impl<'de> de::EnumAccess<'de> for UnitEnumAccess { + type Error = Error; + type Variant = UnitOnlyVariant; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant)> + where + V: de::DeserializeSeed<'de>, + { + let v = seed.deserialize(self.variant.into_deserializer())?; + Ok((v, UnitOnlyVariant)) + } +} + +struct UnitOnlyVariant; + +impl<'de> de::VariantAccess<'de> for UnitOnlyVariant { + type Error = Error; + + fn unit_variant(self) -> Result<()> { + Ok(()) + } + + fn newtype_variant_seed(self, _seed: T) -> Result + where + T: de::DeserializeSeed<'de>, + { + Err(de::Error::invalid_type( + Unexpected::NewtypeVariant, + &"unit variant", + )) + } + + fn tuple_variant(self, _len: usize, _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + Err(de::Error::invalid_type( + Unexpected::TupleVariant, + &"unit variant", + )) + } + + fn struct_variant(self, _fields: &'static [&'static str], _visitor: V) -> Result + where + V: de::Visitor<'de>, + { + Err(de::Error::invalid_type( + Unexpected::StructVariant, + &"unit variant", + )) + } +} + #[cfg(test)] mod tests { use std::collections::BTreeMap; use rquickjs::Value; use serde::de::DeserializeOwned; + use serde::{Deserialize, Serialize}; use super::Deserializer as ValueDeserializer; - use crate::MAX_SAFE_INTEGER; use crate::test::Runtime; + use crate::{MAX_SAFE_INTEGER, from_value, to_value}; fn deserialize_value(v: Value<'_>) -> T where @@ -759,4 +836,23 @@ mod tests { assert_eq!(vec![None; 5], val); }); } + + #[test] + fn test_enum() { + let rt = Runtime::default(); + + #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] + enum Test { + One, + Two, + Three, + } + + rt.context().with(|cx| { + let left = Test::Two; + let value = to_value(cx, left).unwrap(); + let right: Test = from_value(value).unwrap(); + assert_eq!(left, right); + }); + } }