aboutsummaryrefslogtreecommitdiff
path: root/src/lua.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/lua.rs')
-rw-r--r--src/lua.rs103
1 files changed, 103 insertions, 0 deletions
diff --git a/src/lua.rs b/src/lua.rs
new file mode 100644
index 0000000..d9f1511
--- /dev/null
+++ b/src/lua.rs
@@ -0,0 +1,103 @@
+use std::fmt::Display;
+use std::path::Path;
+use std::str::from_utf8;
+
+use rlua::Function;
+use rlua::HookTriggers;
+use rlua::Lua;
+use rlua::StdLib;
+use rlua::Table;
+
+use crate::Context;
+
+pub struct Script<'a> {
+ pub lua_module_name: &'a str,
+ input: &'a str,
+}
+
+pub fn parse_shebang(text: &str) -> Option<Script> {
+ if let Some(rest) = text.strip_prefix("#!") {
+ if let Some((lua_module_name, input)) = rest.split_once('\n') {
+ return Some(Script {
+ lua_module_name,
+ input,
+ });
+ }
+ }
+ None
+}
+
+pub enum ScriptError {
+ ModuleNotFound,
+ ModuleNotUtf8,
+ LuaError(rlua::Error),
+}
+
+#[derive(Debug)]
+struct TimeOutError;
+
+impl Display for TimeOutError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.write_str("execution took too long")
+ }
+}
+
+impl std::error::Error for TimeOutError {}
+
+impl<'a> Script<'a> {
+ pub fn module_path(&self) -> String {
+ format!("bin/{}.lua", self.lua_module_name)
+ }
+
+ pub fn run(&self, ctx: &Context) -> Result<String, ScriptError> {
+ let filename = self.module_path();
+
+ let lua_entr = ctx
+ .branch_head()
+ .unwrap()
+ .tree()
+ .and_then(|tree| tree.get_path(Path::new(&filename)))
+ .map_err(|_| ScriptError::ModuleNotFound)?;
+
+ let lua_blob = ctx.repo.find_blob(lua_entr.id()).unwrap();
+ let lua_code = from_utf8(lua_blob.content()).map_err(|_| ScriptError::ModuleNotUtf8)?;
+
+ let lua = Lua::new_with(StdLib::ALL_NO_DEBUG - StdLib::IO - StdLib::OS - StdLib::PACKAGE);
+ lua.set_hook(
+ HookTriggers {
+ every_nth_instruction: Some(10_000),
+ ..Default::default()
+ },
+ |_ctx, _debug| Err(rlua::Error::external(TimeOutError)),
+ );
+ lua.context(|ctx| {
+ ctx.globals()
+ .raw_set(
+ "gitpad",
+ ctx.load(include_str!("static/api.lua"))
+ .eval::<Table>()
+ .expect("error in api.lua"),
+ )
+ .unwrap();
+
+ let module: Table = ctx.load(lua_code).eval().map_err(ScriptError::LuaError)?;
+ let view: Function = module.get("view").map_err(ScriptError::LuaError)?;
+
+ view.call::<_, String>(self.input)
+ .map_err(ScriptError::LuaError)
+ })
+ }
+}
+
+impl Display for ScriptError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ ScriptError::ModuleNotFound => write!(f, "module not found"),
+ ScriptError::ModuleNotUtf8 => write!(f, "module not valid UTF-8"),
+ ScriptError::LuaError(rlua::Error::CallbackError { cause, .. }) => {
+ write!(f, "{}", cause)
+ }
+ ScriptError::LuaError(err) => write!(f, "{}", err),
+ }
+ }
+}