ducklm/duck_core/tools/command_policy.py

181 lines
5.9 KiB
Python

import shlex
from dataclasses import dataclass
from pydantic import BaseModel
@dataclass(frozen=True)
class CommandPattern:
words: tuple[str, ...]
action_type: str
risk_level: str
class CommandPolicyResult(BaseModel):
command: str
normalized_command: str
action_type: str
risk_level: str
requires_approval: bool
requires_password: bool
blocked: bool = False
reason: str | None = None
class CommandPolicy:
READONLY_PATTERNS = (
CommandPattern(("pwd",), "working_directory", "low"),
CommandPattern(("ls",), "directory_list", "low"),
CommandPattern(("cat",), "file_read", "low"),
CommandPattern(("head",), "file_read", "low"),
CommandPattern(("tail",), "file_read", "low"),
CommandPattern(("grep",), "text_search", "low"),
CommandPattern(("find",), "file_search", "low"),
CommandPattern(("apt", "list"), "package_check", "low"),
CommandPattern(("apt-cache", "policy"), "package_check", "low"),
CommandPattern(("git", "status"), "vcs_status", "low"),
CommandPattern(("git", "diff"), "vcs_inspect", "low"),
CommandPattern(("git", "log"), "vcs_inspect", "low"),
CommandPattern(("pytest",), "test_run", "low"),
CommandPattern(("python", "-m", "pytest"), "test_run", "low"),
CommandPattern(("python3", "-m", "pytest"), "test_run", "low"),
)
SYSTEM_PATTERNS = (
CommandPattern(("apt", "update"), "package_cache_update", "high"),
CommandPattern(("apt", "install"), "package_install", "high"),
CommandPattern(("apt", "remove"), "package_remove", "high"),
CommandPattern(("apt", "upgrade"), "package_upgrade", "high"),
CommandPattern(("systemctl",), "service_control", "high"),
CommandPattern(("service",), "service_control", "high"),
)
DESTRUCTIVE_COMMANDS = {
"rm",
"dd",
"mkfs",
"mount",
"umount",
"shutdown",
"reboot",
"poweroff",
"su",
}
DANGEROUS_FRAGMENTS = ("curl | sh", "wget | sh", "chmod -r", "chown -r")
@classmethod
def classify(cls, command: str) -> CommandPolicyResult:
command = command.strip()
parts = cls._split(command)
normalized_parts = cls._strip_sudo(parts)
normalized_command = shlex.join(normalized_parts) if normalized_parts else command
uses_sudo = bool(parts) and parts[0] == "sudo"
lowered = command.lower()
if not command:
return cls._result(command, normalized_command, "empty", "low", False, False, True, "Empty command")
if not parts:
return cls._result(command, normalized_command, "invalid", "medium", True, False, False, "Command could not be parsed")
if cls._is_destructive(normalized_parts, lowered):
return cls._result(
command,
normalized_command,
"destructive",
"critical",
False,
uses_sudo,
True,
"Command is blocked by safety policy",
)
readonly = cls._match(normalized_parts, cls.READONLY_PATTERNS)
if readonly:
requires_approval = uses_sudo
return cls._result(
command,
normalized_command,
readonly.action_type,
readonly.risk_level if not uses_sudo else "medium",
requires_approval,
uses_sudo,
False,
"Sudo command requires approval" if uses_sudo else None,
)
system = cls._match(normalized_parts, cls.SYSTEM_PATTERNS)
if system:
return cls._result(
command,
normalized_command,
system.action_type,
system.risk_level,
True,
uses_sudo,
False,
"System command requires approval",
)
return cls._result(
command,
normalized_command,
"shell_command",
"medium",
True,
uses_sudo,
False,
"Command is outside allowlist and requires approval",
)
@classmethod
def _split(cls, command: str) -> list[str]:
try:
return shlex.split(command)
except ValueError:
return []
@classmethod
def _strip_sudo(cls, parts: list[str]) -> list[str]:
if parts and parts[0] == "sudo":
stripped = parts[1:]
while stripped and stripped[0].startswith("-"):
stripped = stripped[1:]
return stripped
return parts
@classmethod
def _match(
cls, parts: list[str], patterns: tuple[CommandPattern, ...]
) -> CommandPattern | None:
lowered = tuple(part.lower() for part in parts)
for pattern in patterns:
if lowered[: len(pattern.words)] == pattern.words:
return pattern
return None
@classmethod
def _is_destructive(cls, parts: list[str], lowered_command: str) -> bool:
if any(fragment in lowered_command for fragment in cls.DANGEROUS_FRAGMENTS):
return True
return bool(parts) and parts[0].lower() in cls.DESTRUCTIVE_COMMANDS
@classmethod
def _result(
cls,
command: str,
normalized_command: str,
action_type: str,
risk_level: str,
requires_approval: bool,
requires_password: bool,
blocked: bool,
reason: str | None,
) -> CommandPolicyResult:
return CommandPolicyResult(
command=command,
normalized_command=normalized_command,
action_type=action_type,
risk_level=risk_level,
requires_approval=requires_approval,
requires_password=requires_password,
blocked=blocked,
reason=reason,
)