From 61cdd8bbafd7c835f14ad71fd1f6dfece1ea1eae Mon Sep 17 00:00:00 2001 From: mirivlad Date: Wed, 20 May 2026 23:03:59 +0800 Subject: [PATCH] Add command policy audit trail --- duck_core/api.py | 26 ++++ duck_core/events/store.py | 26 ++++ duck_core/runtime_loop.py | 37 ++++++ duck_core/tools/command_policy.py | 180 ++++++++++++++++++++++++++++ duck_core/tools/shell_exec_safe.py | 33 +++-- tests/smoke/test_api_stream_chat.py | 54 +++++++++ tests/smoke/test_tool_gateway.py | 20 ++++ 7 files changed, 364 insertions(+), 12 deletions(-) create mode 100644 duck_core/tools/command_policy.py diff --git a/duck_core/api.py b/duck_core/api.py index a20f94e..dfbd252 100644 --- a/duck_core/api.py +++ b/duck_core/api.py @@ -344,6 +344,12 @@ def create_app() -> FastAPI: "reasoning_content": reasoning_content, }, ) + except asyncio.CancelledError: + await task_store.cancel_task(task.task_id) + await event_store.append( + task.task_id, "task_cancelled", {"reason": "client_disconnected"} + ) + raise except Exception as exc: await task_store.fail_task(task.task_id, str(exc)) await event_store.append(task.task_id, "task_failed", {"error": str(exc)}) @@ -379,6 +385,14 @@ def create_app() -> FastAPI: async def get_events(task_id: str) -> list[dict[str, Any]]: return [event.model_dump() for event in await event_store.list_events(task_id)] + @app.get("/v1/audit/commands") + async def command_audit(limit: int = 100) -> list[dict[str, Any]]: + bounded_limit = min(max(limit, 1), 500) + return [ + event.model_dump() + for event in await event_store.list_by_type("command_audit", bounded_limit) + ] + @app.get("/v1/tasks/{task_id}/stream") async def stream_events(task_id: str) -> StreamingResponse: async def generator(): @@ -553,6 +567,12 @@ def create_app() -> FastAPI: "reasoning_content": reasoning_content, }, ) + except asyncio.CancelledError: + await task_store.cancel_task(task_id) + await event_store.append( + task_id, "task_cancelled", {"reason": "client_disconnected"} + ) + raise except Exception as exc: await task_store.fail_task(task_id, str(exc)) await event_store.append(task_id, "task_failed", {"error": str(exc)}) @@ -709,6 +729,12 @@ def create_app() -> FastAPI: "reasoning_content": reasoning_content, }, ) + except asyncio.CancelledError: + await task_store.cancel_task(task_id) + await event_store.append( + task_id, "task_cancelled", {"reason": "client_disconnected"} + ) + raise except Exception as exc: await task_store.fail_task(task_id, str(exc)) await event_store.append(task_id, "task_failed", {"error": str(exc)}) diff --git a/duck_core/events/store.py b/duck_core/events/store.py index a9f8bb3..8de8b83 100644 --- a/duck_core/events/store.py +++ b/duck_core/events/store.py @@ -90,3 +90,29 @@ class EventStore: ) for row in rows ] + + async def list_by_type(self, event_type: str, limit: int = 100) -> list[Event]: + await self.init() + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + """ + select * from events + where event_type = ? + order by id desc + limit ? + """, + (event_type, limit), + ) + rows = await cursor.fetchall() + return [ + Event( + id=row["id"], + task_id=row["task_id"], + sequence=row["sequence"], + event_type=row["event_type"], + payload=json.loads(row["payload_json"]), + created_at=row["created_at"], + ) + for row in rows + ] diff --git a/duck_core/runtime_loop.py b/duck_core/runtime_loop.py index 43c8f1b..8487e07 100644 --- a/duck_core/runtime_loop.py +++ b/duck_core/runtime_loop.py @@ -267,6 +267,9 @@ class RuntimeLoop: result = await gateway.run_action(action, approved=True, password=password) result_payload = result.model_dump() + await self._append_command_audit( + task_id, index, tool_name, action, result_payload, approved=decision != "deny" + ) await self.event_store.append( task_id, "tool_call_finished", @@ -334,6 +337,9 @@ class RuntimeLoop: ) result = await gateway.run_action(action, approved=approved_forever) result_payload = result.model_dump() + await self._append_command_audit( + task_id, index, tool_name, action, result_payload, approved=approved_forever + ) if result.metadata.get("requires_approval"): approval = None if self.approval_service is not None: @@ -375,6 +381,37 @@ class RuntimeLoop: ) return observations + async def _append_command_audit( + self, + task_id: str, + index: int, + tool_name: str, + action: dict[str, Any], + result_payload: dict[str, Any], + approved: bool, + ) -> None: + if tool_name != "shell_exec_safe": + return + metadata = result_payload.get("metadata") or {} + command = metadata.get("command") or (action.get("args") or {}).get("command") + await self.event_store.append( + task_id, + "command_audit", + { + "index": index, + "tool": tool_name, + "command": command, + "action_type": metadata.get("action_type"), + "risk_level": metadata.get("risk_level"), + "requires_approval": bool(metadata.get("requires_approval")), + "blocked": bool(metadata.get("blocked")), + "approved": approved, + "ok": bool(result_payload.get("ok")), + "returncode": metadata.get("returncode"), + "reason": metadata.get("reason") or result_payload.get("error"), + }, + ) + async def _run_action_loop( self, task_id: str, diff --git a/duck_core/tools/command_policy.py b/duck_core/tools/command_policy.py new file mode 100644 index 0000000..55f8876 --- /dev/null +++ b/duck_core/tools/command_policy.py @@ -0,0 +1,180 @@ +import shlex +from dataclasses import dataclass + +from pydantic import BaseModel + + +@dataclass(frozen=True) +class CommandPattern: + words: tuple[str, ...] + action_type: str + risk_level: str + + +class CommandPolicyResult(BaseModel): + command: str + normalized_command: str + action_type: str + risk_level: str + requires_approval: bool + requires_password: bool + blocked: bool = False + reason: str | None = None + + +class CommandPolicy: + READONLY_PATTERNS = ( + CommandPattern(("pwd",), "working_directory", "low"), + CommandPattern(("ls",), "directory_list", "low"), + CommandPattern(("cat",), "file_read", "low"), + CommandPattern(("head",), "file_read", "low"), + CommandPattern(("tail",), "file_read", "low"), + CommandPattern(("grep",), "text_search", "low"), + CommandPattern(("find",), "file_search", "low"), + CommandPattern(("apt", "list"), "package_check", "low"), + CommandPattern(("apt-cache", "policy"), "package_check", "low"), + CommandPattern(("git", "status"), "vcs_status", "low"), + CommandPattern(("git", "diff"), "vcs_inspect", "low"), + CommandPattern(("git", "log"), "vcs_inspect", "low"), + CommandPattern(("pytest",), "test_run", "low"), + CommandPattern(("python", "-m", "pytest"), "test_run", "low"), + CommandPattern(("python3", "-m", "pytest"), "test_run", "low"), + ) + SYSTEM_PATTERNS = ( + CommandPattern(("apt", "update"), "package_cache_update", "high"), + CommandPattern(("apt", "install"), "package_install", "high"), + CommandPattern(("apt", "remove"), "package_remove", "high"), + CommandPattern(("apt", "upgrade"), "package_upgrade", "high"), + CommandPattern(("systemctl",), "service_control", "high"), + CommandPattern(("service",), "service_control", "high"), + ) + DESTRUCTIVE_COMMANDS = { + "rm", + "dd", + "mkfs", + "mount", + "umount", + "shutdown", + "reboot", + "poweroff", + "su", + } + DANGEROUS_FRAGMENTS = ("curl | sh", "wget | sh", "chmod -r", "chown -r") + + @classmethod + def classify(cls, command: str) -> CommandPolicyResult: + command = command.strip() + parts = cls._split(command) + normalized_parts = cls._strip_sudo(parts) + normalized_command = shlex.join(normalized_parts) if normalized_parts else command + uses_sudo = bool(parts) and parts[0] == "sudo" + lowered = command.lower() + + if not command: + return cls._result(command, normalized_command, "empty", "low", False, False, True, "Empty command") + if not parts: + return cls._result(command, normalized_command, "invalid", "medium", True, False, False, "Command could not be parsed") + if cls._is_destructive(normalized_parts, lowered): + return cls._result( + command, + normalized_command, + "destructive", + "critical", + False, + uses_sudo, + True, + "Command is blocked by safety policy", + ) + + readonly = cls._match(normalized_parts, cls.READONLY_PATTERNS) + if readonly: + requires_approval = uses_sudo + return cls._result( + command, + normalized_command, + readonly.action_type, + readonly.risk_level if not uses_sudo else "medium", + requires_approval, + uses_sudo, + False, + "Sudo command requires approval" if uses_sudo else None, + ) + + system = cls._match(normalized_parts, cls.SYSTEM_PATTERNS) + if system: + return cls._result( + command, + normalized_command, + system.action_type, + system.risk_level, + True, + uses_sudo, + False, + "System command requires approval", + ) + + return cls._result( + command, + normalized_command, + "shell_command", + "medium", + True, + uses_sudo, + False, + "Command is outside allowlist and requires approval", + ) + + @classmethod + def _split(cls, command: str) -> list[str]: + try: + return shlex.split(command) + except ValueError: + return [] + + @classmethod + def _strip_sudo(cls, parts: list[str]) -> list[str]: + if parts and parts[0] == "sudo": + stripped = parts[1:] + while stripped and stripped[0].startswith("-"): + stripped = stripped[1:] + return stripped + return parts + + @classmethod + def _match( + cls, parts: list[str], patterns: tuple[CommandPattern, ...] + ) -> CommandPattern | None: + lowered = tuple(part.lower() for part in parts) + for pattern in patterns: + if lowered[: len(pattern.words)] == pattern.words: + return pattern + return None + + @classmethod + def _is_destructive(cls, parts: list[str], lowered_command: str) -> bool: + if any(fragment in lowered_command for fragment in cls.DANGEROUS_FRAGMENTS): + return True + return bool(parts) and parts[0].lower() in cls.DESTRUCTIVE_COMMANDS + + @classmethod + def _result( + cls, + command: str, + normalized_command: str, + action_type: str, + risk_level: str, + requires_approval: bool, + requires_password: bool, + blocked: bool, + reason: str | None, + ) -> CommandPolicyResult: + return CommandPolicyResult( + command=command, + normalized_command=normalized_command, + action_type=action_type, + risk_level=risk_level, + requires_approval=requires_approval, + requires_password=requires_password, + blocked=blocked, + reason=reason, + ) diff --git a/duck_core/tools/shell_exec_safe.py b/duck_core/tools/shell_exec_safe.py index 0664dce..8b4509c 100644 --- a/duck_core/tools/shell_exec_safe.py +++ b/duck_core/tools/shell_exec_safe.py @@ -3,6 +3,7 @@ import subprocess from typing import Any from duck_core.tools.base import ToolResult +from duck_core.tools.command_policy import CommandPolicy ALLOWLIST = { @@ -60,15 +61,19 @@ class ShellExecSafeTool: command = str(args.get("command", "")).strip() approved = bool(args.get("_approved")) password = args.get("_password") + policy = CommandPolicy.classify(command) allowed, reason, blocked = self._is_allowed(command, approved=approved) if not allowed: - metadata = {"blocked": True} if blocked else {"requires_approval": True} + metadata = { + **policy.model_dump(), + **({"blocked": True} if blocked else {"requires_approval": True}), + } return ToolResult(ok=False, error=reason, metadata=metadata) if self._is_sudo_command(command) and not password: return ToolResult( ok=False, error="Sudo password is required to run this command.", - metadata={"requires_password": True}, + metadata={**policy.model_dump(), "requires_password": True}, ) run_command = self._sudo_stdin_command(command) if self._is_sudo_command(command) else command input_text = f"{password}\n" if self._is_sudo_command(command) else None @@ -89,23 +94,27 @@ class ShellExecSafeTool: ok=completed.returncode == 0, output=completed.stdout, error=completed.stderr if completed.returncode else None, - metadata={"returncode": completed.returncode, "command": command}, + metadata={ + **{ + key: value + for key, value in policy.model_dump().items() + if key not in {"requires_approval", "requires_password"} + }, + "returncode": completed.returncode, + }, ) def _is_allowed( self, command: str, approved: bool = False ) -> tuple[bool, str | None, bool]: - if not command: - return False, "Empty command", False - lowered = command.lower() - parts = shlex.split(command) - for blocked in BLOCKLIST: - if self._matches_blocked_command(lowered, parts, blocked): - return False, f"Command is blocked: {blocked}", True + policy = CommandPolicy.classify(command) + if policy.blocked: + return False, policy.reason, True if approved: return True, None, False - if self._is_sudo_command(command): - return False, "Sudo command requires approval.", False + if policy.requires_approval: + return False, policy.reason, False + parts = shlex.split(command) prefix1 = parts[0] if parts else "" prefix2 = " ".join(parts[:2]) prefix3 = " ".join(parts[:3]) diff --git a/tests/smoke/test_api_stream_chat.py b/tests/smoke/test_api_stream_chat.py index 0a1921f..8b3e42e 100644 --- a/tests/smoke/test_api_stream_chat.py +++ b/tests/smoke/test_api_stream_chat.py @@ -323,3 +323,57 @@ def test_password_stream_runs_sudo_with_password_and_streams_answer(tmp_path, mo assert "event: content_delta" in body assert "sudo command completed" in body assert "secret" not in body + + +def test_command_audit_endpoint_exposes_redacted_shell_events(tmp_path, monkeypatch): + monkeypatch.setenv("DUCK_DB_PATH", str(tmp_path / "duck.sqlite3")) + + async def fake_chat(self, role, messages, temperature=None, max_output_tokens=None, response_format=None): + assert role == "action" + actions = [] + if not any("tool_observations" in message["content"] for message in messages): + actions = [ + { + "tool": "shell_exec_safe", + "args": {"command": "apt list --upgradable"}, + "reason": "Check available updates", + } + ] + return ModelResponse( + role=role, + model="local-main", + content=json.dumps( + { + "kind": "action_directive", + "intent": "check updates", + "risk_level": "low", + "actions": actions, + } + ), + reasoning_content=None, + raw={}, + latency_ms=1.0, + ) + + async def fake_stream_chat(self, role, messages): + yield {"type": "content_delta", "delta": "updates checked"} + + monkeypatch.setattr("duck_core.model_client.ModelClient.chat", fake_chat) + monkeypatch.setattr("duck_core.model_client.ModelClient.stream_chat", fake_stream_chat) + client = TestClient(create_app()) + + with client.stream( + "POST", + "/v1/chat/stream", + json={"message": "check updates", "workspace": str(tmp_path), "debug": True}, + ) as response: + _ = "".join(response.iter_text()) + audit = client.get("/v1/audit/commands").json() + + assert response.status_code == 200 + assert audit[0]["event_type"] == "command_audit" + assert audit[0]["payload"]["command"] == "apt list --upgradable" + assert audit[0]["payload"]["risk_level"] == "low" + assert audit[0]["payload"]["action_type"] == "package_check" + assert audit[0]["payload"]["approved"] is False + assert "password" not in json.dumps(audit).lower() diff --git a/tests/smoke/test_tool_gateway.py b/tests/smoke/test_tool_gateway.py index ddb030e..13a1331 100644 --- a/tests/smoke/test_tool_gateway.py +++ b/tests/smoke/test_tool_gateway.py @@ -3,6 +3,7 @@ import pytest from duck_core.tools.file_read import FileReadTool from duck_core.tools.file_write import FileWriteTool from duck_core.tools.gateway import ToolGateway +from duck_core.tools.command_policy import CommandPolicy from duck_core.tools.shell_exec_safe import ShellExecSafeTool @@ -139,3 +140,22 @@ async def test_shell_tool_allows_read_only_apt_update_check(monkeypatch, tmp_pat assert result.ok is True assert "bootlogd" in result.output assert result.metadata["command"] == "apt list --upgradable" + + +def test_command_policy_classifies_common_system_commands(): + readonly = CommandPolicy.classify("apt list --upgradable") + sudo_update = CommandPolicy.classify("sudo apt update") + install = CommandPolicy.classify("sudo apt install vim") + destructive = CommandPolicy.classify("rm -rf .") + + assert readonly.risk_level == "low" + assert readonly.action_type == "package_check" + assert readonly.requires_approval is False + assert sudo_update.risk_level == "high" + assert sudo_update.action_type == "package_cache_update" + assert sudo_update.requires_password is True + assert install.risk_level == "high" + assert install.action_type == "package_install" + assert install.requires_password is True + assert destructive.risk_level == "critical" + assert destructive.blocked is True