aboutsummaryrefslogtreecommitdiff
path: root/src/lua/serde/de.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/lua/serde/de.rs')
-rw-r--r--src/lua/serde/de.rs436
1 files changed, 436 insertions, 0 deletions
diff --git a/src/lua/serde/de.rs b/src/lua/serde/de.rs
new file mode 100644
index 0000000..4d78a83
--- /dev/null
+++ b/src/lua/serde/de.rs
@@ -0,0 +1,436 @@
+// vendored from https://github.com/zrkn/rlua_serde because it dependend on an outdated rlua version
+// Copyright (c) 2018 zrkn <zrkn@email.su>, licensed under the MIT License
+
+use serde::de::IntoDeserializer;
+use serde::{self, forward_to_deserialize_any};
+
+use rlua::{TablePairs, TableSequence, Value};
+
+use super::error::{Error, Result};
+
+pub struct Deserializer<'lua> {
+ pub value: Value<'lua>,
+}
+
+impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> {
+ type Error = Error;
+
+ #[inline]
+ fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
+ where
+ V: serde::de::Visitor<'de>,
+ {
+ match self.value {
+ Value::Nil => visitor.visit_unit(),
+ Value::Boolean(v) => visitor.visit_bool(v),
+ Value::Integer(v) => visitor.visit_i64(v),
+ Value::Number(v) => visitor.visit_f64(v),
+ Value::String(v) => visitor.visit_str(v.to_str()?),
+ Value::Table(v) => {
+ let len = v.len()? as usize;
+ let mut deserializer = MapDeserializer(v.pairs(), None);
+ let map = visitor.visit_map(&mut deserializer)?;
+ let remaining = deserializer.0.count();
+ if remaining == 0 {
+ Ok(map)
+ } else {
+ Err(serde::de::Error::invalid_length(
+ len,
+ &"fewer elements in array",
+ ))
+ }
+ }
+ _ => Err(serde::de::Error::custom("invalid value type")),
+ }
+ }
+
+ #[inline]
+ fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
+ where
+ V: serde::de::Visitor<'de>,
+ {
+ match self.value {
+ Value::Nil => visitor.visit_none(),
+ _ => visitor.visit_some(self),
+ }
+ }
+
+ #[inline]
+ fn deserialize_enum<V>(
+ self,
+ _name: &str,
+ _variants: &'static [&'static str],
+ visitor: V,
+ ) -> Result<V::Value>
+ where
+ V: serde::de::Visitor<'de>,
+ {
+ let (variant, value) = match self.value {
+ Value::Table(value) => {
+ let mut iter = value.pairs::<String, Value>();
+ let (variant, value) = match iter.next() {
+ Some(v) => v?,
+ None => {
+ return Err(serde::de::Error::invalid_value(
+ serde::de::Unexpected::Map,
+ &"map with a single key",
+ ))
+ }
+ };
+
+ if iter.next().is_some() {
+ return Err(serde::de::Error::invalid_value(
+ serde::de::Unexpected::Map,
+ &"map with a single key",
+ ));
+ }
+ (variant, Some(value))
+ }
+ Value::String(variant) => (variant.to_str()?.to_owned(), None),
+ _ => return Err(serde::de::Error::custom("bad enum value")),
+ };
+
+ visitor.visit_enum(EnumDeserializer { variant, value })
+ }
+
+ #[inline]
+ fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
+ where
+ V: serde::de::Visitor<'de>,
+ {
+ match self.value {
+ Value::Table(v) => {
+ let len = v.len()? as usize;
+ let mut deserializer = SeqDeserializer(v.sequence_values());
+ let seq = visitor.visit_seq(&mut deserializer)?;
+ let remaining = deserializer.0.count();
+ if remaining == 0 {
+ Ok(seq)
+ } else {
+ Err(serde::de::Error::invalid_length(
+ len,
+ &"fewer elements in array",
+ ))
+ }
+ }
+ _ => Err(serde::de::Error::custom("invalid value type")),
+ }
+ }
+
+ #[inline]
+ fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
+ where
+ V: serde::de::Visitor<'de>,
+ {
+ self.deserialize_seq(visitor)
+ }
+
+ #[inline]
+ fn deserialize_tuple_struct<V>(
+ self,
+ _name: &'static str,
+ _len: usize,
+ visitor: V,
+ ) -> Result<V::Value>
+ where
+ V: serde::de::Visitor<'de>,
+ {
+ self.deserialize_seq(visitor)
+ }
+
+ forward_to_deserialize_any! {
+ bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes
+ byte_buf unit unit_struct newtype_struct
+ map struct identifier ignored_any
+ }
+}
+
+struct SeqDeserializer<'lua>(TableSequence<'lua, Value<'lua>>);
+
+impl<'lua, 'de> serde::de::SeqAccess<'de> for SeqDeserializer<'lua> {
+ type Error = Error;
+
+ fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
+ where
+ T: serde::de::DeserializeSeed<'de>,
+ {
+ match self.0.next() {
+ Some(value) => seed.deserialize(Deserializer { value: value? }).map(Some),
+ None => Ok(None),
+ }
+ }
+
+ fn size_hint(&self) -> Option<usize> {
+ match self.0.size_hint() {
+ (lower, Some(upper)) if lower == upper => Some(upper),
+ _ => None,
+ }
+ }
+}
+
+struct MapDeserializer<'lua>(
+ TablePairs<'lua, Value<'lua>, Value<'lua>>,
+ Option<Value<'lua>>,
+);
+
+impl<'lua, 'de> serde::de::MapAccess<'de> for MapDeserializer<'lua> {
+ type Error = Error;
+
+ fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
+ where
+ T: serde::de::DeserializeSeed<'de>,
+ {
+ match self.0.next() {
+ Some(item) => {
+ let (key, value) = item?;
+ self.1 = Some(value);
+ let key_de = Deserializer { value: key };
+ seed.deserialize(key_de).map(Some)
+ }
+ None => Ok(None),
+ }
+ }
+
+ fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value>
+ where
+ T: serde::de::DeserializeSeed<'de>,
+ {
+ match self.1.take() {
+ Some(value) => seed.deserialize(Deserializer { value }),
+ None => Err(serde::de::Error::custom("value is missing")),
+ }
+ }
+
+ fn size_hint(&self) -> Option<usize> {
+ match self.0.size_hint() {
+ (lower, Some(upper)) if lower == upper => Some(upper),
+ _ => None,
+ }
+ }
+}
+
+struct EnumDeserializer<'lua> {
+ variant: String,
+ value: Option<Value<'lua>>,
+}
+
+impl<'lua, 'de> serde::de::EnumAccess<'de> for EnumDeserializer<'lua> {
+ type Error = Error;
+ type Variant = VariantDeserializer<'lua>;
+
+ fn variant_seed<T>(self, seed: T) -> Result<(T::Value, Self::Variant)>
+ where
+ T: serde::de::DeserializeSeed<'de>,
+ {
+ let variant = self.variant.into_deserializer();
+ let variant_access = VariantDeserializer { value: self.value };
+ seed.deserialize(variant).map(|v| (v, variant_access))
+ }
+}
+
+struct VariantDeserializer<'lua> {
+ value: Option<Value<'lua>>,
+}
+
+impl<'lua, 'de> serde::de::VariantAccess<'de> for VariantDeserializer<'lua> {
+ type Error = Error;
+
+ fn unit_variant(self) -> Result<()> {
+ match self.value {
+ Some(_) => Err(serde::de::Error::invalid_type(
+ serde::de::Unexpected::NewtypeVariant,
+ &"unit variant",
+ )),
+ None => Ok(()),
+ }
+ }
+
+ fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
+ where
+ T: serde::de::DeserializeSeed<'de>,
+ {
+ match self.value {
+ Some(value) => seed.deserialize(Deserializer { value }),
+ None => Err(serde::de::Error::invalid_type(
+ serde::de::Unexpected::UnitVariant,
+ &"newtype variant",
+ )),
+ }
+ }
+
+ fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
+ where
+ V: serde::de::Visitor<'de>,
+ {
+ match self.value {
+ Some(value) => serde::Deserializer::deserialize_seq(Deserializer { value }, visitor),
+ None => Err(serde::de::Error::invalid_type(
+ serde::de::Unexpected::UnitVariant,
+ &"tuple variant",
+ )),
+ }
+ }
+
+ fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
+ where
+ V: serde::de::Visitor<'de>,
+ {
+ match self.value {
+ Some(value) => serde::Deserializer::deserialize_map(Deserializer { value }, visitor),
+ None => Err(serde::de::Error::invalid_type(
+ serde::de::Unexpected::UnitVariant,
+ &"struct variant",
+ )),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use rlua::Lua;
+
+ use super::super::from_value;
+ use serde::Deserialize;
+
+ #[test]
+ fn test_struct() {
+ #[derive(Deserialize, PartialEq, Debug)]
+ struct Test {
+ int: u32,
+ seq: Vec<String>,
+ map: std::collections::HashMap<i32, i32>,
+ empty: Vec<()>,
+ }
+
+ let expected = Test {
+ int: 1,
+ seq: vec!["a".to_owned(), "b".to_owned()],
+ map: vec![(1, 2), (4, 1)].into_iter().collect(),
+ empty: vec![],
+ };
+
+ println!("{:?}", expected);
+ let lua = Lua::new();
+ lua.context(|lua| {
+ let value = lua
+ .load(
+ r#"
+ a = {}
+ a.int = 1
+ a.seq = {"a", "b"}
+ a.map = {2, [4]=1}
+ a.empty = {}
+ return a
+ "#,
+ )
+ .eval()
+ .unwrap();
+ let got = from_value(value).unwrap();
+ assert_eq!(expected, got);
+ });
+ }
+
+ #[test]
+ fn test_tuple() {
+ #[derive(Deserialize, PartialEq, Debug)]
+ struct Rgb(u8, u8, u8);
+
+ let lua = Lua::new();
+ lua.context(|lua| {
+ let expected = Rgb(1, 2, 3);
+ let value = lua
+ .load(
+ r#"
+ a = {1, 2, 3}
+ return a
+ "#,
+ )
+ .eval()
+ .unwrap();
+ let got = from_value(value).unwrap();
+ assert_eq!(expected, got);
+
+ let expected = (1, 2, 3);
+ let value = lua
+ .load(
+ r#"
+ a = {1, 2, 3}
+ return a
+ "#,
+ )
+ .eval()
+ .unwrap();
+ let got = from_value(value).unwrap();
+ assert_eq!(expected, got);
+ });
+ }
+
+ #[test]
+ fn test_enum() {
+ #[derive(Deserialize, PartialEq, Debug)]
+ enum E {
+ Unit,
+ Newtype(u32),
+ Tuple(u32, u32),
+ Struct { a: u32 },
+ }
+
+ let lua = Lua::new();
+ lua.context(|lua| {
+ let expected = E::Unit;
+ let value = lua
+ .load(
+ r#"
+ return "Unit"
+ "#,
+ )
+ .eval()
+ .unwrap();
+ let got = from_value(value).unwrap();
+ assert_eq!(expected, got);
+
+ let expected = E::Newtype(1);
+ let value = lua
+ .load(
+ r#"
+ a = {}
+ a["Newtype"] = 1
+ return a
+ "#,
+ )
+ .eval()
+ .unwrap();
+ let got = from_value(value).unwrap();
+ assert_eq!(expected, got);
+
+ let expected = E::Tuple(1, 2);
+ let value = lua
+ .load(
+ r#"
+ a = {}
+ a["Tuple"] = {1, 2}
+ return a
+ "#,
+ )
+ .eval()
+ .unwrap();
+ let got = from_value(value).unwrap();
+ assert_eq!(expected, got);
+
+ let expected = E::Struct { a: 1 };
+ let value = lua
+ .load(
+ r#"
+ a = {}
+ a["Struct"] = {}
+ a["Struct"]["a"] = 1
+ return a
+ "#,
+ )
+ .eval()
+ .unwrap();
+ let got = from_value(value).unwrap();
+ assert_eq!(expected, got);
+ });
+ }
+}