340 lines
8.3 KiB
Go
340 lines
8.3 KiB
Go
package plugins
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"verstak/internal/core/activity"
|
|
"verstak/internal/core/files"
|
|
"verstak/internal/core/nodes"
|
|
"verstak/internal/core/storage"
|
|
"verstak/internal/core/worklog"
|
|
|
|
lua "github.com/yuin/gopher-lua"
|
|
)
|
|
|
|
// CoreServices bundles core Verstak services for use by the Lua plugin API.
|
|
type CoreServices struct {
|
|
NodeRepo *nodes.Repository
|
|
DB *storage.DB
|
|
ActivitySvc *activity.Service
|
|
WorklogSvc *worklog.Service
|
|
FilesSvc *files.Service
|
|
VaultPath string
|
|
}
|
|
|
|
// LuaVM wraps a gopher-lua state for a single plugin.
|
|
// Each plugin gets its own isolated state; API functions use vm.Services to call core services.
|
|
type LuaVM struct {
|
|
L *lua.LState
|
|
Plugin *Plugin
|
|
Services *CoreServices
|
|
mu sync.Mutex
|
|
done chan struct{}
|
|
|
|
// sandbox limits
|
|
callTimeout time.Duration
|
|
}
|
|
|
|
// NewLuaVM creates a sandboxed Lua VM for a plugin.
|
|
func NewLuaVM(p *Plugin) (*LuaVM, error) {
|
|
vm := &LuaVM{
|
|
Plugin: p,
|
|
callTimeout: 30 * time.Second,
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
L := lua.NewState(lua.Options{
|
|
SkipOpenLibs: true, // we selectively open safe libs
|
|
})
|
|
|
|
// Open only safe libraries
|
|
for _, pair := range []struct {
|
|
lib string
|
|
fn lua.LGFunction
|
|
}{
|
|
{lua.LoadLibName, lua.OpenPackage},
|
|
{lua.BaseLibName, lua.OpenBase},
|
|
{lua.TabLibName, lua.OpenTable},
|
|
{lua.StringLibName, lua.OpenString},
|
|
{lua.MathLibName, lua.OpenMath},
|
|
{lua.OsLibName, lua.OpenOs},
|
|
} {
|
|
L.Push(L.NewFunction(pair.fn))
|
|
L.Push(lua.LString(pair.lib))
|
|
L.Call(1, 0)
|
|
}
|
|
|
|
// Disable dangerous functions
|
|
for _, name := range []string{"dofile", "loadfile", "require", "module", "rawequal", "rawget", "rawset", "rawlen", "setfenv", "getfenv", "load"} {
|
|
L.SetGlobal(name, lua.LNil)
|
|
}
|
|
|
|
// Restrict package table: remove dangerous fields
|
|
if pkgTbl := L.GetGlobal("package"); pkgTbl != lua.LNil {
|
|
if tbl, ok := pkgTbl.(*lua.LTable); ok {
|
|
for _, name := range []string{"loadlib", "seeall", "preload", "loaders", "loaded", "path", "cpath", "config", "searchpath"} {
|
|
tbl.RawSetString(name, lua.LNil)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Restrict os.* to safe subset
|
|
osSafe := map[string]bool{
|
|
"clock": true, "date": true, "difftime": true, "time": true,
|
|
"tmpname": true,
|
|
}
|
|
if osTable := L.GetGlobal("os"); osTable != lua.LNil {
|
|
if tbl, ok := osTable.(*lua.LTable); ok {
|
|
for _, k := range []string{"execute", "exit", "remove", "rename", "setlocale", "getenv"} {
|
|
tbl.RawSetString(k, lua.LNil)
|
|
}
|
|
// Only keep safe ones
|
|
tbl.ForEach(func(k lua.LValue, v lua.LValue) {
|
|
if ks, ok := k.(lua.LString); ok && !osSafe[string(ks)] {
|
|
tbl.RawSet(k, lua.LNil)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Redirect print to Go log
|
|
L.SetGlobal("print", L.NewFunction(func(L *lua.LState) int {
|
|
top := L.GetTop()
|
|
var parts []string
|
|
for i := 1; i <= top; i++ {
|
|
parts = append(parts, L.Get(i).String())
|
|
}
|
|
msg := strings.Join(parts, " ")
|
|
log.Printf("[lua] %s", msg)
|
|
return 0
|
|
}))
|
|
|
|
vm.L = L
|
|
registerAPI(vm)
|
|
return vm, nil
|
|
}
|
|
|
|
// LoadScript loads and executes a Lua file from the plugin directory.
|
|
func (vm *LuaVM) LoadScript(filename string) error {
|
|
path := filepath.Join(vm.Plugin.Dir, filename)
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return fmt.Errorf("read %s: %w", filename, err)
|
|
}
|
|
vm.mu.Lock()
|
|
defer vm.mu.Unlock()
|
|
fn, err := vm.L.Load(strings.NewReader(string(data)), filename)
|
|
if err != nil {
|
|
return fmt.Errorf("load %s: %w", filename, err)
|
|
}
|
|
vm.L.Push(fn)
|
|
_, err = vm.callWithTimeout(0)
|
|
return err
|
|
}
|
|
|
|
// CallHook calls a Lua function by name with optional args (no return value expected).
|
|
func (vm *LuaVM) CallHook(name string, args ...lua.LValue) error {
|
|
_, err := vm.CallHookWithResult(name, args...)
|
|
return err
|
|
}
|
|
|
|
// CallHookWithResult calls a Lua function by name and returns its first return value.
|
|
func (vm *LuaVM) CallHookWithResult(name string, args ...lua.LValue) (lua.LValue, error) {
|
|
vm.mu.Lock()
|
|
defer vm.mu.Unlock()
|
|
fn := vm.L.GetGlobal(name)
|
|
if fn == lua.LNil {
|
|
return lua.LNil, nil
|
|
}
|
|
if _, ok := fn.(*lua.LFunction); !ok {
|
|
return lua.LNil, fmt.Errorf("%q is not a function", name)
|
|
}
|
|
vm.L.Push(fn)
|
|
for _, arg := range args {
|
|
vm.L.Push(arg)
|
|
}
|
|
return vm.callWithTimeout(len(args))
|
|
}
|
|
|
|
// Close shuts down the Lua VM.
|
|
func (vm *LuaVM) Close() {
|
|
close(vm.done)
|
|
if vm.L != nil && !vm.L.IsClosed() {
|
|
vm.L.Close()
|
|
vm.L = nil
|
|
}
|
|
}
|
|
|
|
// SetServices sets the core services reference on the VM.
|
|
func (vm *LuaVM) SetServices(svc *CoreServices) {
|
|
vm.Services = svc
|
|
}
|
|
|
|
// DoString executes an arbitrary Lua script string and returns the first return value.
|
|
func (vm *LuaVM) DoString(src string) (string, error) {
|
|
vm.mu.Lock()
|
|
defer vm.mu.Unlock()
|
|
if vm.L == nil || vm.L.IsClosed() {
|
|
return "", fmt.Errorf("Lua VM is closed")
|
|
}
|
|
if err := vm.L.DoString(src); err != nil {
|
|
return "", err
|
|
}
|
|
// Get return value from stack
|
|
ret := vm.L.Get(-1)
|
|
vm.L.Pop(1)
|
|
return ret.String(), nil
|
|
}
|
|
|
|
// CallFunctionJSON is a thread-safe, timeout-safe wrapper that accepts JSON params.
|
|
// It converts JSON→Lua under vm.mu, so the Lua state is never touched outside the lock.
|
|
// segments: pre-validated identifier segments (e.g. ["calendar", "create_event"])
|
|
// paramsJSON: JSON string or "" / "{}" for no-arg calls.
|
|
func (vm *LuaVM) CallFunctionJSON(segments []string, paramsJSON string) (string, error) {
|
|
vm.mu.Lock()
|
|
defer vm.mu.Unlock()
|
|
|
|
if vm.L == nil || vm.L.IsClosed() {
|
|
return "", fmt.Errorf("Lua VM is closed")
|
|
}
|
|
|
|
// Resolve the function via _G
|
|
var fn lua.LValue
|
|
if len(segments) == 1 {
|
|
fn = vm.L.GetGlobal(segments[0])
|
|
} else {
|
|
tbl := vm.L.GetGlobal(segments[0])
|
|
for i := 1; i < len(segments); i++ {
|
|
if t, ok := tbl.(*lua.LTable); ok {
|
|
tbl = t.RawGetString(segments[i])
|
|
} else {
|
|
tbl = lua.LNil
|
|
break
|
|
}
|
|
}
|
|
fn = tbl
|
|
}
|
|
|
|
if fn == lua.LNil {
|
|
return "", fmt.Errorf("function not found")
|
|
}
|
|
if _, ok := fn.(*lua.LFunction); !ok {
|
|
return "", fmt.Errorf("not a function")
|
|
}
|
|
|
|
// Convert JSON params to Lua value UNDER the lock
|
|
var luaArg lua.LValue
|
|
if paramsJSON != "" && paramsJSON != "{}" {
|
|
var params interface{}
|
|
if err := json.Unmarshal([]byte(paramsJSON), ¶ms); err != nil {
|
|
return "", fmt.Errorf("invalid JSON params: %w", err)
|
|
}
|
|
luaArg = goToLua(vm.L, params)
|
|
}
|
|
|
|
// Push function and args
|
|
vm.L.Push(fn)
|
|
if luaArg != nil {
|
|
vm.L.Push(luaArg)
|
|
}
|
|
nargs := 0
|
|
if luaArg != nil {
|
|
nargs = 1
|
|
}
|
|
|
|
// Call with timeout
|
|
ret, err := vm.callWithTimeout(nargs)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return ret.String(), nil
|
|
}
|
|
|
|
// LState returns the underlying lua.LState (for table creation outside CallFunctionJSON).
|
|
// WARNING: only use this for read-only operations or when NOT holding vm.mu.
|
|
func (vm *LuaVM) LState() *lua.LState {
|
|
return vm.L
|
|
}
|
|
|
|
// VM returns the LuaVM for external use (bindings).
|
|
func (p *Plugin) VM() *LuaVM {
|
|
return p.vm
|
|
}
|
|
|
|
// goToLua converts a Go interface{} to a lua.LValue.
|
|
// Must be called with vm.mu held (uses vm.L).
|
|
func goToLua(L *lua.LState, v interface{}) lua.LValue {
|
|
switch val := v.(type) {
|
|
case nil:
|
|
return lua.LNil
|
|
case string:
|
|
return lua.LString(val)
|
|
case float64:
|
|
return lua.LNumber(val)
|
|
case bool:
|
|
return lua.LBool(val)
|
|
case map[string]interface{}:
|
|
tbl := L.NewTable()
|
|
for k, v := range val {
|
|
tbl.RawSetString(k, goToLua(L, v))
|
|
}
|
|
return tbl
|
|
case []interface{}:
|
|
tbl := L.NewTable()
|
|
for i, v := range val {
|
|
tbl.RawSetInt(i+1, goToLua(L, v))
|
|
}
|
|
return tbl
|
|
default:
|
|
return lua.LString(fmt.Sprintf("%v", v))
|
|
}
|
|
}
|
|
|
|
// callWithTimeout runs a PCall with a timeout and returns the first LValue.
|
|
// nargs is the number of function arguments already on the stack.
|
|
// Must be called with vm.mu held.
|
|
func (vm *LuaVM) callWithTimeout(nargs int) (lua.LValue, error) {
|
|
timeout := vm.callTimeout
|
|
if timeout <= 0 {
|
|
timeout = 30 * time.Second
|
|
}
|
|
|
|
// Create a cancellable context for timeout
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
|
|
// Set context so gopher-lua's main loop checks ctx.Done()
|
|
vm.L.SetContext(ctx)
|
|
|
|
err := vm.L.PCall(nargs, lua.MultRet, nil)
|
|
|
|
// Remove context after call
|
|
vm.L.RemoveContext()
|
|
|
|
// Collect return value (if any)
|
|
ret := lua.LNil
|
|
if vm.L.GetTop() > 0 {
|
|
ret = vm.L.Get(1)
|
|
vm.L.Pop(1)
|
|
}
|
|
|
|
if err != nil {
|
|
return ret, err
|
|
}
|
|
|
|
// Check if timeout occurred
|
|
if ctx.Err() != nil {
|
|
return ret, fmt.Errorf("execution timeout (%s)", timeout)
|
|
}
|
|
|
|
return ret, nil
|
|
}
|