ducklm/duck_core/tools/shell_exec_safe.py

110 lines
3.1 KiB
Python

import shlex
import subprocess
from typing import Any
from duck_core.tools.base import ToolResult
ALLOWLIST = {
"pwd",
"ls",
"cat",
"head",
"tail",
"grep",
"find",
"pytest",
"python -m pytest",
"python3 -m pytest",
"git status",
"git diff",
"git log",
}
BLOCKLIST = {
"rm",
"sudo",
"su",
"dd",
"mkfs",
"mount",
"umount",
"shutdown",
"reboot",
"poweroff",
"systemctl",
"service",
"apt install",
"apt remove",
"pacman -S",
"pacman -R",
"pip install",
"npm install -g",
"chmod -R",
"chown -R",
"curl | sh",
"wget | sh",
}
class ShellExecSafeTool:
name = "shell_exec_safe"
risk_level = "medium"
def __init__(self, workspace: str, timeout_seconds: int = 30):
self.workspace = workspace
self.timeout_seconds = timeout_seconds
async def run(self, args: dict[str, Any]) -> ToolResult:
command = str(args.get("command", "")).strip()
approved = bool(args.get("_approved"))
allowed, reason, blocked = self._is_allowed(command, approved=approved)
if not allowed:
metadata = {"blocked": True} if blocked else {"requires_approval": True}
return ToolResult(ok=False, error=reason, metadata=metadata)
try:
completed = subprocess.run(
command,
cwd=self.workspace,
shell=True,
text=True,
capture_output=True,
timeout=self.timeout_seconds,
check=False,
)
except subprocess.SubprocessError as exc:
return ToolResult(ok=False, error=str(exc))
return ToolResult(
ok=completed.returncode == 0,
output=completed.stdout,
error=completed.stderr if completed.returncode else None,
metadata={"returncode": completed.returncode, "command": command},
)
def _is_allowed(
self, command: str, approved: bool = False
) -> tuple[bool, str | None, bool]:
if not command:
return False, "Empty command", False
lowered = command.lower()
parts = shlex.split(command)
for blocked in BLOCKLIST:
if self._matches_blocked_command(lowered, parts, blocked):
return False, f"Command is blocked: {blocked}", True
if approved:
return True, None, False
prefix1 = parts[0] if parts else ""
prefix2 = " ".join(parts[:2])
prefix3 = " ".join(parts[:3])
if prefix1 in ALLOWLIST or prefix2 in ALLOWLIST or prefix3 in ALLOWLIST:
return True, None, False
return False, "Command is outside allowlist and requires approval", False
def _matches_blocked_command(
self, lowered_command: str, parts: list[str], blocked: str
) -> bool:
lowered_blocked = blocked.lower()
if " " in lowered_blocked or "|" in lowered_blocked:
return lowered_command.startswith(lowered_blocked) or lowered_blocked in lowered_command
return bool(parts) and parts[0].lower() == lowered_blocked