use std::fmt::Display;
use std::path::Path;
use std::str::from_utf8;
use std::sync::Arc;
use rlua::Function;
use rlua::HookTriggers;
use rlua::Lua;
use rlua::StdLib;
use rlua::Table;
use crate::Context;
use self::template::ArgIndex;
mod serde;
pub mod template;
pub enum ScriptError {
ModuleNotFound,
ModuleNotUtf8,
LuaError(rlua::Error),
}
#[derive(Debug)]
struct TimeOutError;
impl Display for TimeOutError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("execution took too long")
}
}
impl std::error::Error for TimeOutError {}
pub fn lua_context(run: impl FnOnce(rlua::Context) -> rlua::Result) -> rlua::Result {
let lua = Lua::new_with(StdLib::ALL_NO_DEBUG - StdLib::IO - StdLib::OS - StdLib::PACKAGE);
lua.set_hook(
HookTriggers {
every_nth_instruction: Some(10_000),
..Default::default()
},
|_ctx, _debug| Err(rlua::Error::external(TimeOutError)),
);
lua.context(|ctx| {
let api: Table = ctx
.load(include_str!("static/api.lua"))
.eval::
()
.expect("error in api.lua");
api.set(
"decode_toml",
ctx.create_function(|_ctx, text: String| {
toml::from_str::(&text)
.map_err(|e| rlua::Error::ExternalError(Arc::new(e)))
.and_then(|v| {
serde::to_value(_ctx, v)
.map_err(|e| rlua::Error::ExternalError(Arc::new(e)))
})
})?,
)?;
ctx.globals().raw_set("gitpad", api).unwrap();
run(ctx)
})
}
fn module_path(module_name: &str) -> String {
format!("modules/{}.lua", module_name)
}
pub struct ModuleFunction<'a> {
pub module_name: &'a str,
pub function_name: &'a str,
}
pub fn call(
modfn: &ModuleFunction,
args: impl Iterator- ,
ctx: &Context,
) -> Result {
let filename = module_path(modfn.module_name);
let lua_entr = ctx
.branch_head()
.unwrap()
.tree()
.and_then(|tree| tree.get_path(Path::new(&filename)))
.map_err(|_| ScriptError::ModuleNotFound)?;
let lua_blob = ctx.repo.find_blob(lua_entr.id()).unwrap();
let lua_code = from_utf8(lua_blob.content()).map_err(|_| ScriptError::ModuleNotUtf8)?;
lua_context(|ctx| {
let module: Table = ctx.load(lua_code).eval()?;
let view: Function = module.get(modfn.function_name)?;
let lua_args = ctx.create_table()?;
for (idx, val) in args {
match idx {
ArgIndex::Str(s) => lua_args.set(s, val)?,
ArgIndex::Num(n) => lua_args.set(n, val)?,
}
}
view.call(lua_args)
})
.map_err(ScriptError::LuaError)
}
impl Display for ScriptError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ScriptError::ModuleNotFound => write!(f, "module not found"),
ScriptError::ModuleNotUtf8 => write!(f, "module not valid UTF-8"),
ScriptError::LuaError(rlua::Error::CallbackError { cause, .. }) => {
write!(f, "{}", cause)
}
ScriptError::LuaError(err) => write!(f, "{}", err),
}
}
}