aboutsummaryrefslogtreecommitdiff
path: root/src/lua.rs
blob: d8d82b2af326920811c047a4fe95e8c73144c7d5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use std::fmt::Display;
use std::path::Path;
use std::str::from_utf8;
use std::sync::Arc;

use rlua::Function;
use rlua::HookTriggers;
use rlua::Lua;
use rlua::StdLib;
use rlua::Table;

use crate::Context;

use self::template::ArgIndex;

mod serde;
pub mod template;

pub enum ScriptError {
    ModuleNotFound,
    ModuleNotUtf8,
    LuaError(rlua::Error),
}

#[derive(Debug)]
struct TimeOutError;

impl Display for TimeOutError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str("execution took too long")
    }
}

impl std::error::Error for TimeOutError {}

pub fn lua_context<A>(run: impl FnOnce(rlua::Context) -> rlua::Result<A>) -> rlua::Result<A> {
    let lua = Lua::new_with(StdLib::ALL_NO_DEBUG - StdLib::IO - StdLib::OS - StdLib::PACKAGE);
    lua.set_hook(
        HookTriggers {
            every_nth_instruction: Some(10_000),
            ..Default::default()
        },
        |_ctx, _debug| Err(rlua::Error::external(TimeOutError)),
    );

    lua.context(|ctx| {
        let api: Table = ctx
            .load(include_str!("static/api.lua"))
            .eval::<Table>()
            .expect("error in api.lua");

        api.set(
            "decode_toml",
            ctx.create_function(|_ctx, text: String| {
                toml::from_str::<toml::Value>(&text)
                    .map_err(|e| rlua::Error::ExternalError(Arc::new(e)))
                    .and_then(|v| {
                        serde::to_value(_ctx, v)
                            .map_err(|e| rlua::Error::ExternalError(Arc::new(e)))
                    })
            })?,
        )?;

        ctx.globals().raw_set("gitpad", api).unwrap();

        run(ctx)
    })
}

fn module_path(module_name: &str) -> String {
    format!("modules/{}.lua", module_name)
}

pub struct ModuleFunction<'a> {
    pub module_name: &'a str,
    pub function_name: &'a str,
}

pub fn call(
    modfn: &ModuleFunction,
    args: impl Iterator<Item = (ArgIndex, String)>,
    ctx: &Context,
) -> Result<String, ScriptError> {
    let filename = module_path(modfn.module_name);

    let lua_entr = ctx
        .branch_head()
        .unwrap()
        .tree()
        .and_then(|tree| tree.get_path(Path::new(&filename)))
        .map_err(|_| ScriptError::ModuleNotFound)?;

    let lua_blob = ctx.repo.find_blob(lua_entr.id()).unwrap();
    let lua_code = from_utf8(lua_blob.content()).map_err(|_| ScriptError::ModuleNotUtf8)?;

    lua_context(|ctx| {
        let module: Table = ctx.load(lua_code).eval()?;
        let view: Function = module.get(modfn.function_name)?;
        let lua_args = ctx.create_table()?;
        for (idx, val) in args {
            match idx {
                ArgIndex::Str(s) => lua_args.set(s, val)?,
                ArgIndex::Num(n) => lua_args.set(n, val)?,
            }
        }
        view.call(lua_args)
    })
    .map_err(ScriptError::LuaError)
}

impl Display for ScriptError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            ScriptError::ModuleNotFound => write!(f, "module not found"),
            ScriptError::ModuleNotUtf8 => write!(f, "module not valid UTF-8"),
            ScriptError::LuaError(rlua::Error::CallbackError { cause, .. }) => {
                write!(f, "{}", cause)
            }
            ScriptError::LuaError(err) => write!(f, "{}", err),
        }
    }
}