diff options
Diffstat (limited to 'src/lua.rs')
-rw-r--r-- | src/lua.rs | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/src/lua.rs b/src/lua.rs new file mode 100644 index 0000000..d9f1511 --- /dev/null +++ b/src/lua.rs @@ -0,0 +1,103 @@ +use std::fmt::Display; +use std::path::Path; +use std::str::from_utf8; + +use rlua::Function; +use rlua::HookTriggers; +use rlua::Lua; +use rlua::StdLib; +use rlua::Table; + +use crate::Context; + +pub struct Script<'a> { + pub lua_module_name: &'a str, + input: &'a str, +} + +pub fn parse_shebang(text: &str) -> Option<Script> { + if let Some(rest) = text.strip_prefix("#!") { + if let Some((lua_module_name, input)) = rest.split_once('\n') { + return Some(Script { + lua_module_name, + input, + }); + } + } + None +} + +pub enum ScriptError { + ModuleNotFound, + ModuleNotUtf8, + LuaError(rlua::Error), +} + +#[derive(Debug)] +struct TimeOutError; + +impl Display for TimeOutError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("execution took too long") + } +} + +impl std::error::Error for TimeOutError {} + +impl<'a> Script<'a> { + pub fn module_path(&self) -> String { + format!("bin/{}.lua", self.lua_module_name) + } + + pub fn run(&self, ctx: &Context) -> Result<String, ScriptError> { + let filename = self.module_path(); + + let lua_entr = ctx + .branch_head() + .unwrap() + .tree() + .and_then(|tree| tree.get_path(Path::new(&filename))) + .map_err(|_| ScriptError::ModuleNotFound)?; + + let lua_blob = ctx.repo.find_blob(lua_entr.id()).unwrap(); + let lua_code = from_utf8(lua_blob.content()).map_err(|_| ScriptError::ModuleNotUtf8)?; + + let lua = Lua::new_with(StdLib::ALL_NO_DEBUG - StdLib::IO - StdLib::OS - StdLib::PACKAGE); + lua.set_hook( + HookTriggers { + every_nth_instruction: Some(10_000), + ..Default::default() + }, + |_ctx, _debug| Err(rlua::Error::external(TimeOutError)), + ); + lua.context(|ctx| { + ctx.globals() + .raw_set( + "gitpad", + ctx.load(include_str!("static/api.lua")) + .eval::<Table>() + .expect("error in api.lua"), + ) + .unwrap(); + + let module: Table = ctx.load(lua_code).eval().map_err(ScriptError::LuaError)?; + let view: Function = module.get("view").map_err(ScriptError::LuaError)?; + + view.call::<_, String>(self.input) + .map_err(ScriptError::LuaError) + }) + } +} + +impl Display for ScriptError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ScriptError::ModuleNotFound => write!(f, "module not found"), + ScriptError::ModuleNotUtf8 => write!(f, "module not valid UTF-8"), + ScriptError::LuaError(rlua::Error::CallbackError { cause, .. }) => { + write!(f, "{}", cause) + } + ScriptError::LuaError(err) => write!(f, "{}", err), + } + } +} |