diff options
author | Martin Fischer <martin@push-f.com> | 2022-08-13 03:56:55 +0200 |
---|---|---|
committer | Martin Fischer <martin@push-f.com> | 2022-08-14 00:33:23 +0200 |
commit | 36e98c1b135f07ef9a5eec046b8f0fd6d93534d4 (patch) | |
tree | da9c93ceb0e147695e1992c4d1678964af3be4bf /src/lua/serde/de.rs | |
parent | dc20e1df60c1e4e81d1e16e8f177a1c6956966b7 (diff) |
add gitpad.decode_toml lua method
We are vendoring the rlua_serde crate because it currently depends on
rlua 0.17, which is outdated and my attempts to contact the crate author
were bounced by Yandex for somehow looking like spam.
Diffstat (limited to 'src/lua/serde/de.rs')
-rw-r--r-- | src/lua/serde/de.rs | 436 |
1 files changed, 436 insertions, 0 deletions
diff --git a/src/lua/serde/de.rs b/src/lua/serde/de.rs new file mode 100644 index 0000000..4d78a83 --- /dev/null +++ b/src/lua/serde/de.rs @@ -0,0 +1,436 @@ +// vendored from https://github.com/zrkn/rlua_serde because it dependend on an outdated rlua version +// Copyright (c) 2018 zrkn <zrkn@email.su>, licensed under the MIT License + +use serde::de::IntoDeserializer; +use serde::{self, forward_to_deserialize_any}; + +use rlua::{TablePairs, TableSequence, Value}; + +use super::error::{Error, Result}; + +pub struct Deserializer<'lua> { + pub value: Value<'lua>, +} + +impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { + type Error = Error; + + #[inline] + fn deserialize_any<V>(self, visitor: V) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + match self.value { + Value::Nil => visitor.visit_unit(), + Value::Boolean(v) => visitor.visit_bool(v), + Value::Integer(v) => visitor.visit_i64(v), + Value::Number(v) => visitor.visit_f64(v), + Value::String(v) => visitor.visit_str(v.to_str()?), + Value::Table(v) => { + let len = v.len()? as usize; + let mut deserializer = MapDeserializer(v.pairs(), None); + let map = visitor.visit_map(&mut deserializer)?; + let remaining = deserializer.0.count(); + if remaining == 0 { + Ok(map) + } else { + Err(serde::de::Error::invalid_length( + len, + &"fewer elements in array", + )) + } + } + _ => Err(serde::de::Error::custom("invalid value type")), + } + } + + #[inline] + fn deserialize_option<V>(self, visitor: V) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + match self.value { + Value::Nil => visitor.visit_none(), + _ => visitor.visit_some(self), + } + } + + #[inline] + fn deserialize_enum<V>( + self, + _name: &str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + let (variant, value) = match self.value { + Value::Table(value) => { + let mut iter = value.pairs::<String, Value>(); + let (variant, value) = match iter.next() { + Some(v) => v?, + None => { + return Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Map, + &"map with a single key", + )) + } + }; + + if iter.next().is_some() { + return Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Map, + &"map with a single key", + )); + } + (variant, Some(value)) + } + Value::String(variant) => (variant.to_str()?.to_owned(), None), + _ => return Err(serde::de::Error::custom("bad enum value")), + }; + + visitor.visit_enum(EnumDeserializer { variant, value }) + } + + #[inline] + fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + match self.value { + Value::Table(v) => { + let len = v.len()? as usize; + let mut deserializer = SeqDeserializer(v.sequence_values()); + let seq = visitor.visit_seq(&mut deserializer)?; + let remaining = deserializer.0.count(); + if remaining == 0 { + Ok(seq) + } else { + Err(serde::de::Error::invalid_length( + len, + &"fewer elements in array", + )) + } + } + _ => Err(serde::de::Error::custom("invalid value type")), + } + } + + #[inline] + fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + self.deserialize_seq(visitor) + } + + #[inline] + fn deserialize_tuple_struct<V>( + self, + _name: &'static str, + _len: usize, + visitor: V, + ) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + self.deserialize_seq(visitor) + } + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes + byte_buf unit unit_struct newtype_struct + map struct identifier ignored_any + } +} + +struct SeqDeserializer<'lua>(TableSequence<'lua, Value<'lua>>); + +impl<'lua, 'de> serde::de::SeqAccess<'de> for SeqDeserializer<'lua> { + type Error = Error; + + fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>> + where + T: serde::de::DeserializeSeed<'de>, + { + match self.0.next() { + Some(value) => seed.deserialize(Deserializer { value: value? }).map(Some), + None => Ok(None), + } + } + + fn size_hint(&self) -> Option<usize> { + match self.0.size_hint() { + (lower, Some(upper)) if lower == upper => Some(upper), + _ => None, + } + } +} + +struct MapDeserializer<'lua>( + TablePairs<'lua, Value<'lua>, Value<'lua>>, + Option<Value<'lua>>, +); + +impl<'lua, 'de> serde::de::MapAccess<'de> for MapDeserializer<'lua> { + type Error = Error; + + fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>> + where + T: serde::de::DeserializeSeed<'de>, + { + match self.0.next() { + Some(item) => { + let (key, value) = item?; + self.1 = Some(value); + let key_de = Deserializer { value: key }; + seed.deserialize(key_de).map(Some) + } + None => Ok(None), + } + } + + fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value> + where + T: serde::de::DeserializeSeed<'de>, + { + match self.1.take() { + Some(value) => seed.deserialize(Deserializer { value }), + None => Err(serde::de::Error::custom("value is missing")), + } + } + + fn size_hint(&self) -> Option<usize> { + match self.0.size_hint() { + (lower, Some(upper)) if lower == upper => Some(upper), + _ => None, + } + } +} + +struct EnumDeserializer<'lua> { + variant: String, + value: Option<Value<'lua>>, +} + +impl<'lua, 'de> serde::de::EnumAccess<'de> for EnumDeserializer<'lua> { + type Error = Error; + type Variant = VariantDeserializer<'lua>; + + fn variant_seed<T>(self, seed: T) -> Result<(T::Value, Self::Variant)> + where + T: serde::de::DeserializeSeed<'de>, + { + let variant = self.variant.into_deserializer(); + let variant_access = VariantDeserializer { value: self.value }; + seed.deserialize(variant).map(|v| (v, variant_access)) + } +} + +struct VariantDeserializer<'lua> { + value: Option<Value<'lua>>, +} + +impl<'lua, 'de> serde::de::VariantAccess<'de> for VariantDeserializer<'lua> { + type Error = Error; + + fn unit_variant(self) -> Result<()> { + match self.value { + Some(_) => Err(serde::de::Error::invalid_type( + serde::de::Unexpected::NewtypeVariant, + &"unit variant", + )), + None => Ok(()), + } + } + + fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value> + where + T: serde::de::DeserializeSeed<'de>, + { + match self.value { + Some(value) => seed.deserialize(Deserializer { value }), + None => Err(serde::de::Error::invalid_type( + serde::de::Unexpected::UnitVariant, + &"newtype variant", + )), + } + } + + fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + match self.value { + Some(value) => serde::Deserializer::deserialize_seq(Deserializer { value }, visitor), + None => Err(serde::de::Error::invalid_type( + serde::de::Unexpected::UnitVariant, + &"tuple variant", + )), + } + } + + fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + match self.value { + Some(value) => serde::Deserializer::deserialize_map(Deserializer { value }, visitor), + None => Err(serde::de::Error::invalid_type( + serde::de::Unexpected::UnitVariant, + &"struct variant", + )), + } + } +} + +#[cfg(test)] +mod tests { + use rlua::Lua; + + use super::super::from_value; + use serde::Deserialize; + + #[test] + fn test_struct() { + #[derive(Deserialize, PartialEq, Debug)] + struct Test { + int: u32, + seq: Vec<String>, + map: std::collections::HashMap<i32, i32>, + empty: Vec<()>, + } + + let expected = Test { + int: 1, + seq: vec!["a".to_owned(), "b".to_owned()], + map: vec![(1, 2), (4, 1)].into_iter().collect(), + empty: vec![], + }; + + println!("{:?}", expected); + let lua = Lua::new(); + lua.context(|lua| { + let value = lua + .load( + r#" + a = {} + a.int = 1 + a.seq = {"a", "b"} + a.map = {2, [4]=1} + a.empty = {} + return a + "#, + ) + .eval() + .unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + }); + } + + #[test] + fn test_tuple() { + #[derive(Deserialize, PartialEq, Debug)] + struct Rgb(u8, u8, u8); + + let lua = Lua::new(); + lua.context(|lua| { + let expected = Rgb(1, 2, 3); + let value = lua + .load( + r#" + a = {1, 2, 3} + return a + "#, + ) + .eval() + .unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + + let expected = (1, 2, 3); + let value = lua + .load( + r#" + a = {1, 2, 3} + return a + "#, + ) + .eval() + .unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + }); + } + + #[test] + fn test_enum() { + #[derive(Deserialize, PartialEq, Debug)] + enum E { + Unit, + Newtype(u32), + Tuple(u32, u32), + Struct { a: u32 }, + } + + let lua = Lua::new(); + lua.context(|lua| { + let expected = E::Unit; + let value = lua + .load( + r#" + return "Unit" + "#, + ) + .eval() + .unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + + let expected = E::Newtype(1); + let value = lua + .load( + r#" + a = {} + a["Newtype"] = 1 + return a + "#, + ) + .eval() + .unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + + let expected = E::Tuple(1, 2); + let value = lua + .load( + r#" + a = {} + a["Tuple"] = {1, 2} + return a + "#, + ) + .eval() + .unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + + let expected = E::Struct { a: 1 }; + let value = lua + .load( + r#" + a = {} + a["Struct"] = {} + a["Struct"]["a"] = 1 + return a + "#, + ) + .eval() + .unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + }); + } +} |