use std::fmt::Display; use std::path::Path; use std::str::from_utf8; use percent_encoding::utf8_percent_encode; use rlua::Function; use rlua::HookTriggers; use rlua::Lua; use rlua::StdLib; use rlua::Table; use crate::Context; use self::template::ArgIndex; pub mod template; pub enum ScriptError { ModuleNotFound, ModuleNotUtf8, LuaError(rlua::Error), } #[derive(Debug)] struct TimeOutError; impl Display for TimeOutError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str("execution took too long") } } impl std::error::Error for TimeOutError {} pub fn lua_context(run: impl FnOnce(rlua::Context) -> rlua::Result) -> rlua::Result { let lua = Lua::new_with(StdLib::ALL_NO_DEBUG - StdLib::IO - StdLib::OS - StdLib::PACKAGE); lua.set_hook( HookTriggers { every_nth_instruction: Some(10_000), ..Default::default() }, |_ctx, _debug| Err(rlua::Error::external(TimeOutError)), ); lua.context(|ctx| { let api: Table = ctx .load(include_str!("static/api.lua")) .eval::() .expect("error in api.lua"); api.set( "encode_uri_component", ctx.create_function(|_ctx, text: String| { Ok(utf8_percent_encode(&text, percent_encoding::NON_ALPHANUMERIC).to_string()) })?, )?; ctx.globals().raw_set("gitpad", api).unwrap(); run(ctx) }) } fn module_path(module_name: &str) -> String { format!("modules/{}.lua", module_name) } pub struct ModuleFunction<'a> { pub module_name: &'a str, pub function_name: &'a str, } pub fn call( modfn: &ModuleFunction, args: impl Iterator, ctx: &Context, ) -> Result { let filename = module_path(modfn.module_name); let lua_entr = ctx .branch_head() .unwrap() .tree() .and_then(|tree| tree.get_path(Path::new(&filename))) .map_err(|_| ScriptError::ModuleNotFound)?; let lua_blob = ctx.repo.find_blob(lua_entr.id()).unwrap(); let lua_code = from_utf8(lua_blob.content()).map_err(|_| ScriptError::ModuleNotUtf8)?; lua_context(|ctx| { let module: Table = ctx.load(lua_code).eval()?; let view: Function = module.get(modfn.function_name)?; let lua_args = ctx.create_table()?; for (idx, val) in args { match idx { ArgIndex::Str(s) => lua_args.set(s, val)?, ArgIndex::Num(n) => lua_args.set(n, val)?, } } view.call(lua_args) }) .map_err(ScriptError::LuaError) } impl Display for ScriptError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ScriptError::ModuleNotFound => write!(f, "module not found"), ScriptError::ModuleNotUtf8 => write!(f, "module not valid UTF-8"), ScriptError::LuaError(rlua::Error::CallbackError { cause, .. }) => { write!(f, "{}", cause) } ScriptError::LuaError(err) => write!(f, "{}", err), } } }