import shlex from dataclasses import dataclass from pydantic import BaseModel @dataclass(frozen=True) class CommandPattern: words: tuple[str, ...] action_type: str risk_level: str class CommandPolicyResult(BaseModel): command: str normalized_command: str action_type: str risk_level: str requires_approval: bool requires_password: bool blocked: bool = False reason: str | None = None class CommandPolicy: READONLY_PATTERNS = ( CommandPattern(("pwd",), "working_directory", "low"), CommandPattern(("ls",), "directory_list", "low"), CommandPattern(("cat",), "file_read", "low"), CommandPattern(("head",), "file_read", "low"), CommandPattern(("tail",), "file_read", "low"), CommandPattern(("grep",), "text_search", "low"), CommandPattern(("find",), "file_search", "low"), CommandPattern(("apt", "list"), "package_check", "low"), CommandPattern(("apt-cache", "policy"), "package_check", "low"), CommandPattern(("git", "status"), "vcs_status", "low"), CommandPattern(("git", "diff"), "vcs_inspect", "low"), CommandPattern(("git", "log"), "vcs_inspect", "low"), CommandPattern(("pytest",), "test_run", "low"), CommandPattern(("python", "-m", "pytest"), "test_run", "low"), CommandPattern(("python3", "-m", "pytest"), "test_run", "low"), ) SYSTEM_PATTERNS = ( CommandPattern(("apt", "update"), "package_cache_update", "high"), CommandPattern(("apt", "install"), "package_install", "high"), CommandPattern(("apt", "remove"), "package_remove", "high"), CommandPattern(("apt", "upgrade"), "package_upgrade", "high"), CommandPattern(("systemctl",), "service_control", "high"), CommandPattern(("service",), "service_control", "high"), ) DESTRUCTIVE_COMMANDS = { "rm", "dd", "mkfs", "mount", "umount", "shutdown", "reboot", "poweroff", "su", } DANGEROUS_FRAGMENTS = ("curl | sh", "wget | sh", "chmod -r", "chown -r") @classmethod def classify(cls, command: str) -> CommandPolicyResult: command = command.strip() parts = cls._split(command) normalized_parts = cls._strip_sudo(parts) normalized_command = shlex.join(normalized_parts) if normalized_parts else command uses_sudo = bool(parts) and parts[0] == "sudo" lowered = command.lower() if not command: return cls._result(command, normalized_command, "empty", "low", False, False, True, "Empty command") if not parts: return cls._result(command, normalized_command, "invalid", "medium", True, False, False, "Command could not be parsed") if cls._is_destructive(normalized_parts, lowered): return cls._result( command, normalized_command, "destructive", "critical", False, uses_sudo, True, "Command is blocked by safety policy", ) readonly = cls._match(normalized_parts, cls.READONLY_PATTERNS) if readonly: requires_approval = uses_sudo return cls._result( command, normalized_command, readonly.action_type, readonly.risk_level if not uses_sudo else "medium", requires_approval, uses_sudo, False, "Sudo command requires approval" if uses_sudo else None, ) system = cls._match(normalized_parts, cls.SYSTEM_PATTERNS) if system: return cls._result( command, normalized_command, system.action_type, system.risk_level, True, uses_sudo, False, "System command requires approval", ) return cls._result( command, normalized_command, "shell_command", "medium", True, uses_sudo, False, "Command is outside allowlist and requires approval", ) @classmethod def _split(cls, command: str) -> list[str]: try: return shlex.split(command) except ValueError: return [] @classmethod def _strip_sudo(cls, parts: list[str]) -> list[str]: if parts and parts[0] == "sudo": stripped = parts[1:] while stripped and stripped[0].startswith("-"): stripped = stripped[1:] return stripped return parts @classmethod def _match( cls, parts: list[str], patterns: tuple[CommandPattern, ...] ) -> CommandPattern | None: lowered = tuple(part.lower() for part in parts) for pattern in patterns: if lowered[: len(pattern.words)] == pattern.words: return pattern return None @classmethod def _is_destructive(cls, parts: list[str], lowered_command: str) -> bool: if any(fragment in lowered_command for fragment in cls.DANGEROUS_FRAGMENTS): return True return bool(parts) and parts[0].lower() in cls.DESTRUCTIVE_COMMANDS @classmethod def _result( cls, command: str, normalized_command: str, action_type: str, risk_level: str, requires_approval: bool, requires_password: bool, blocked: bool, reason: str | None, ) -> CommandPolicyResult: return CommandPolicyResult( command=command, normalized_command=normalized_command, action_type=action_type, risk_level=risk_level, requires_approval=requires_approval, requires_password=requires_password, blocked=blocked, reason=reason, )