181 lines
5.9 KiB
Python
181 lines
5.9 KiB
Python
import shlex
|
|
from dataclasses import dataclass
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CommandPattern:
|
|
words: tuple[str, ...]
|
|
action_type: str
|
|
risk_level: str
|
|
|
|
|
|
class CommandPolicyResult(BaseModel):
|
|
command: str
|
|
normalized_command: str
|
|
action_type: str
|
|
risk_level: str
|
|
requires_approval: bool
|
|
requires_password: bool
|
|
blocked: bool = False
|
|
reason: str | None = None
|
|
|
|
|
|
class CommandPolicy:
|
|
READONLY_PATTERNS = (
|
|
CommandPattern(("pwd",), "working_directory", "low"),
|
|
CommandPattern(("ls",), "directory_list", "low"),
|
|
CommandPattern(("cat",), "file_read", "low"),
|
|
CommandPattern(("head",), "file_read", "low"),
|
|
CommandPattern(("tail",), "file_read", "low"),
|
|
CommandPattern(("grep",), "text_search", "low"),
|
|
CommandPattern(("find",), "file_search", "low"),
|
|
CommandPattern(("apt", "list"), "package_check", "low"),
|
|
CommandPattern(("apt-cache", "policy"), "package_check", "low"),
|
|
CommandPattern(("git", "status"), "vcs_status", "low"),
|
|
CommandPattern(("git", "diff"), "vcs_inspect", "low"),
|
|
CommandPattern(("git", "log"), "vcs_inspect", "low"),
|
|
CommandPattern(("pytest",), "test_run", "low"),
|
|
CommandPattern(("python", "-m", "pytest"), "test_run", "low"),
|
|
CommandPattern(("python3", "-m", "pytest"), "test_run", "low"),
|
|
)
|
|
SYSTEM_PATTERNS = (
|
|
CommandPattern(("apt", "update"), "package_cache_update", "high"),
|
|
CommandPattern(("apt", "install"), "package_install", "high"),
|
|
CommandPattern(("apt", "remove"), "package_remove", "high"),
|
|
CommandPattern(("apt", "upgrade"), "package_upgrade", "high"),
|
|
CommandPattern(("systemctl",), "service_control", "high"),
|
|
CommandPattern(("service",), "service_control", "high"),
|
|
)
|
|
DESTRUCTIVE_COMMANDS = {
|
|
"rm",
|
|
"dd",
|
|
"mkfs",
|
|
"mount",
|
|
"umount",
|
|
"shutdown",
|
|
"reboot",
|
|
"poweroff",
|
|
"su",
|
|
}
|
|
DANGEROUS_FRAGMENTS = ("curl | sh", "wget | sh", "chmod -r", "chown -r")
|
|
|
|
@classmethod
|
|
def classify(cls, command: str) -> CommandPolicyResult:
|
|
command = command.strip()
|
|
parts = cls._split(command)
|
|
normalized_parts = cls._strip_sudo(parts)
|
|
normalized_command = shlex.join(normalized_parts) if normalized_parts else command
|
|
uses_sudo = bool(parts) and parts[0] == "sudo"
|
|
lowered = command.lower()
|
|
|
|
if not command:
|
|
return cls._result(command, normalized_command, "empty", "low", False, False, True, "Empty command")
|
|
if not parts:
|
|
return cls._result(command, normalized_command, "invalid", "medium", True, False, False, "Command could not be parsed")
|
|
if cls._is_destructive(normalized_parts, lowered):
|
|
return cls._result(
|
|
command,
|
|
normalized_command,
|
|
"destructive",
|
|
"critical",
|
|
False,
|
|
uses_sudo,
|
|
True,
|
|
"Command is blocked by safety policy",
|
|
)
|
|
|
|
readonly = cls._match(normalized_parts, cls.READONLY_PATTERNS)
|
|
if readonly:
|
|
requires_approval = uses_sudo
|
|
return cls._result(
|
|
command,
|
|
normalized_command,
|
|
readonly.action_type,
|
|
readonly.risk_level if not uses_sudo else "medium",
|
|
requires_approval,
|
|
uses_sudo,
|
|
False,
|
|
"Sudo command requires approval" if uses_sudo else None,
|
|
)
|
|
|
|
system = cls._match(normalized_parts, cls.SYSTEM_PATTERNS)
|
|
if system:
|
|
return cls._result(
|
|
command,
|
|
normalized_command,
|
|
system.action_type,
|
|
system.risk_level,
|
|
True,
|
|
uses_sudo,
|
|
False,
|
|
"System command requires approval",
|
|
)
|
|
|
|
return cls._result(
|
|
command,
|
|
normalized_command,
|
|
"shell_command",
|
|
"medium",
|
|
True,
|
|
uses_sudo,
|
|
False,
|
|
"Command is outside allowlist and requires approval",
|
|
)
|
|
|
|
@classmethod
|
|
def _split(cls, command: str) -> list[str]:
|
|
try:
|
|
return shlex.split(command)
|
|
except ValueError:
|
|
return []
|
|
|
|
@classmethod
|
|
def _strip_sudo(cls, parts: list[str]) -> list[str]:
|
|
if parts and parts[0] == "sudo":
|
|
stripped = parts[1:]
|
|
while stripped and stripped[0].startswith("-"):
|
|
stripped = stripped[1:]
|
|
return stripped
|
|
return parts
|
|
|
|
@classmethod
|
|
def _match(
|
|
cls, parts: list[str], patterns: tuple[CommandPattern, ...]
|
|
) -> CommandPattern | None:
|
|
lowered = tuple(part.lower() for part in parts)
|
|
for pattern in patterns:
|
|
if lowered[: len(pattern.words)] == pattern.words:
|
|
return pattern
|
|
return None
|
|
|
|
@classmethod
|
|
def _is_destructive(cls, parts: list[str], lowered_command: str) -> bool:
|
|
if any(fragment in lowered_command for fragment in cls.DANGEROUS_FRAGMENTS):
|
|
return True
|
|
return bool(parts) and parts[0].lower() in cls.DESTRUCTIVE_COMMANDS
|
|
|
|
@classmethod
|
|
def _result(
|
|
cls,
|
|
command: str,
|
|
normalized_command: str,
|
|
action_type: str,
|
|
risk_level: str,
|
|
requires_approval: bool,
|
|
requires_password: bool,
|
|
blocked: bool,
|
|
reason: str | None,
|
|
) -> CommandPolicyResult:
|
|
return CommandPolicyResult(
|
|
command=command,
|
|
normalized_command=normalized_command,
|
|
action_type=action_type,
|
|
risk_level=risk_level,
|
|
requires_approval=requires_approval,
|
|
requires_password=requires_password,
|
|
blocked=blocked,
|
|
reason=reason,
|
|
)
|