Add command policy audit trail

This commit is contained in:
mirivlad 2026-05-20 23:03:59 +08:00
parent 45b70d2800
commit 61cdd8bbaf
7 changed files with 364 additions and 12 deletions

View File

@ -344,6 +344,12 @@ def create_app() -> FastAPI:
"reasoning_content": reasoning_content, "reasoning_content": reasoning_content,
}, },
) )
except asyncio.CancelledError:
await task_store.cancel_task(task.task_id)
await event_store.append(
task.task_id, "task_cancelled", {"reason": "client_disconnected"}
)
raise
except Exception as exc: except Exception as exc:
await task_store.fail_task(task.task_id, str(exc)) await task_store.fail_task(task.task_id, str(exc))
await event_store.append(task.task_id, "task_failed", {"error": str(exc)}) await event_store.append(task.task_id, "task_failed", {"error": str(exc)})
@ -379,6 +385,14 @@ def create_app() -> FastAPI:
async def get_events(task_id: str) -> list[dict[str, Any]]: async def get_events(task_id: str) -> list[dict[str, Any]]:
return [event.model_dump() for event in await event_store.list_events(task_id)] return [event.model_dump() for event in await event_store.list_events(task_id)]
@app.get("/v1/audit/commands")
async def command_audit(limit: int = 100) -> list[dict[str, Any]]:
bounded_limit = min(max(limit, 1), 500)
return [
event.model_dump()
for event in await event_store.list_by_type("command_audit", bounded_limit)
]
@app.get("/v1/tasks/{task_id}/stream") @app.get("/v1/tasks/{task_id}/stream")
async def stream_events(task_id: str) -> StreamingResponse: async def stream_events(task_id: str) -> StreamingResponse:
async def generator(): async def generator():
@ -553,6 +567,12 @@ def create_app() -> FastAPI:
"reasoning_content": reasoning_content, "reasoning_content": reasoning_content,
}, },
) )
except asyncio.CancelledError:
await task_store.cancel_task(task_id)
await event_store.append(
task_id, "task_cancelled", {"reason": "client_disconnected"}
)
raise
except Exception as exc: except Exception as exc:
await task_store.fail_task(task_id, str(exc)) await task_store.fail_task(task_id, str(exc))
await event_store.append(task_id, "task_failed", {"error": str(exc)}) await event_store.append(task_id, "task_failed", {"error": str(exc)})
@ -709,6 +729,12 @@ def create_app() -> FastAPI:
"reasoning_content": reasoning_content, "reasoning_content": reasoning_content,
}, },
) )
except asyncio.CancelledError:
await task_store.cancel_task(task_id)
await event_store.append(
task_id, "task_cancelled", {"reason": "client_disconnected"}
)
raise
except Exception as exc: except Exception as exc:
await task_store.fail_task(task_id, str(exc)) await task_store.fail_task(task_id, str(exc))
await event_store.append(task_id, "task_failed", {"error": str(exc)}) await event_store.append(task_id, "task_failed", {"error": str(exc)})

View File

@ -90,3 +90,29 @@ class EventStore:
) )
for row in rows for row in rows
] ]
async def list_by_type(self, event_type: str, limit: int = 100) -> list[Event]:
await self.init()
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"""
select * from events
where event_type = ?
order by id desc
limit ?
""",
(event_type, limit),
)
rows = await cursor.fetchall()
return [
Event(
id=row["id"],
task_id=row["task_id"],
sequence=row["sequence"],
event_type=row["event_type"],
payload=json.loads(row["payload_json"]),
created_at=row["created_at"],
)
for row in rows
]

View File

