Add command policy audit trail
This commit is contained in:
parent
45b70d2800
commit
61cdd8bbaf
|
|
@ -344,6 +344,12 @@ def create_app() -> FastAPI:
|
|||
"reasoning_content": reasoning_content,
|
||||
},
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
await task_store.cancel_task(task.task_id)
|
||||
await event_store.append(
|
||||
task.task_id, "task_cancelled", {"reason": "client_disconnected"}
|
||||
)
|
||||
raise
|
||||
except Exception as exc:
|
||||
await task_store.fail_task(task.task_id, str(exc))
|
||||
await event_store.append(task.task_id, "task_failed", {"error": str(exc)})
|
||||
|
|
@ -379,6 +385,14 @@ def create_app() -> FastAPI:
|
|||
async def get_events(task_id: str) -> list[dict[str, Any]]:
|
||||
return [event.model_dump() for event in await event_store.list_events(task_id)]
|
||||
|
||||
@app.get("/v1/audit/commands")
|
||||
async def command_audit(limit: int = 100) -> list[dict[str, Any]]:
|
||||
bounded_limit = min(max(limit, 1), 500)
|
||||
return [
|
||||
event.model_dump()
|
||||
for event in await event_store.list_by_type("command_audit", bounded_limit)
|
||||
]
|
||||
|
||||
@app.get("/v1/tasks/{task_id}/stream")
|
||||
async def stream_events(task_id: str) -> StreamingResponse:
|
||||
async def generator():
|
||||
|
|
@ -553,6 +567,12 @@ def create_app() -> FastAPI:
|
|||
"reasoning_content": reasoning_content,
|
||||
},
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
await task_store.cancel_task(task_id)
|
||||
await event_store.append(
|
||||
task_id, "task_cancelled", {"reason": "client_disconnected"}
|
||||
)
|
||||
raise
|
||||
except Exception as exc:
|
||||
await task_store.fail_task(task_id, str(exc))
|
||||
await event_store.append(task_id, "task_failed", {"error": str(exc)})
|
||||
|
|
@ -709,6 +729,12 @@ def create_app() -> FastAPI:
|
|||
"reasoning_content": reasoning_content,
|
||||
},
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
await task_store.cancel_task(task_id)
|
||||
await event_store.append(
|
||||
task_id, "task_cancelled", {"reason": "client_disconnected"}
|
||||
)
|
||||
raise
|
||||
except Exception as exc:
|
||||
await task_store.fail_task(task_id, str(exc))
|
||||
await event_store.append(task_id, "task_failed", {"error": str(exc)})
|
||||
|
|
|
|||
|
|
@ -90,3 +90,29 @@ class EventStore:
|
|||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def list_by_type(self, event_type: str, limit: int = 100) -> list[Event]:
|
||||
await self.init()
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"""
|
||||
select * from events
|
||||
where event_type = ?
|
||||
order by id desc
|
||||
limit ?
|
||||
""",
|
||||
(event_type, limit),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [
|
||||
Event(
|
||||
id=row["id"],
|
||||
task_id=row["task_id"],
|
||||
sequence=row["sequence"],
|
||||
event_type=row["event_type"],
|
||||
payload=json.loads(row["payload_json"]),
|
||||
created_at=row["created_at"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
|
|
|||
|
|
@ -267,6 +267,9 @@ class RuntimeLoop:
|
|||
result = await gateway.run_action(action, approved=True, password=password)
|
||||
|
||||
result_payload = result.model_dump()
|
||||
await self._append_command_audit(
|
||||
task_id, index, tool_name, action, result_payload, approved=decision != "deny"
|
||||
)
|
||||
await self.event_store.append(
|
||||
task_id,
|
||||
"tool_call_finished",
|
||||
|
|
@ -334,6 +337,9 @@ class RuntimeLoop:
|
|||
)
|
||||
result = await gateway.run_action(action, approved=approved_forever)
|
||||
result_payload = result.model_dump()
|
||||
await self._append_command_audit(
|
||||
task_id, index, tool_name, action, result_payload, approved=approved_forever
|
||||
)
|
||||
if result.metadata.get("requires_approval"):
|
||||
approval = None
|
||||
if self.approval_service is not None:
|
||||
|
|
@ -375,6 +381,37 @@ class RuntimeLoop:
|
|||
)
|
||||
return observations
|
||||
|
||||
async def _append_command_audit(
|
||||
self,
|
||||
task_id: str,
|
||||
index: int,
|
||||
tool_name: str,
|
||||
action: dict[str, Any],
|
||||
result_payload: dict[str, Any],
|
||||
approved: bool,
|
||||
) -> None:
|
||||
if tool_name != "shell_exec_safe":
|
||||
return
|
||||
metadata = result_payload.get("metadata") or {}
|
||||
command = metadata.get("command") or (action.get("args") or {}).get("command")
|
||||
await self.event_store.append(
|
||||
task_id,
|
||||
"command_audit",
|
||||
{
|
||||
"index": index,
|
||||
"tool": tool_name,
|
||||
"command": command,
|
||||
"action_type": metadata.get("action_type"),
|
||||
"risk_level": metadata.get("risk_level"),
|
||||
"requires_approval": bool(metadata.get("requires_approval")),
|
||||
"blocked": bool(metadata.get("blocked")),
|
||||
"approved": approved,
|
||||
"ok": bool(result_payload.get("ok")),
|
||||
"returncode": metadata.get("returncode"),
|
||||
"reason": metadata.get("reason") or result_payload.get("error"),
|
||||
},
|
||||
)
|
||||
|
||||
async def _run_action_loop(
|
||||
self,
|
||||
task_id: str,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,180 @@
|
|||
import shlex
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommandPattern:
|
||||
words: tuple[str, ...]
|
||||
action_type: str
|
||||
risk_level: str
|
||||
|
||||
|
||||
class CommandPolicyResult(BaseModel):
|
||||
command: str
|
||||
normalized_command: str
|
||||
action_type: str
|
||||
risk_level: str
|
||||
requires_approval: bool
|
||||
requires_password: bool
|
||||
blocked: bool = False
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class CommandPolicy:
|
||||
READONLY_PATTERNS = (
|
||||
CommandPattern(("pwd",), "working_directory", "low"),
|
||||
CommandPattern(("ls",), "directory_list", "low"),
|
||||
CommandPattern(("cat",), "file_read", "low"),
|
||||
CommandPattern(("head",), "file_read", "low"),
|
||||
CommandPattern(("tail",), "file_read", "low"),
|
||||
CommandPattern(("grep",), "text_search", "low"),
|
||||
CommandPattern(("find",), "file_search", "low"),
|
||||
CommandPattern(("apt", "list"), "package_check", "low"),
|
||||
CommandPattern(("apt-cache", "policy"), "package_check", "low"),
|
||||
CommandPattern(("git", "status"), "vcs_status", "low"),
|
||||
CommandPattern(("git", "diff"), "vcs_inspect", "low"),
|
||||
CommandPattern(("git", "log"), "vcs_inspect", "low"),
|
||||
CommandPattern(("pytest",), "test_run", "low"),
|
||||
CommandPattern(("python", "-m", "pytest"), "test_run", "low"),
|
||||
CommandPattern(("python3", "-m", "pytest"), "test_run", "low"),
|
||||
)
|
||||
SYSTEM_PATTERNS = (
|
||||
CommandPattern(("apt", "update"), "package_cache_update", "high"),
|
||||
CommandPattern(("apt", "install"), "package_install", "high"),
|
||||
CommandPattern(("apt", "remove"), "package_remove", "high"),
|
||||
CommandPattern(("apt", "upgrade"), "package_upgrade", "high"),
|
||||
CommandPattern(("systemctl",), "service_control", "high"),
|
||||
CommandPattern(("service",), "service_control", "high"),
|
||||
)
|
||||
DESTRUCTIVE_COMMANDS = {
|
||||
"rm",
|
||||
"dd",
|
||||
"mkfs",
|
||||
"mount",
|
||||
"umount",
|
||||
"shutdown",
|
||||
"reboot",
|
||||
"poweroff",
|
||||
"su",
|
||||
}
|
||||
DANGEROUS_FRAGMENTS = ("curl | sh", "wget | sh", "chmod -r", "chown -r")
|
||||
|
||||
@classmethod
|
||||
def classify(cls, command: str) -> CommandPolicyResult:
|
||||
command = command.strip()
|
||||
parts = cls._split(command)
|
||||
normalized_parts = cls._strip_sudo(parts)
|
||||
normalized_command = shlex.join(normalized_parts) if normalized_parts else command
|
||||
uses_sudo = bool(parts) and parts[0] == "sudo"
|
||||
lowered = command.lower()
|
||||
|
||||
if not command:
|
||||
return cls._result(command, normalized_command, "empty", "low", False, False, True, "Empty command")
|
||||
if not parts:
|
||||
return cls._result(command, normalized_command, "invalid", "medium", True, False, False, "Command could not be parsed")
|
||||
if cls._is_destructive(normalized_parts, lowered):
|
||||
return cls._result(
|
||||
command,
|
||||
normalized_command,
|
||||
"destructive",
|
||||
"critical",
|
||||
False,
|
||||
uses_sudo,
|
||||
True,
|
||||
"Command is blocked by safety policy",
|
||||
)
|
||||
|
||||
readonly = cls._match(normalized_parts, cls.READONLY_PATTERNS)
|
||||
if readonly:
|
||||
requires_approval = uses_sudo
|
||||
return cls._result(
|
||||
command,
|
||||
normalized_command,
|
||||
readonly.action_type,
|
||||
readonly.risk_level if not uses_sudo else "medium",
|
||||
requires_approval,
|
||||
uses_sudo,
|
||||
False,
|
||||
"Sudo command requires approval" if uses_sudo else None,
|
||||
)
|
||||
|
||||
system = cls._match(normalized_parts, cls.SYSTEM_PATTERNS)
|
||||
if system:
|
||||
return cls._result(
|
||||
command,
|
||||
normalized_command,
|
||||
system.action_type,
|
||||
system.risk_level,
|
||||
True,
|
||||
uses_sudo,
|
||||
False,
|
||||
"System command requires approval",
|
||||
)
|
||||
|
||||
return cls._result(
|
||||
command,
|
||||
normalized_command,
|
||||
"shell_command",
|
||||
"medium",
|
||||
True,
|
||||
uses_sudo,
|
||||
False,
|
||||
"Command is outside allowlist and requires approval",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _split(cls, command: str) -> list[str]:
|
||||
try:
|
||||
return shlex.split(command)
|
||||
except ValueError:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _strip_sudo(cls, parts: list[str]) -> list[str]:
|
||||
if parts and parts[0] == "sudo":
|
||||
stripped = parts[1:]
|
||||
while stripped and stripped[0].startswith("-"):
|
||||
stripped = stripped[1:]
|
||||
return stripped
|
||||
return parts
|
||||
|
||||
@classmethod
|
||||
def _match(
|
||||
cls, parts: list[str], patterns: tuple[CommandPattern, ...]
|
||||
) -> CommandPattern | None:
|
||||
lowered = tuple(part.lower() for part in parts)
|
||||
for pattern in patterns:
|
||||
if lowered[: len(pattern.words)] == pattern.words:
|
||||
return pattern
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _is_destructive(cls, parts: list[str], lowered_command: str) -> bool:
|
||||
if any(fragment in lowered_command for fragment in cls.DANGEROUS_FRAGMENTS):
|
||||
return True
|
||||
return bool(parts) and parts[0].lower() in cls.DESTRUCTIVE_COMMANDS
|
||||
|
||||
@classmethod
|
||||
def _result(
|
||||
cls,
|
||||
command: str,
|
||||
normalized_command: str,
|
||||
action_type: str,
|
||||
risk_level: str,
|
||||
requires_approval: bool,
|
||||
requires_password: bool,
|
||||
blocked: bool,
|
||||
reason: str | None,
|
||||
) -> CommandPolicyResult:
|
||||
return CommandPolicyResult(
|
||||
command=command,
|
||||
normalized_command=normalized_command,
|
||||
action_type=action_type,
|
||||
risk_level=risk_level,
|
||||
requires_approval=requires_approval,
|
||||
requires_password=requires_password,
|
||||
blocked=blocked,
|
||||
reason=reason,
|
||||
)
|
||||
|
|
@ -3,6 +3,7 @@ import subprocess
|
|||
from typing import Any
|
||||
|
||||
from duck_core.tools.base import ToolResult
|
||||
from duck_core.tools.command_policy import CommandPolicy
|
||||
|
||||
|
||||
ALLOWLIST = {
|
||||
|
|
@ -60,15 +61,19 @@ class ShellExecSafeTool:
|
|||
command = str(args.get("command", "")).strip()
|
||||
approved = bool(args.get("_approved"))
|
||||
password = args.get("_password")
|
||||
policy = CommandPolicy.classify(command)
|
||||
allowed, reason, blocked = self._is_allowed(command, approved=approved)
|
||||
if not allowed:
|
||||
metadata = {"blocked": True} if blocked else {"requires_approval": True}
|
||||
metadata = {
|
||||
**policy.model_dump(),
|
||||
**({"blocked": True} if blocked else {"requires_approval": True}),
|
||||
}
|
||||
return ToolResult(ok=False, error=reason, metadata=metadata)
|
||||
if self._is_sudo_command(command) and not password:
|
||||
return ToolResult(
|
||||
ok=False,
|
||||
error="Sudo password is required to run this command.",
|
||||
metadata={"requires_password": True},
|
||||
metadata={**policy.model_dump(), "requires_password": True},
|
||||
)
|
||||
run_command = self._sudo_stdin_command(command) if self._is_sudo_command(command) else command
|
||||
input_text = f"{password}\n" if self._is_sudo_command(command) else None
|
||||
|
|
@ -89,23 +94,27 @@ class ShellExecSafeTool:
|
|||
ok=completed.returncode == 0,
|
||||
output=completed.stdout,
|
||||
error=completed.stderr if completed.returncode else None,
|
||||
metadata={"returncode": completed.returncode, "command": command},
|
||||
metadata={
|
||||
**{
|
||||
key: value
|
||||
for key, value in policy.model_dump().items()
|
||||
if key not in {"requires_approval", "requires_password"}
|
||||
},
|
||||
"returncode": completed.returncode,
|
||||
},
|
||||
)
|
||||
|
||||
def _is_allowed(
|
||||
self, command: str, approved: bool = False
|
||||
) -> tuple[bool, str | None, bool]:
|
||||
if not command:
|
||||
return False, "Empty command", False
|
||||
lowered = command.lower()
|
||||
parts = shlex.split(command)
|
||||
for blocked in BLOCKLIST:
|
||||
if self._matches_blocked_command(lowered, parts, blocked):
|
||||
return False, f"Command is blocked: {blocked}", True
|
||||
policy = CommandPolicy.classify(command)
|
||||
if policy.blocked:
|
||||
return False, policy.reason, True
|
||||
if approved:
|
||||
return True, None, False
|
||||
if self._is_sudo_command(command):
|
||||
return False, "Sudo command requires approval.", False
|
||||
if policy.requires_approval:
|
||||
return False, policy.reason, False
|
||||
parts = shlex.split(command)
|
||||
prefix1 = parts[0] if parts else ""
|
||||
prefix2 = " ".join(parts[:2])
|
||||
prefix3 = " ".join(parts[:3])
|
||||
|
|
|
|||
|
|
@ -323,3 +323,57 @@ def test_password_stream_runs_sudo_with_password_and_streams_answer(tmp_path, mo
|
|||
assert "event: content_delta" in body
|
||||
assert "sudo command completed" in body
|
||||
assert "secret" not in body
|
||||
|
||||
|
||||
def test_command_audit_endpoint_exposes_redacted_shell_events(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("DUCK_DB_PATH", str(tmp_path / "duck.sqlite3"))
|
||||
|
||||
async def fake_chat(self, role, messages, temperature=None, max_output_tokens=None, response_format=None):
|
||||
assert role == "action"
|
||||
actions = []
|
||||
if not any("tool_observations" in message["content"] for message in messages):
|
||||
actions = [
|
||||
{
|
||||
"tool": "shell_exec_safe",
|
||||
"args": {"command": "apt list --upgradable"},
|
||||
"reason": "Check available updates",
|
||||
}
|
||||
]
|
||||
return ModelResponse(
|
||||
role=role,
|
||||
model="local-main",
|
||||
content=json.dumps(
|
||||
{
|
||||
"kind": "action_directive",
|
||||
"intent": "check updates",
|
||||
"risk_level": "low",
|
||||
"actions": actions,
|
||||
}
|
||||
),
|
||||
reasoning_content=None,
|
||||
raw={},
|
||||
latency_ms=1.0,
|
||||
)
|
||||
|
||||
async def fake_stream_chat(self, role, messages):
|
||||
yield {"type": "content_delta", "delta": "updates checked"}
|
||||
|
||||
monkeypatch.setattr("duck_core.model_client.ModelClient.chat", fake_chat)
|
||||
monkeypatch.setattr("duck_core.model_client.ModelClient.stream_chat", fake_stream_chat)
|
||||
client = TestClient(create_app())
|
||||
|
||||
with client.stream(
|
||||
"POST",
|
||||
"/v1/chat/stream",
|
||||
json={"message": "check updates", "workspace": str(tmp_path), "debug": True},
|
||||
) as response:
|
||||
_ = "".join(response.iter_text())
|
||||
audit = client.get("/v1/audit/commands").json()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert audit[0]["event_type"] == "command_audit"
|
||||
assert audit[0]["payload"]["command"] == "apt list --upgradable"
|
||||
assert audit[0]["payload"]["risk_level"] == "low"
|
||||
assert audit[0]["payload"]["action_type"] == "package_check"
|
||||
assert audit[0]["payload"]["approved"] is False
|
||||
assert "password" not in json.dumps(audit).lower()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import pytest
|
|||
from duck_core.tools.file_read import FileReadTool
|
||||
from duck_core.tools.file_write import FileWriteTool
|
||||
from duck_core.tools.gateway import ToolGateway
|
||||
from duck_core.tools.command_policy import CommandPolicy
|
||||
from duck_core.tools.shell_exec_safe import ShellExecSafeTool
|
||||
|
||||
|
||||
|
|
@ -139,3 +140,22 @@ async def test_shell_tool_allows_read_only_apt_update_check(monkeypatch, tmp_pat
|
|||
assert result.ok is True
|
||||
assert "bootlogd" in result.output
|
||||
assert result.metadata["command"] == "apt list --upgradable"
|
||||
|
||||
|
||||
def test_command_policy_classifies_common_system_commands():
|
||||
readonly = CommandPolicy.classify("apt list --upgradable")
|
||||
sudo_update = CommandPolicy.classify("sudo apt update")
|
||||
install = CommandPolicy.classify("sudo apt install vim")
|
||||
destructive = CommandPolicy.classify("rm -rf .")
|
||||
|
||||
assert readonly.risk_level == "low"
|
||||
assert readonly.action_type == "package_check"
|
||||
assert readonly.requires_approval is False
|
||||
assert sudo_update.risk_level == "high"
|
||||
assert sudo_update.action_type == "package_cache_update"
|
||||
assert sudo_update.requires_password is True
|
||||
assert install.risk_level == "high"
|
||||
assert install.action_type == "package_install"
|
||||
assert install.requires_password is True
|
||||
assert destructive.risk_level == "critical"
|
||||
assert destructive.blocked is True
|
||||
|
|
|
|||
Loading…
Reference in New Issue