diff options
Diffstat (limited to 'src/lua.rs')
-rw-r--r-- | src/lua.rs | 88 |
1 files changed, 45 insertions, 43 deletions
@@ -1,5 +1,4 @@ use std::fmt::Display; -use std::ops::Index; use std::path::Path; use std::str::from_utf8; use std::sync::Arc; @@ -7,7 +6,6 @@ use std::sync::Arc; use rlua::Function; use rlua::HookTriggers; use rlua::Lua; -use rlua::MetaMethod; use rlua::StdLib; use rlua::Table; @@ -49,6 +47,46 @@ impl Display for TimeOutError { impl std::error::Error for TimeOutError {} +pub fn lua_context<A>(run: impl FnOnce(rlua::Context) -> rlua::Result<A>) -> rlua::Result<A> { + let lua = Lua::new_with(StdLib::ALL_NO_DEBUG - StdLib::IO - StdLib::OS - StdLib::PACKAGE); + lua.set_hook( + HookTriggers { + every_nth_instruction: Some(10_000), + ..Default::default() + }, + |_ctx, _debug| Err(rlua::Error::external(TimeOutError)), + ); + + lua.context(|ctx| { + ctx.globals() + .raw_set( + "gitpad", + ctx.load(include_str!("static/api.lua")) + .eval::<Table>() + .expect("error in api.lua"), + ) + .unwrap(); + + ctx.globals() + .get::<_, Table>("gitpad") + .unwrap() + .set( + "decode_toml", + ctx.create_function(|_ctx, text: String| { + toml::from_str::<toml::Value>(&text) + .map_err(|e| rlua::Error::ExternalError(Arc::new(e))) + .and_then(|v| { + serde::to_value(_ctx, v) + .map_err(|e| rlua::Error::ExternalError(Arc::new(e))) + }) + })?, + ) + .unwrap(); + + run(ctx) + }) +} + impl<'a> Script<'a> { pub fn module_path(&self) -> String { format!("bin/{}.lua", self.lua_module_name) @@ -66,48 +104,12 @@ impl<'a> Script<'a> { let lua_blob = ctx.repo.find_blob(lua_entr.id()).unwrap(); let lua_code = from_utf8(lua_blob.content()).map_err(|_| ScriptError::ModuleNotUtf8)?; - - let lua = Lua::new_with(StdLib::ALL_NO_DEBUG - StdLib::IO - StdLib::OS - StdLib::PACKAGE); - lua.set_hook( - HookTriggers { - every_nth_instruction: Some(10_000), - ..Default::default() - }, - |_ctx, _debug| Err(rlua::Error::external(TimeOutError)), - ); - lua.context(|ctx| { - ctx.globals() - .raw_set( - "gitpad", - ctx.load(include_str!("static/api.lua")) - .eval::<Table>() - .expect("error in api.lua"), - ) - .unwrap(); - - ctx.globals() - .get::<_, Table>("gitpad") - .unwrap() - .set( - "decode_toml", - ctx.create_function(|_ctx, text: String| { - toml::from_str::<toml::Value>(&text) - .map_err(|e| rlua::Error::ExternalError(Arc::new(e))) - .and_then(|v| { - serde::to_value(_ctx, v) - .map_err(|e| rlua::Error::ExternalError(Arc::new(e))) - }) - }) - .map_err(ScriptError::LuaError)?, - ) - .unwrap(); - - let module: Table = ctx.load(lua_code).eval().map_err(ScriptError::LuaError)?; - let view: Function = module.get("view").map_err(ScriptError::LuaError)?; - - view.call::<_, String>(self.input) - .map_err(ScriptError::LuaError) + lua_context(|ctx| { + let module: Table = ctx.load(lua_code).eval()?; + let view: Function = module.get("view")?; + view.call(self.input) }) + .map_err(ScriptError::LuaError) } } |