@ -267,6 +267,9 @@ class RuntimeLoop:
result = await gateway.run_action(action, approved=True, password=password) result = await gateway.run_action(action, approved=True, password=password)
result_payload = result.model_dump() result_payload = result.model_dump()
await self._append_command_audit(
task_id, index, tool_name, action, result_payload, approved=decision != "deny"
)
await self.event_store.append( await self.event_store.append(
task_id, task_id,
"tool_call_finished", "tool_call_finished",
@ -334,6 +337,9 @@ class RuntimeLoop:
) )
result = await gateway.run_action(action, approved=approved_forever) result = await gateway.run_action(action, approved=approved_forever)
result_payload = result.model_dump() result_payload = result.model_dump()
await self._append_command_audit(
task_id, index, tool_name, action, result_payload, approved=approved_forever
)
if result.metadata.get("requires_approval"): if result.metadata.get("requires_approval"):
approval = None approval = None
if self.approval_service is not None: if self.approval_service is not None:
@ -375,6 +381,37 @@ class RuntimeLoop:
) )
return observations return observations
async def _append_command_audit(
self,
task_id: str,
index: int,
tool_name: str,
action: dict[str, Any],
result_payload: dict[str, Any],
approved: bool,
) -> None:
if tool_name != "shell_exec_safe":
return
metadata = result_payload.get("metadata") or {}
command = metadata.get("command") or (action.get("args") or {}).get("command")
await self.event_store.append(
task_id,
"command_audit",
{
"index": index,
"tool": tool_name,
"command": command,
"action_type": metadata.get("action_type"),
"risk_level": metadata.get("risk_level"),
"requires_approval": bool(metadata.get("requires_approval")),
"blocked": bool(metadata.get("blocked")),
"approved": approved,
"ok": bool(result_payload.get("ok")),
"returncode": metadata.get("returncode"),
"reason": metadata.get("reason") or result_payload.get("error"),
},
)
async def _run_action_loop( async def _run_action_loop(
self, self,
task_id: str, task_id: str,

View File

@ -0,0 +1,180 @@
import shlex
from dataclasses import dataclass
from pydantic import BaseModel
@dataclass(frozen=True)
class CommandPattern:
words: tuple[str, ...]
action_type: str
risk_level: str
class CommandPolicyResult(BaseModel):
command: str
normalized_command: str
action_type: str
risk_level: str
requires_approval: bool
requires_password: bool
blocked: bool = False
reason: str | None = None
class CommandPolicy:
READONLY_PATTERNS = (
CommandPattern(("pwd",), "working_directory", "low"),
CommandPattern(("ls",), "directory_list", "low"),
CommandPattern(("cat",), "file_read", "low"),
CommandPattern(("head",), "file_read", "low"),
CommandPattern(("tail",), "file_read", "low"),
CommandPattern(("grep",), "text_search", "low"),
CommandPattern(("find",), "file_search", "low"),
CommandPattern(("apt", "list"), "package_check", "low"),
CommandPattern(("apt-cache", "policy"), "package_check", "low"),
CommandPattern(("git", "status"), "vcs_status", "low"),
CommandPattern(("git", "diff"), "vcs_inspect", "low"),
CommandPattern(("git", "log"), "vcs_inspect", "low"),
CommandPattern(("pytest",), "test_run", "low"),
CommandPattern(("python", "-m", "pytest"), "test_run", "low"),
CommandPattern(("python3", "-m", "pytest"), "test_run", "low"),
)
SYSTEM_PATTERNS = (
CommandPattern(("apt", "update"), "package_cache_update", "high"),
CommandPattern(("apt", "install"), "package_install", "high"),
CommandPattern(("apt", "remove"), "package_remove", "high"),
CommandPattern(("apt", "upgrade"), "package_upgrade", "high"),
CommandPattern(("systemctl",), "service_control", "high"),
CommandPattern(("service",), "service_control", "high"),
)
DESTRUCTIVE_COMMANDS = {
"rm",
"dd",
"mkfs",
"mount",
"umount",
"shutdown",
"reboot",
"poweroff",
"su",
}
DANGEROUS_FRAGMENTS = ("curl | sh", "wget | sh", "chmod -r", "chown -r")
@classmethod
def classify(cls, command: str) -> CommandPolicyResult:
command = command.strip()
parts = cls._split(command)
normalized_parts = cls._strip_sudo(parts)
normalized_command = shlex.join(normalized_parts) if normalized_parts else command
uses_sudo = bool(parts) and parts[0] == "sudo"
lowered = command.lower()
if not command:
return cls._result(command, normalized_command, "empty", "low", False, False, True, "Empty command")
if not parts:
return cls._result(command, normalized_command, "invalid", "medium", True, False, False, "Command could not be parsed")
if cls._is_destructive(normalized_parts, lowered):
return cls._result(
command,
normalized_command,
"destructive",
"critical",
False,
uses_sudo,
True,
"Command is blocked by safety policy",
)
readonly = cls._match(normalized_parts, cls.READONLY_PATTERNS)
if readonly:
requires_approval = uses_sudo
return cls._result(
command,
normalized_command,
readonly.action_type,
readonly.risk_level if not uses_sudo else "medium",
requires_approval,
uses_sudo,
False,
"Sudo command requires approval" if uses_sudo else None,
)
system = cls._match(normalized_parts, cls.SYSTEM_PATTERNS)
if system:
return cls._result(
command,
normalized_command,
system.action_type,
system.risk_level,
True,
uses_sudo,
False,
"System command requires approval",
)
return cls._result(
command,
normalized_command,
"shell_command",
"medium",
True,
uses_sudo,
False,
"Command is outside allowlist and requires approval",
)
@classmethod
def _split(cls, command: str) -> list[str]:
try:
return shlex.split(command)
except ValueError:
return []
@classmethod
def _strip_sudo(cls, parts: list[str]) -> list[str]:
if parts and parts[0] == "sudo":
stripped = parts[1:]
while stripped and stripped[0].startswith("-"):
stripped = stripped[1:]
return stripped
return parts
@classmethod
def _match(
cls, parts: list[str], patterns: tuple[CommandPattern, ...]
) -> CommandPattern | None:
lowered = tuple(part.lower() for part in parts)
for pattern in patterns:
if lowered[: len(pattern.words)] == pattern.words:
return pattern
return None
@classmethod
def _is_destructive(cls, parts: list[str], lowered_command: str) -> bool:
if any(fragment in lowered_command for fragment in cls.DANGEROUS_FRAGMENTS):
return True
return bool(parts) and parts[0].lower() in cls.DESTRUCTIVE_COMMANDS
@classmethod
def _result(
cls,
command: str,
normalized_command: str,
action_type: str,
risk_level: str,
requires_approval: bool,
requires_password: bool,
blocked: bool,
reason: str | None,
) -> CommandPolicyResult:
return CommandPolicyResult(
command=command,
normalized_command=normalized_command,
action_type=action_type,
risk_level=risk_level,
requires_approval=requires_approval,
requires_password=requires_password,
blocked=blocked,
reason=reason,
)

View File

@ -3,6 +3,7 @@ import subprocess
from typing import Any from typing import Any
from duck_core.tools.base import ToolResult from duck_core.tools.base import ToolResult
from duck_core.tools.command_policy import CommandPolicy
ALLOWLIST = { ALLOWLIST = {
@ -60,15 +61,19 @@ class ShellExecSafeTool:
command = str(args.get("command", "")).strip() command = str(args.get("command", "")).strip()
approved = bool(args.get("_approved")) approved = bool(args.get("_approved"))
password = args.get("_password") password = args.get("_password")
policy = CommandPolicy.classify(command)
allowed, reason, blocked = self._is_allowed(command, approved=approved) allowed, reason, blocked = self._is_allowed(command, approved=approved)
if not allowed: if not allowed:
metadata = {"blocked": True} if blocked else {"requires_approval": True} metadata = {
**policy.model_dump(),
**({"blocked": True} if blocked else {"requires_approval": True}),
}
return ToolResult(ok=False, error=reason, metadata=metadata) return ToolResult(ok=False, error=reason, metadata=metadata)
if self._is_sudo_command(command) and not password: if self._is_sudo_command(command) and not password:
return ToolResult( return ToolResult(
ok=False, ok=False,
error="Sudo password is required to run this command.", error="Sudo password is required to run this command.",
metadata={"requires_password": True}, metadata={**policy.model_dump(), "requires_password": True},
) )
run_command = self._sudo_stdin_command(command) if self._is_sudo_command(command) else command run_command = self._sudo_stdin_command(command) if self._is_sudo_command(command) else command
input_text = f"{password}\n" if self._is_sudo_command(command) else None input_text = f"{password}\n" if self._is_sudo_command(command) else None
@ -89,23 +94,27 @@ class ShellExecSafeTool:
ok=completed.returncode == 0, ok=completed.returncode == 0,
output=completed.stdout, output=completed.stdout,
error=completed.stderr if completed.returncode else None, error=completed.stderr if completed.returncode else None,
metadata={"returncode": completed.returncode, "command": command}, metadata={
**{
key: value
for key, value in policy.model_dump().items()
if key not in {"requires_approval", "requires_password"}
},
"returncode": completed.returncode,
},
) )
def _is_allowed( def _is_allowed(
self, command: str, approved: bool = False self, command: str, approved: bool = False
) -> tuple[bool, str | None, bool]: ) -> tuple[bool, str | None, bool]:
if not command: policy = CommandPolicy.classify(command)
return False, "Empty command", False if policy.blocked:
lowered = command.lower() return False, policy.reason, True
parts = shlex.split(command)
for blocked in BLOCKLIST:
if self._matches_blocked_command(lowered, parts, blocked):
return False, f"Command is blocked: {blocked}", True
if approved: if approved:
return True, None, False return True, None, False
if self._is_sudo_command(command): if policy.requires_approval:
return False, "Sudo command requires approval.", False return False, policy.reason, False
parts = shlex.split(command)
prefix1 = parts[0] if parts else "" prefix1 = parts[0] if parts else ""
prefix2 = " ".join(parts[:2]) prefix2 = " ".join(parts[:2])
prefix3 = " ".join(parts[:3]) prefix3 = " ".join(parts[:3])

View File

@ -323,3 +323,57 @@ def test_password_stream_runs_sudo_with_password_and_streams_answer(tmp_path, mo
assert "event: content_delta" in body assert "event: content_delta" in body
assert "sudo command completed" in body assert "sudo command completed" in body
assert "secret" not in body assert "secret" not in body
def test_command_audit_endpoint_exposes_redacted_shell_events(tmp_path, monkeypatch):
monkeypatch.setenv("DUCK_DB_PATH", str(tmp_path / "duck.sqlite3"))
async def fake_chat(self, role, messages, temperature=None, max_output_tokens=None, response_format=None):
assert role == "action"
actions = []
if not any("tool_observations" in message["content"] for message in messages):
actions = [
{
"tool": "shell_exec_safe",
"args": {"command": "apt list --upgradable"},
"reason": "Check available updates",
}
]
return ModelResponse(
role=role,
model="local-main",
content=json.dumps(
{
"kind": "action_directive",
"intent": "check updates",
"risk_level": "low",
"actions": actions,
}
),
reasoning_content=None,
raw={},
latency_ms=1.0,
)
async def fake_stream_chat(self, role, messages):
yield {"type": "content_delta", "delta": "updates checked"}
monkeypatch.setattr("duck_core.model_client.ModelClient.chat", fake_chat)
monkeypatch.setattr("duck_core.model_client.ModelClient.stream_chat", fake_stream_chat)
client = TestClient(create_app())
with client.stream(
"POST",
"/v1/chat/stream",
json={"message": "check updates", "workspace": str(tmp_path), "debug": True},
) as response:
_ = "".join(response.iter_text())
audit = client.get("/v1/audit/commands").json()
assert response.status_code == 200
assert audit[0]["event_type"] == "command_audit"
assert audit[0]["payload"]["command"] == "apt list --upgradable"
assert audit[0]["payload"]["risk_level"] == "low"
assert audit[0]["payload"]["action_type"] == "package_check"
assert audit[0]["payload"]["approved"] is False
assert "password" not in json.dumps(audit).lower()

View File

@ -3,6 +3,7 @@ import pytest
from duck_core.tools.file_read import FileReadTool from duck_core.tools.file_read import FileReadTool
from duck_core.tools.file_write import FileWriteTool from duck_core.tools.file_write import FileWriteTool
from duck_core.tools.gateway import ToolGateway from duck_core.tools.gateway import ToolGateway
from duck_core.tools.command_policy import CommandPolicy
from duck_core.tools.shell_exec_safe import ShellExecSafeTool from duck_core.tools.shell_exec_safe import ShellExecSafeTool
@ -139,3 +140,22 @@ async def test_shell_tool_allows_read_only_apt_update_check(monkeypatch, tmp_pat
assert result.ok is True assert result.ok is True
assert "bootlogd" in result.output assert "bootlogd" in result.output
assert result.metadata["command"] == "apt list --upgradable" assert result.metadata["command"] == "apt list --upgradable"
def test_command_policy_classifies_common_system_commands():
readonly = CommandPolicy.classify("apt list --upgradable")
sudo_update = CommandPolicy.classify("sudo apt update")
install = CommandPolicy.classify("sudo apt install vim")
destructive = CommandPolicy.classify("rm -rf .")
assert readonly.risk_level == "low"
assert readonly.action_type == "package_check"
assert readonly.requires_approval is False
assert sudo_update.risk_level == "high"
assert sudo_update.action_type == "package_cache_update"
assert sudo_update.requires_password is True
assert install.risk_level == "high"
assert install.action_type == "package_install"
assert install.requires_password is True
assert destructive.risk_level == "critical"
assert destructive.blocked is True