verstak/internal/core/plugins/runtime.go

299 lines
7.2 KiB
Go

package plugins
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"sync"
"time"
"verstak/internal/core/activity"
"verstak/internal/core/files"
"verstak/internal/core/nodes"
"verstak/internal/core/storage"
"verstak/internal/core/worklog"
lua "github.com/yuin/gopher-lua"
)
// CoreServices bundles core Verstak services for use by the Lua plugin API.
type CoreServices struct {
NodeRepo *nodes.Repository
DB *storage.DB
ActivitySvc *activity.Service
WorklogSvc *worklog.Service
FilesSvc *files.Service
VaultPath string
}
// LuaVM wraps a gopher-lua state for a single plugin.
// Each plugin gets its own isolated state; API functions use vm.Services to call core services.
type LuaVM struct {
L *lua.LState
Plugin *Plugin
Services *CoreServices
mu sync.Mutex
done chan struct{}
// sandbox limits
callTimeout time.Duration
}
// NewLuaVM creates a sandboxed Lua VM for a plugin.
func NewLuaVM(p *Plugin) (*LuaVM, error) {
vm := &LuaVM{
Plugin: p,
callTimeout: 30 * time.Second,
done: make(chan struct{}),
}
L := lua.NewState(lua.Options{
SkipOpenLibs: true, // we selectively open safe libs
})
// Open only safe libraries
for _, pair := range []struct {
lib string
fn lua.LGFunction
}{
{lua.LoadLibName, lua.OpenPackage},
{lua.BaseLibName, lua.OpenBase},
{lua.TabLibName, lua.OpenTable},
{lua.StringLibName, lua.OpenString},
{lua.MathLibName, lua.OpenMath},
{lua.OsLibName, lua.OpenOs},
} {
L.Push(L.NewFunction(pair.fn))
L.Push(lua.LString(pair.lib))
L.Call(1, 0)
}
// Disable dangerous functions
for _, name := range []string{"dofile", "loadfile", "require", "module", "rawequal", "rawget", "rawset", "rawlen", "setfenv", "getfenv", "load"} {
L.SetGlobal(name, lua.LNil)
}
// Restrict package table: remove dangerous fields
if pkgTbl := L.GetGlobal("package"); pkgTbl != lua.LNil {
if tbl, ok := pkgTbl.(*lua.LTable); ok {
for _, name := range []string{"loadlib", "seeall", "preload", "loaders", "loaded", "path", "cpath", "config", "searchpath"} {
tbl.RawSetString(name, lua.LNil)
}
}
}
// Restrict os.* to safe subset
osSafe := map[string]bool{
"clock": true, "date": true, "difftime": true, "time": true,
"tmpname": true,
}
if osTable := L.GetGlobal("os"); osTable != lua.LNil {
if tbl, ok := osTable.(*lua.LTable); ok {
for _, k := range []string{"execute", "exit", "remove", "rename", "setlocale", "getenv"} {
tbl.RawSetString(k, lua.LNil)
}
// Only keep safe ones
tbl.ForEach(func(k lua.LValue, v lua.LValue) {
if ks, ok := k.(lua.LString); ok && !osSafe[string(ks)] {
tbl.RawSet(k, lua.LNil)
}
})
}
}
// Redirect print to Go log
L.SetGlobal("print", L.NewFunction(func(L *lua.LState) int {
top := L.GetTop()
var parts []string
for i := 1; i <= top; i++ {
parts = append(parts, L.Get(i).String())
}
msg := strings.Join(parts, " ")
log.Printf("[lua] %s", msg)
return 0
}))
vm.L = L
registerAPI(vm)
return vm, nil
}
// LoadScript loads and executes a Lua file from the plugin directory.
func (vm *LuaVM) LoadScript(filename string) error {
path := filepath.Join(vm.Plugin.Dir, filename)
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("read %s: %w", filename, err)
}
vm.mu.Lock()
defer vm.mu.Unlock()
fn, err := vm.L.Load(strings.NewReader(string(data)), filename)
if err != nil {
return fmt.Errorf("load %s: %w", filename, err)
}
vm.L.Push(fn)
_, err = vm.callWithTimeout(0)
return err
}
// CallHook calls a Lua function by name with optional args (no return value expected).
func (vm *LuaVM) CallHook(name string, args ...lua.LValue) error {
_, err := vm.CallHookWithResult(name, args...)
return err
}
// CallHookWithResult calls a Lua function by name and returns its first return value.
func (vm *LuaVM) CallHookWithResult(name string, args ...lua.LValue) (lua.LValue, error) {
vm.mu.Lock()
defer vm.mu.Unlock()
fn := vm.L.GetGlobal(name)
if fn == lua.LNil {
return lua.LNil, nil
}
if _, ok := fn.(*lua.LFunction); !ok {
return lua.LNil, fmt.Errorf("%q is not a function", name)
}
vm.L.Push(fn)
for _, arg := range args {
vm.L.Push(arg)
}
return vm.callWithTimeout(len(args))
}
// Close shuts down the Lua VM.
func (vm *LuaVM) Close() {
close(vm.done)
if vm.L != nil && !vm.L.IsClosed() {
vm.L.Close()
vm.L = nil
}
}
// SetServices sets the core services reference on the VM.
func (vm *LuaVM) SetServices(svc *CoreServices) {
vm.Services = svc
}
// DoString executes an arbitrary Lua script string and returns the first return value.
func (vm *LuaVM) DoString(src string) (string, error) {
vm.mu.Lock()
defer vm.mu.Unlock()
if vm.L == nil || vm.L.IsClosed() {
return "", fmt.Errorf("Lua VM is closed")
}
if err := vm.L.DoString(src); err != nil {
return "", err
}
// Get return value from stack
ret := vm.L.Get(-1)
vm.L.Pop(1)
return ret.String(), nil
}
// CallFunction resolves a dotted function name, pushes args, and calls it under lock with timeout.
// segments: pre-validated identifier segments (e.g. ["calendar", "create_event"])
// luaArg: pre-converted Lua argument (nil for no-arg calls)
// Returns the first return value as string, or error.
func (vm *LuaVM) CallFunction(segments []string, luaArg lua.LValue) (string, error) {
vm.mu.Lock()
defer vm.mu.Unlock()
if vm.L == nil || vm.L.IsClosed() {
return "", fmt.Errorf("Lua VM is closed")
}
// Resolve the function via _G
var fn lua.LValue
if len(segments) == 1 {
fn = vm.L.GetGlobal(segments[0])
} else {
tbl := vm.L.GetGlobal(segments[0])
for i := 1; i < len(segments); i++ {
if t, ok := tbl.(*lua.LTable); ok {
tbl = t.RawGetString(segments[i])
} else {
tbl = lua.LNil
break
}
}
fn = tbl
}
if fn == lua.LNil {
return "", fmt.Errorf("function not found")
}
if _, ok := fn.(*lua.LFunction); !ok {
return "", fmt.Errorf("not a function")
}
// Push function and args
vm.L.Push(fn)
if luaArg != nil {
vm.L.Push(luaArg)
}
nargs := 0
if luaArg != nil {
nargs = 1
}
// Call with timeout
ret, err := vm.callWithTimeout(nargs)
if err != nil {
return "", err
}
return ret.String(), nil
}
// LState returns the underlying lua.LState (for table creation).
func (vm *LuaVM) LState() *lua.LState {
return vm.L
}
// VM returns the LuaVM for external use (bindings).
func (p *Plugin) VM() *LuaVM {
return p.vm
}
// callWithTimeout runs a PCall with a timeout and returns the first LValue.
// nargs is the number of function arguments already on the stack.
// Must be called with vm.mu held.
func (vm *LuaVM) callWithTimeout(nargs int) (lua.LValue, error) {
timeout := vm.callTimeout
if timeout <= 0 {
timeout = 30 * time.Second
}
// Create a cancellable context for timeout
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// Set context so gopher-lua's main loop checks ctx.Done()
vm.L.SetContext(ctx)
err := vm.L.PCall(nargs, lua.MultRet, nil)
// Remove context after call
vm.L.RemoveContext()
// Collect return value (if any)
ret := lua.LNil
if vm.L.GetTop() > 0 {
ret = vm.L.Get(1)
vm.L.Pop(1)
}
if err != nil {
return ret, err
}
// Check if timeout occurred
if ctx.Err() != nil {
return ret, fmt.Errorf("execution timeout (%s)", timeout)
}
return ret, nil
}