package plugins import ( "context" "encoding/json" "fmt" "log" "os" "path/filepath" "strings" "sync" "time" "verstak/internal/core/activity" "verstak/internal/core/files" "verstak/internal/core/nodes" "verstak/internal/core/storage" "verstak/internal/core/worklog" lua "github.com/yuin/gopher-lua" ) // CoreServices bundles core Verstak services for use by the Lua plugin API. type CoreServices struct { NodeRepo *nodes.Repository DB *storage.DB ActivitySvc *activity.Service WorklogSvc *worklog.Service FilesSvc *files.Service VaultPath string } // LuaVM wraps a gopher-lua state for a single plugin. // Each plugin gets its own isolated state; API functions use vm.Services to call core services. type LuaVM struct { L *lua.LState Plugin *Plugin Services *CoreServices mu sync.Mutex done chan struct{} // sandbox limits callTimeout time.Duration } // NewLuaVM creates a sandboxed Lua VM for a plugin. func NewLuaVM(p *Plugin) (*LuaVM, error) { vm := &LuaVM{ Plugin: p, callTimeout: 30 * time.Second, done: make(chan struct{}), } L := lua.NewState(lua.Options{ SkipOpenLibs: true, // we selectively open safe libs }) // Open only safe libraries for _, pair := range []struct { lib string fn lua.LGFunction }{ {lua.LoadLibName, lua.OpenPackage}, {lua.BaseLibName, lua.OpenBase}, {lua.TabLibName, lua.OpenTable}, {lua.StringLibName, lua.OpenString}, {lua.MathLibName, lua.OpenMath}, {lua.OsLibName, lua.OpenOs}, } { L.Push(L.NewFunction(pair.fn)) L.Push(lua.LString(pair.lib)) L.Call(1, 0) } // Disable dangerous functions for _, name := range []string{"dofile", "loadfile", "require", "module", "rawequal", "rawget", "rawset", "rawlen", "setfenv", "getfenv", "load"} { L.SetGlobal(name, lua.LNil) } // Restrict package table: remove dangerous fields if pkgTbl := L.GetGlobal("package"); pkgTbl != lua.LNil { if tbl, ok := pkgTbl.(*lua.LTable); ok { for _, name := range []string{"loadlib", "seeall", "preload", "loaders", "loaded", "path", "cpath", "config", "searchpath"} { tbl.RawSetString(name, lua.LNil) } } } // Restrict os.* to safe subset osSafe := map[string]bool{ "clock": true, "date": true, "difftime": true, "time": true, "tmpname": true, } if osTable := L.GetGlobal("os"); osTable != lua.LNil { if tbl, ok := osTable.(*lua.LTable); ok { for _, k := range []string{"execute", "exit", "remove", "rename", "setlocale", "getenv"} { tbl.RawSetString(k, lua.LNil) } // Only keep safe ones tbl.ForEach(func(k lua.LValue, v lua.LValue) { if ks, ok := k.(lua.LString); ok && !osSafe[string(ks)] { tbl.RawSet(k, lua.LNil) } }) } } // Redirect print to Go log L.SetGlobal("print", L.NewFunction(func(L *lua.LState) int { top := L.GetTop() var parts []string for i := 1; i <= top; i++ { parts = append(parts, L.Get(i).String()) } msg := strings.Join(parts, " ") log.Printf("[lua] %s", msg) return 0 })) vm.L = L registerAPI(vm) return vm, nil } // LoadScript loads and executes a Lua file from the plugin directory. func (vm *LuaVM) LoadScript(filename string) error { path := filepath.Join(vm.Plugin.Dir, filename) data, err := os.ReadFile(path) if err != nil { return fmt.Errorf("read %s: %w", filename, err) } vm.mu.Lock() defer vm.mu.Unlock() fn, err := vm.L.Load(strings.NewReader(string(data)), filename) if err != nil { return fmt.Errorf("load %s: %w", filename, err) } vm.L.Push(fn) _, err = vm.callWithTimeout(0) return err } // CallHook calls a Lua function by name with optional args (no return value expected). func (vm *LuaVM) CallHook(name string, args ...lua.LValue) error { _, err := vm.CallHookWithResult(name, args...) return err } // CallHookWithResult calls a Lua function by name and returns its first return value. func (vm *LuaVM) CallHookWithResult(name string, args ...lua.LValue) (lua.LValue, error) { vm.mu.Lock() defer vm.mu.Unlock() fn := vm.L.GetGlobal(name) if fn == lua.LNil { return lua.LNil, nil } if _, ok := fn.(*lua.LFunction); !ok { return lua.LNil, fmt.Errorf("%q is not a function", name) } vm.L.Push(fn) for _, arg := range args { vm.L.Push(arg) } return vm.callWithTimeout(len(args)) } // Close shuts down the Lua VM. func (vm *LuaVM) Close() { close(vm.done) if vm.L != nil && !vm.L.IsClosed() { vm.L.Close() vm.L = nil } } // SetServices sets the core services reference on the VM. func (vm *LuaVM) SetServices(svc *CoreServices) { vm.Services = svc } // DoString executes an arbitrary Lua script string and returns the first return value. func (vm *LuaVM) DoString(src string) (string, error) { vm.mu.Lock() defer vm.mu.Unlock() if vm.L == nil || vm.L.IsClosed() { return "", fmt.Errorf("Lua VM is closed") } if err := vm.L.DoString(src); err != nil { return "", err } // Get return value from stack ret := vm.L.Get(-1) vm.L.Pop(1) return ret.String(), nil } // CallFunctionJSON is a thread-safe, timeout-safe wrapper that accepts JSON params. // It converts JSON→Lua under vm.mu, so the Lua state is never touched outside the lock. // segments: pre-validated identifier segments (e.g. ["calendar", "create_event"]) // paramsJSON: JSON string or "" / "{}" for no-arg calls. func (vm *LuaVM) CallFunctionJSON(segments []string, paramsJSON string) (string, error) { vm.mu.Lock() defer vm.mu.Unlock() if vm.L == nil || vm.L.IsClosed() { return "", fmt.Errorf("Lua VM is closed") } // Resolve the function via _G var fn lua.LValue if len(segments) == 1 { fn = vm.L.GetGlobal(segments[0]) } else { tbl := vm.L.GetGlobal(segments[0]) for i := 1; i < len(segments); i++ { if t, ok := tbl.(*lua.LTable); ok { tbl = t.RawGetString(segments[i]) } else { tbl = lua.LNil break } } fn = tbl } if fn == lua.LNil { return "", fmt.Errorf("function not found") } if _, ok := fn.(*lua.LFunction); !ok { return "", fmt.Errorf("not a function") } // Convert JSON params to Lua value UNDER the lock var luaArg lua.LValue if paramsJSON != "" && paramsJSON != "{}" { var params interface{} if err := json.Unmarshal([]byte(paramsJSON), ¶ms); err != nil { return "", fmt.Errorf("invalid JSON params: %w", err) } luaArg = goToLua(vm.L, params) } // Push function and args vm.L.Push(fn) if luaArg != nil { vm.L.Push(luaArg) } nargs := 0 if luaArg != nil { nargs = 1 } // Call with timeout ret, err := vm.callWithTimeout(nargs) if err != nil { return "", err } // Convert Lua return value to JSON for the frontend. // Lua tables (from db.query, etc.) return "table: 0x..." via .String() — // the frontend always expects valid JSON for JSON.parse(). goVal := luaValueToGo(ret) jsonBytes, marshalErr := json.Marshal(goVal) if marshalErr != nil { return "", fmt.Errorf("failed to marshal return value: %w", marshalErr) } return string(jsonBytes), nil } // LState returns the underlying lua.LState (for table creation outside CallFunctionJSON). // WARNING: only use this for read-only operations or when NOT holding vm.mu. func (vm *LuaVM) LState() *lua.LState { return vm.L } // VM returns the LuaVM for external use (bindings). func (p *Plugin) VM() *LuaVM { return p.vm } // goToLua converts a Go interface{} to a lua.LValue. // Must be called with vm.mu held (uses vm.L). func goToLua(L *lua.LState, v interface{}) lua.LValue { switch val := v.(type) { case nil: return lua.LNil case string: return lua.LString(val) case float64: return lua.LNumber(val) case bool: return lua.LBool(val) case map[string]interface{}: tbl := L.NewTable() for k, v := range val { tbl.RawSetString(k, goToLua(L, v)) } return tbl case []interface{}: tbl := L.NewTable() for i, v := range val { tbl.RawSetInt(i+1, goToLua(L, v)) } return tbl default: return lua.LString(fmt.Sprintf("%v", v)) } } // callWithTimeout runs a PCall with a timeout and returns the first LValue. // nargs is the number of function arguments already on the stack. // Must be called with vm.mu held. func (vm *LuaVM) callWithTimeout(nargs int) (lua.LValue, error) { timeout := vm.callTimeout if timeout <= 0 { timeout = 30 * time.Second } // Create a cancellable context for timeout ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() // Set context so gopher-lua's main loop checks ctx.Done() vm.L.SetContext(ctx) err := vm.L.PCall(nargs, lua.MultRet, nil) // Remove context after call vm.L.RemoveContext() // Collect return value (if any) ret := lua.LNil if vm.L.GetTop() > 0 { ret = vm.L.Get(1) vm.L.Pop(1) } if err != nil { return ret, err } // Check if timeout occurred if ctx.Err() != nil { return ret, fmt.Errorf("execution timeout (%s)", timeout) } return ret, nil }