diff options
author | Martin Fischer <martin@push-f.com> | 2022-08-13 03:56:55 +0200 |
---|---|---|
committer | Martin Fischer <martin@push-f.com> | 2022-08-14 00:33:23 +0200 |
commit | 36e98c1b135f07ef9a5eec046b8f0fd6d93534d4 (patch) | |
tree | da9c93ceb0e147695e1992c4d1678964af3be4bf | |
parent | dc20e1df60c1e4e81d1e16e8f177a1c6956966b7 (diff) |
add gitpad.decode_toml lua method
We are vendoring the rlua_serde crate because it currently depends on
rlua 0.17, which is outdated and my attempts to contact the crate author
were bounced by Yandex for somehow looking like spam.
-rw-r--r-- | Cargo.lock | 1 | ||||
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | src/lua.rs | 22 | ||||
-rw-r--r-- | src/lua/serde/de.rs | 436 | ||||
-rw-r--r-- | src/lua/serde/error.rs | 58 | ||||
-rw-r--r-- | src/lua/serde/mod.rs | 57 | ||||
-rw-r--r-- | src/lua/serde/ser.rs | 481 |
7 files changed, 1056 insertions, 0 deletions
@@ -236,6 +236,7 @@ dependencies = [ "pulldown-cmark", "rlua", "serde", + "serde_derive", "sputnik", "tempdir", "tokio", @@ -45,3 +45,4 @@ hyperlocal = { version = "0.8", features = ["server"] } [dev-dependencies] tempdir = "0.3.7" +serde_derive = "1.0" @@ -1,15 +1,20 @@ use std::fmt::Display; +use std::ops::Index; use std::path::Path; use std::str::from_utf8; +use std::sync::Arc; use rlua::Function; use rlua::HookTriggers; use rlua::Lua; +use rlua::MetaMethod; use rlua::StdLib; use rlua::Table; use crate::Context; +mod serde; + pub struct Script<'a> { pub lua_module_name: &'a str, input: &'a str, @@ -80,6 +85,23 @@ impl<'a> Script<'a> { ) .unwrap(); + ctx.globals() + .get::<_, Table>("gitpad") + .unwrap() + .set( + "decode_toml", + ctx.create_function(|_ctx, text: String| { + toml::from_str::<toml::Value>(&text) + .map_err(|e| rlua::Error::ExternalError(Arc::new(e))) + .and_then(|v| { + serde::to_value(_ctx, v) + .map_err(|e| rlua::Error::ExternalError(Arc::new(e))) + }) + }) + .map_err(ScriptError::LuaError)?, + ) + .unwrap(); + let module: Table = ctx.load(lua_code).eval().map_err(ScriptError::LuaError)?; let view: Function = module.get("view").map_err(ScriptError::LuaError)?; diff --git a/src/lua/serde/de.rs b/src/lua/serde/de.rs new file mode 100644 index 0000000..4d78a83 --- /dev/null +++ b/src/lua/serde/de.rs @@ -0,0 +1,436 @@ +// vendored from https://github.com/zrkn/rlua_serde because it dependend on an outdated rlua version +// Copyright (c) 2018 zrkn <zrkn@email.su>, licensed under the MIT License + +use serde::de::IntoDeserializer; +use serde::{self, forward_to_deserialize_any}; + +use rlua::{TablePairs, TableSequence, Value}; + +use super::error::{Error, Result}; + +pub struct Deserializer<'lua> { + pub value: Value<'lua>, +} + +impl<'lua, 'de> serde::Deserializer<'de> for Deserializer<'lua> { + type Error = Error; + + #[inline] + fn deserialize_any<V>(self, visitor: V) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + match self.value { + Value::Nil => visitor.visit_unit(), + Value::Boolean(v) => visitor.visit_bool(v), + Value::Integer(v) => visitor.visit_i64(v), + Value::Number(v) => visitor.visit_f64(v), + Value::String(v) => visitor.visit_str(v.to_str()?), + Value::Table(v) => { + let len = v.len()? as usize; + let mut deserializer = MapDeserializer(v.pairs(), None); + let map = visitor.visit_map(&mut deserializer)?; + let remaining = deserializer.0.count(); + if remaining == 0 { + Ok(map) + } else { + Err(serde::de::Error::invalid_length( + len, + &"fewer elements in array", + )) + } + } + _ => Err(serde::de::Error::custom("invalid value type")), + } + } + + #[inline] + fn deserialize_option<V>(self, visitor: V) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + match self.value { + Value::Nil => visitor.visit_none(), + _ => visitor.visit_some(self), + } + } + + #[inline] + fn deserialize_enum<V>( + self, + _name: &str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + let (variant, value) = match self.value { + Value::Table(value) => { + let mut iter = value.pairs::<String, Value>(); + let (variant, value) = match iter.next() { + Some(v) => v?, + None => { + return Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Map, + &"map with a single key", + )) + } + }; + + if iter.next().is_some() { + return Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Map, + &"map with a single key", + )); + } + (variant, Some(value)) + } + Value::String(variant) => (variant.to_str()?.to_owned(), None), + _ => return Err(serde::de::Error::custom("bad enum value")), + }; + + visitor.visit_enum(EnumDeserializer { variant, value }) + } + + #[inline] + fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + match self.value { + Value::Table(v) => { + let len = v.len()? as usize; + let mut deserializer = SeqDeserializer(v.sequence_values()); + let seq = visitor.visit_seq(&mut deserializer)?; + let remaining = deserializer.0.count(); + if remaining == 0 { + Ok(seq) + } else { + Err(serde::de::Error::invalid_length( + len, + &"fewer elements in array", + )) + } + } + _ => Err(serde::de::Error::custom("invalid value type")), + } + } + + #[inline] + fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + self.deserialize_seq(visitor) + } + + #[inline] + fn deserialize_tuple_struct<V>( + self, + _name: &'static str, + _len: usize, + visitor: V, + ) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + self.deserialize_seq(visitor) + } + + forward_to_deserialize_any! { + bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes + byte_buf unit unit_struct newtype_struct + map struct identifier ignored_any + } +} + +struct SeqDeserializer<'lua>(TableSequence<'lua, Value<'lua>>); + +impl<'lua, 'de> serde::de::SeqAccess<'de> for SeqDeserializer<'lua> { + type Error = Error; + + fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>> + where + T: serde::de::DeserializeSeed<'de>, + { + match self.0.next() { + Some(value) => seed.deserialize(Deserializer { value: value? }).map(Some), + None => Ok(None), + } + } + + fn size_hint(&self) -> Option<usize> { + match self.0.size_hint() { + (lower, Some(upper)) if lower == upper => Some(upper), + _ => None, + } + } +} + +struct MapDeserializer<'lua>( + TablePairs<'lua, Value<'lua>, Value<'lua>>, + Option<Value<'lua>>, +); + +impl<'lua, 'de> serde::de::MapAccess<'de> for MapDeserializer<'lua> { + type Error = Error; + + fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>> + where + T: serde::de::DeserializeSeed<'de>, + { + match self.0.next() { + Some(item) => { + let (key, value) = item?; + self.1 = Some(value); + let key_de = Deserializer { value: key }; + seed.deserialize(key_de).map(Some) + } + None => Ok(None), + } + } + + fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value> + where + T: serde::de::DeserializeSeed<'de>, + { + match self.1.take() { + Some(value) => seed.deserialize(Deserializer { value }), + None => Err(serde::de::Error::custom("value is missing")), + } + } + + fn size_hint(&self) -> Option<usize> { + match self.0.size_hint() { + (lower, Some(upper)) if lower == upper => Some(upper), + _ => None, + } + } +} + +struct EnumDeserializer<'lua> { + variant: String, + value: Option<Value<'lua>>, +} + +impl<'lua, 'de> serde::de::EnumAccess<'de> for EnumDeserializer<'lua> { + type Error = Error; + type Variant = VariantDeserializer<'lua>; + + fn variant_seed<T>(self, seed: T) -> Result<(T::Value, Self::Variant)> + where + T: serde::de::DeserializeSeed<'de>, + { + let variant = self.variant.into_deserializer(); + let variant_access = VariantDeserializer { value: self.value }; + seed.deserialize(variant).map(|v| (v, variant_access)) + } +} + +struct VariantDeserializer<'lua> { + value: Option<Value<'lua>>, +} + +impl<'lua, 'de> serde::de::VariantAccess<'de> for VariantDeserializer<'lua> { + type Error = Error; + + fn unit_variant(self) -> Result<()> { + match self.value { + Some(_) => Err(serde::de::Error::invalid_type( + serde::de::Unexpected::NewtypeVariant, + &"unit variant", + )), + None => Ok(()), + } + } + + fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value> + where + T: serde::de::DeserializeSeed<'de>, + { + match self.value { + Some(value) => seed.deserialize(Deserializer { value }), + None => Err(serde::de::Error::invalid_type( + serde::de::Unexpected::UnitVariant, + &"newtype variant", + )), + } + } + + fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + match self.value { + Some(value) => serde::Deserializer::deserialize_seq(Deserializer { value }, visitor), + None => Err(serde::de::Error::invalid_type( + serde::de::Unexpected::UnitVariant, + &"tuple variant", + )), + } + } + + fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value> + where + V: serde::de::Visitor<'de>, + { + match self.value { + Some(value) => serde::Deserializer::deserialize_map(Deserializer { value }, visitor), + None => Err(serde::de::Error::invalid_type( + serde::de::Unexpected::UnitVariant, + &"struct variant", + )), + } + } +} + +#[cfg(test)] +mod tests { + use rlua::Lua; + + use super::super::from_value; + use serde::Deserialize; + + #[test] + fn test_struct() { + #[derive(Deserialize, PartialEq, Debug)] + struct Test { + int: u32, + seq: Vec<String>, + map: std::collections::HashMap<i32, i32>, + empty: Vec<()>, + } + + let expected = Test { + int: 1, + seq: vec!["a".to_owned(), "b".to_owned()], + map: vec![(1, 2), (4, 1)].into_iter().collect(), + empty: vec![], + }; + + println!("{:?}", expected); + let lua = Lua::new(); + lua.context(|lua| { + let value = lua + .load( + r#" + a = {} + a.int = 1 + a.seq = {"a", "b"} + a.map = {2, [4]=1} + a.empty = {} + return a + "#, + ) + .eval() + .unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + }); + } + + #[test] + fn test_tuple() { + #[derive(Deserialize, PartialEq, Debug)] + struct Rgb(u8, u8, u8); + + let lua = Lua::new(); + lua.context(|lua| { + let expected = Rgb(1, 2, 3); + let value = lua + .load( + r#" + a = {1, 2, 3} + return a + "#, + ) + .eval() + .unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + + let expected = (1, 2, 3); + let value = lua + .load( + r#" + a = {1, 2, 3} + return a + "#, + ) + .eval() + .unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + }); + } + + #[test] + fn test_enum() { + #[derive(Deserialize, PartialEq, Debug)] + enum E { + Unit, + Newtype(u32), + Tuple(u32, u32), + Struct { a: u32 }, + } + + let lua = Lua::new(); + lua.context(|lua| { + let expected = E::Unit; + let value = lua + .load( + r#" + return "Unit" + "#, + ) + .eval() + .unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + + let expected = E::Newtype(1); + let value = lua + .load( + r#" + a = {} + a["Newtype"] = 1 + return a + "#, + ) + .eval() + .unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + + let expected = E::Tuple(1, 2); + let value = lua + .load( + r#" + a = {} + a["Tuple"] = {1, 2} + return a + "#, + ) + .eval() + .unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + + let expected = E::Struct { a: 1 }; + let value = lua + .load( + r#" + a = {} + a["Struct"] = {} + a["Struct"]["a"] = 1 + return a + "#, + ) + .eval() + .unwrap(); + let got = from_value(value).unwrap(); + assert_eq!(expected, got); + }); + } +} diff --git a/src/lua/serde/error.rs b/src/lua/serde/error.rs new file mode 100644 index 0000000..8d5debf --- /dev/null +++ b/src/lua/serde/error.rs @@ -0,0 +1,58 @@ +// vendored from https://github.com/zrkn/rlua_serde because it dependend on an outdated rlua version +// Copyright (c) 2018 zrkn <zrkn@email.su>, licensed under the MIT License + +use std::error::Error as StdError; +use std::fmt; +use std::result::Result as StdResult; + +use rlua::Error as LuaError; +use serde; + +#[derive(Debug)] +pub struct Error(LuaError); + +pub type Result<T> = StdResult<T, Error>; + +impl From<LuaError> for Error { + fn from(err: LuaError) -> Error { + Error(err) + } +} + +impl From<Error> for LuaError { + fn from(err: Error) -> LuaError { + err.0 + } +} + +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(fmt) + } +} + +impl StdError for Error { + fn description(&self) -> &'static str { + "Failed to serialize to Lua value" + } +} + +impl serde::ser::Error for Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + Error(LuaError::ToLuaConversionError { + from: "serialize", + to: "value", + message: Some(format!("{}", msg)), + }) + } +} + +impl serde::de::Error for Error { + fn custom<T: fmt::Display>(msg: T) -> Self { + Error(LuaError::FromLuaConversionError { + from: "value", + to: "deserialize", + message: Some(format!("{}", msg)), + }) + } +} diff --git a/src/lua/serde/mod.rs b/src/lua/serde/mod.rs new file mode 100644 index 0000000..82a928a --- /dev/null +++ b/src/lua/serde/mod.rs @@ -0,0 +1,57 @@ +// vendored from https://github.com/zrkn/rlua_serde because it dependend on an outdated rlua version +// Copyright (c) 2018 zrkn <zrkn@email.su>, licensed under the MIT License + +//! This crate allows you to serialize and deserialize any type that implements +//! `serde::Serialize` and `serde::Deserialize` into/from `rlua::Value`. +//! +//! Implementation is similar to `serde_json::Value` +//! +//! Example usage: +//! +//! ```rust +//! extern crate serde; +//! #[macro_use] +//! extern crate serde_derive; +//! extern crate rlua; +//! extern crate rlua_serde; +//! +//! fn main() { +//! #[derive(Serialize, Deserialize)] +//! struct Foo { +//! bar: u32, +//! baz: Vec<String>, +//! } +//! +//! let lua = rlua::Lua::new(); +//! lua.context(|lua| { +//! let foo = Foo { +//! bar: 42, +//! baz: vec![String::from("fizz"), String::from("buzz")], +//! }; +//! +//! let value = rlua_serde::to_value(lua, &foo).unwrap(); +//! lua.globals().set("value", value).unwrap(); +//! lua.load( +//! r#" +//! assert(value["bar"] == 42) +//! assert(value["baz"][2] == "buzz") +//! "#).exec().unwrap(); +//! }); +//! } +//! ``` + +pub mod de; +pub mod error; +pub mod ser; + +use rlua::{Context, Error, Value}; + +pub fn to_value<T: serde::Serialize>(lua: Context, t: T) -> Result<Value, Error> { + let serializer = ser::Serializer { lua }; + Ok(t.serialize(serializer)?) +} + +pub fn from_value<'de, T: serde::Deserialize<'de>>(value: Value<'de>) -> Result<T, Error> { + let deserializer = de::Deserializer { value }; + Ok(T::deserialize(deserializer)?) +} diff --git a/src/lua/serde/ser.rs b/src/lua/serde/ser.rs new file mode 100644 index 0000000..8af978a --- /dev/null +++ b/src/lua/serde/ser.rs @@ -0,0 +1,481 @@ +// vendored from https://github.com/zrkn/rlua_serde because it dependend on an outdated rlua version +// Copyright (c) 2018 zrkn <zrkn@email.su>, licensed under the MIT License + +use serde; + +use rlua::{Context, String as LuaString, Table, Value}; + +use super::error::{Error, Result}; +use super::to_value; + +pub struct Serializer<'lua> { + pub lua: Context<'lua>, +} + +impl<'lua> serde::Serializer for Serializer<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + type SerializeSeq = SerializeVec<'lua>; + type SerializeTuple = SerializeVec<'lua>; + type SerializeTupleStruct = SerializeVec<'lua>; + type SerializeTupleVariant = SerializeTupleVariant<'lua>; + type SerializeMap = SerializeMap<'lua>; + type SerializeStruct = SerializeMap<'lua>; + type SerializeStructVariant = SerializeStructVariant<'lua>; + + #[inline] + fn serialize_bool(self, value: bool) -> Result<Value<'lua>> { + Ok(Value::Boolean(value)) + } + + #[inline] + fn serialize_i8(self, value: i8) -> Result<Value<'lua>> { + self.serialize_i64(i64::from(value)) + } + + #[inline] + fn serialize_i16(self, value: i16) -> Result<Value<'lua>> { + self.serialize_i64(i64::from(value)) + } + + #[inline] + fn serialize_i32(self, value: i32) -> Result<Value<'lua>> { + self.serialize_i64(i64::from(value)) + } + + #[inline] + fn serialize_i64(self, value: i64) -> Result<Value<'lua>> { + Ok(Value::Integer(value)) + } + + #[inline] + fn serialize_u8(self, value: u8) -> Result<Value<'lua>> { + self.serialize_i64(i64::from(value)) + } + + #[inline] + fn serialize_u16(self, value: u16) -> Result<Value<'lua>> { + self.serialize_i64(i64::from(value)) + } + + #[inline] + fn serialize_u32(self, value: u32) -> Result<Value<'lua>> { + self.serialize_i64(i64::from(value)) + } + + #[inline] + fn serialize_u64(self, value: u64) -> Result<Value<'lua>> { + self.serialize_i64(value as i64) + } + + #[inline] + fn serialize_f32(self, value: f32) -> Result<Value<'lua>> { + self.serialize_f64(f64::from(value)) + } + + #[inline] + fn serialize_f64(self, value: f64) -> Result<Value<'lua>> { + Ok(Value::Number(value)) + } + + #[inline] + fn serialize_char(self, value: char) -> Result<Value<'lua>> { + let mut s = String::new(); + s.push(value); + self.serialize_str(&s) + } + + #[inline] + fn serialize_str(self, value: &str) -> Result<Value<'lua>> { + Ok(Value::String(self.lua.create_string(value)?)) + } + + #[inline] + fn serialize_bytes(self, value: &[u8]) -> Result<Value<'lua>> { + Ok(Value::Table( + self.lua.create_sequence_from(value.iter().cloned())?, + )) + } + + #[inline] + fn serialize_unit(self) -> Result<Value<'lua>> { + Ok(Value::Nil) + } + + #[inline] + fn serialize_unit_struct(self, _name: &'static str) -> Result<Value<'lua>> { + self.serialize_unit() + } + + #[inline] + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result<Value<'lua>> { + self.serialize_str(variant) + } + + #[inline] + fn serialize_newtype_struct<T>(self, _name: &'static str, value: &T) -> Result<Value<'lua>> + where + T: ?Sized + serde::Serialize, + { + value.serialize(self) + } + + fn serialize_newtype_variant<T>( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result<Value<'lua>> + where + T: ?Sized + serde::Serialize, + { + let table = self.lua.create_table()?; + let variant = self.lua.create_string(variant)?; + let value = to_value(self.lua, value)?; + table.set(variant, value)?; + Ok(Value::Table(table)) + } + + #[inline] + fn serialize_none(self) -> Result<Value<'lua>> { + self.serialize_unit() + } + + #[inline] + fn serialize_some<T>(self, value: &T) -> Result<Value<'lua>> + where + T: ?Sized + serde::Serialize, + { + value.serialize(self) + } + + fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq> { + let table = self.lua.create_table()?; + Ok(SerializeVec { + lua: self.lua, + idx: 1, + table, + }) + } + + fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple> { + self.serialize_seq(Some(len)) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + len: usize, + ) -> Result<Self::SerializeTupleStruct> { + self.serialize_seq(Some(len)) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result<Self::SerializeTupleVariant> { + let name = self.lua.create_string(variant)?; + let table = self.lua.create_table()?; + Ok(SerializeTupleVariant { + lua: self.lua, + idx: 1, + name, + table, + }) + } + + fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> { + let table = self.lua.create_table()?; + Ok(SerializeMap { + lua: self.lua, + next_key: None, + table, + }) + } + + fn serialize_struct(self, _name: &'static str, len: usize) -> Result<Self::SerializeStruct> { + self.serialize_map(Some(len)) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result<Self::SerializeStructVariant> { + let name = self.lua.create_string(variant)?; + let table = self.lua.create_table()?; + Ok(SerializeStructVariant { + lua: self.lua, + name, + table, + }) + } +} + +pub struct SerializeVec<'lua> { + lua: Context<'lua>, + table: Table<'lua>, + idx: u64, +} + +impl<'lua> serde::ser::SerializeSeq for SerializeVec<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + fn serialize_element<T>(&mut self, value: &T) -> Result<()> + where + T: ?Sized + serde::Serialize, + { + self.table.set(self.idx, to_value(self.lua, value)?)?; + self.idx += 1; + Ok(()) + } + + fn end(self) -> Result<Value<'lua>> { + Ok(Value::Table(self.table)) + } +} + +impl<'lua> serde::ser::SerializeTuple for SerializeVec<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + fn serialize_element<T>(&mut self, value: &T) -> Result<()> + where + T: ?Sized + serde::Serialize, + { + serde::ser::SerializeSeq::serialize_element(self, value) + } + + fn end(self) -> Result<Value<'lua>> { + serde::ser::SerializeSeq::end(self) + } +} + +impl<'lua> serde::ser::SerializeTupleStruct for SerializeVec<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + fn serialize_field<T>(&mut self, value: &T) -> Result<()> + where + T: ?Sized + serde::Serialize, + { + serde::ser::SerializeSeq::serialize_element(self, value) + } + + fn end(self) -> Result<Value<'lua>> { + serde::ser::SerializeSeq::end(self) + } +} + +pub struct SerializeTupleVariant<'lua> { + lua: Context<'lua>, + name: LuaString<'lua>, + table: Table<'lua>, + idx: u64, +} + +impl<'lua> serde::ser::SerializeTupleVariant for SerializeTupleVariant<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + fn serialize_field<T>(&mut self, value: &T) -> Result<()> + where + T: ?Sized + serde::Serialize, + { + self.table.set(self.idx, to_value(self.lua, value)?)?; + self.idx += 1; + Ok(()) + } + + fn end(self) -> Result<Value<'lua>> { + let table = self.lua.create_table()?; + table.set(self.name, self.table)?; + Ok(Value::Table(table)) + } +} + +pub struct SerializeMap<'lua> { + lua: Context<'lua>, + table: Table<'lua>, + next_key: Option<Value<'lua>>, +} + +impl<'lua> serde::ser::SerializeMap for SerializeMap<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + fn serialize_key<T>(&mut self, key: &T) -> Result<()> + where + T: ?Sized + serde::Serialize, + { + self.next_key = Some(to_value(self.lua, key)?); + Ok(()) + } + + fn serialize_value<T>(&mut self, value: &T) -> Result<()> + where + T: ?Sized + serde::Serialize, + { + let key = self.next_key.take(); + // Panic because this indicates a bug in the program rather than an + // expected failure. + let key = key.expect("serialize_value called before serialize_key"); + self.table.set(key, to_value(self.lua, value)?)?; + Ok(()) + } + + fn end(self) -> Result<Value<'lua>> { + Ok(Value::Table(self.table)) + } +} + +impl<'lua> serde::ser::SerializeStruct for SerializeMap<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ?Sized + serde::Serialize, + { + serde::ser::SerializeMap::serialize_key(self, key)?; + serde::ser::SerializeMap::serialize_value(self, value) + } + + fn end(self) -> Result<Value<'lua>> { + serde::ser::SerializeMap::end(self) + } +} + +pub struct SerializeStructVariant<'lua> { + lua: Context<'lua>, + name: LuaString<'lua>, + table: Table<'lua>, +} + +impl<'lua> serde::ser::SerializeStructVariant for SerializeStructVariant<'lua> { + type Ok = Value<'lua>; + type Error = Error; + + fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ?Sized + serde::Serialize, + { + self.table.set(key, to_value(self.lua, value)?)?; + Ok(()) + } + + fn end(self) -> Result<Value<'lua>> { + let table = self.lua.create_table()?; + table.set(self.name, self.table)?; + Ok(Value::Table(table)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rlua::Lua; + use serde::Serialize; + + #[test] + fn test_struct() { + #[derive(Serialize)] + struct Test { + int: u32, + seq: Vec<&'static str>, + } + + let test = Test { + int: 1, + seq: vec!["a", "b"], + }; + + let lua = Lua::new(); + lua.context(|lua| { + let value = to_value(lua, &test).unwrap(); + lua.globals().set("value", value).unwrap(); + lua.load( + r#" + assert(value["int"] == 1) + assert(value["seq"][1] == "a") + assert(value["seq"][2] == "b") + "#, + ) + .exec() + }) + .unwrap() + } + + #[test] + fn test_num() { + #[derive(Serialize)] + enum E { + Unit, + Newtype(u32), + Tuple(u32, u32), + Struct { a: u32 }, + } + + let lua = Lua::new(); + + lua.context(|lua| { + let u = E::Unit; + let value = to_value(lua, &u).unwrap(); + lua.globals().set("value", value).unwrap(); + lua.load( + r#" + assert(value == "Unit") + "#, + ) + .exec() + .unwrap(); + + let n = E::Newtype(1); + let value = to_value(lua, &n).unwrap(); + lua.globals().set("value", value).unwrap(); + lua.load( + r#" + assert(value["Newtype"] == 1) + "#, + ) + .exec() + .unwrap(); + + let t = E::Tuple(1, 2); + let value = to_value(lua, &t).unwrap(); + lua.globals().set("value", value).unwrap(); + lua.load( + r#" + assert(value["Tuple"][1] == 1) + assert(value["Tuple"][2] == 2) + "#, + ) + .exec() + .unwrap(); + + let s = E::Struct { a: 1 }; + let value = to_value(lua, &s).unwrap(); + lua.globals().set("value", value).unwrap(); + lua.load( + r#" + assert(value["Struct"]["a"] == 1) + "#, + ) + .exec() + }) + .unwrap(); + } +} |