Add approval continuation and multi-step tools

This commit is contained in:
mirivlad 2026-05-20 03:34:13 +08:00
parent 2d3a047548
commit a4b7ef034a
13 changed files with 982 additions and 85 deletions

View File

@ -30,6 +30,10 @@ class ChatRequest(BaseModel):
debug: bool = False
class ContinueRequest(BaseModel):
approval_id: str
def create_app() -> FastAPI:
settings = get_settings()
if settings.api_host == "0.0.0.0":
@ -142,7 +146,7 @@ def create_app() -> FastAPI:
content_parts: list[str] = []
try:
messages = runtime.context_builder.build_basic_messages(task)
tool_observations = await runtime._run_action_tools(
tool_observations = await runtime._run_action_loop(
task.task_id, messages, body.workspace or settings.workspace
)
async for tool_event in emit_tool_events(task.task_id, task_event.sequence):
@ -283,6 +287,128 @@ def create_app() -> FastAPI:
await event_store.append(task_id, "task_continued", {})
return {"status": "running"}
@app.post("/v1/tasks/{task_id}/continue/stream")
async def continue_task_stream(task_id: str, body: ContinueRequest) -> StreamingResponse:
task = await task_store.get_task(task_id)
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
approval = await approvals.get(body.approval_id)
if approval is None or approval.task_id != task_id:
raise HTTPException(status_code=404, detail="Approval not found for task")
if approval.decision is None:
raise HTTPException(status_code=409, detail="Approval is still pending")
async def generator():
reasoning_parts: list[str] = []
content_parts: list[str] = []
try:
await task_store.update_status(task_id, "running")
continued_event = await event_store.append(
task_id,
"task_continued",
{"approval_id": body.approval_id, "decision": approval.decision},
)
tool_observation = await runtime._run_approved_or_denied_action(
task_id, approval.normalized_action, approval.decision
)
messages = runtime.context_builder.build_basic_messages(task)
tool_observations = [tool_observation]
if approval.decision != "deny":
tool_observations = await runtime._run_action_loop(
task_id,
messages,
task.workspace,
initial_observations=tool_observations,
)
async for tool_event in emit_tool_events(task_id, continued_event.sequence):
yield tool_event
if any(observation.get("requires_approval") for observation in tool_observations):
await task_store.waiting_for_approval(task_id)
await event_store.append(
task_id,
"task_waiting_for_approval",
{"observations": tool_observations},
)
yield sse(
"done",
{
"task_id": task_id,
"status": "waiting_for_approval",
"final_response": "Waiting for approval.",
"reasoning_content": None,
},
)
return
messages = [
*messages,
{
"role": "user",
"content": "tool_observations:\n"
+ json.dumps(tool_observations, ensure_ascii=False, indent=2),
},
]
await event_store.append(task_id, "model_call_started", {"role": "thinker"})
async for chunk in model_client.stream_chat("thinker", messages):
delta = str(chunk.get("delta") or "")
if chunk.get("type") == "reasoning_delta":
reasoning_parts.append(delta)
yield sse("reasoning_delta", {"task_id": task_id, "delta": delta})
elif chunk.get("type") == "content_delta":
content_parts.append(delta)
yield sse("content_delta", {"task_id": task_id, "delta": delta})
content = "".join(content_parts)
reasoning_content = "".join(reasoning_parts) or None
await event_store.append(
task_id,
"cognition_response",
{
"role": "thinker",
"content": content,
"reasoning_content": reasoning_content,
},
)
await event_store.append(
task_id,
"model_call_finished",
{
"role": "thinker",
"model": model_client.get_role_config("thinker").model,
},
)
await task_store.complete_task(task_id, content)
await event_store.append(
task_id,
"task_completed",
{
"final_response": content,
"reasoning_content": reasoning_content,
},
)
yield sse(
"done",
{
"task_id": task_id,
"status": "completed",
"final_response": content,
"reasoning_content": reasoning_content,
},
)
except Exception as exc:
await task_store.fail_task(task_id, str(exc))
await event_store.append(task_id, "task_failed", {"error": str(exc)})
yield sse(
"error",
{
"task_id": task_id,
"status": "failed",
"error": str(exc),
},
)
return StreamingResponse(generator(), media_type="text/event-stream")
@app.post("/v1/tasks/{task_id}/cancel")
async def cancel_task(task_id: str) -> dict[str, str]:
await task_store.cancel_task(task_id)

View File

@ -93,6 +93,16 @@ class ApprovalService:
rows = await cursor.fetchall()
return [self._row_to_approval(row) for row in rows]
async def get(self, approval_id: str) -> Approval | None:
await self.init()
async with aiosqlite.connect(self.db_path) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"select * from approvals where approval_id = ?", (approval_id,)
)
row = await cursor.fetchone()
return self._row_to_approval(row) if row else None
async def allow_once(self, approval_id: str) -> None:
await self._decide(approval_id, "resolved", "allow_once")

View File

@ -8,6 +8,7 @@ from duck_core.events.store import EventStore
from duck_core.model_client import ModelClient
from duck_core.tasks.store import TaskStore
from duck_core.tools.gateway import ToolGateway
from duck_core.tools.base import ToolResult
@dataclass
@ -26,12 +27,14 @@ class RuntimeLoop:
model_client: ModelClient | None = None,
context_builder: ContextBuilder | None = None,
approval_service: ApprovalService | None = None,
max_tool_iterations: int = 4,
):
self.task_store = task_store
self.event_store = event_store
self.model_client = model_client or ModelClient()
self.context_builder = context_builder or ContextBuilder()
self.approval_service = approval_service
self.max_tool_iterations = max_tool_iterations
async def run_chat(
self, message: str, workspace: str | None = None, debug: bool = False
@ -44,7 +47,7 @@ class RuntimeLoop:
)
try:
messages = self.context_builder.build_basic_messages(task)
tool_observations = await self._run_action_tools(task.task_id, messages, workspace)
tool_observations = await self._run_action_loop(task.task_id, messages, workspace)
if any(observation.get("requires_approval") for observation in tool_observations):
await self.task_store.waiting_for_approval(task.task_id)
await self.event_store.append(
@ -119,8 +122,172 @@ class RuntimeLoop:
reasoning_content=None,
)
async def continue_after_approval(self, task_id: str, approval_id: str) -> ChatResult:
if self.approval_service is None:
return ChatResult(
task_id=task_id,
status="failed",
final_response="Approval service is not configured.",
reasoning_content=None,
)
task = await self.task_store.get_task(task_id)
approval = await self.approval_service.get(approval_id)
if task is None:
return ChatResult(
task_id=task_id,
status="failed",
final_response="Task not found.",
reasoning_content=None,
)
if approval is None or approval.task_id != task_id:
return ChatResult(
task_id=task_id,
status="failed",
final_response="Approval not found for task.",
reasoning_content=None,
)
if approval.decision is None:
return ChatResult(
task_id=task_id,
status="waiting_for_approval",
final_response="Waiting for approval.",
reasoning_content=None,
)
await self.task_store.update_status(task_id, "running")
await self.event_store.append(
task_id,
"task_continued",
{"approval_id": approval_id, "decision": approval.decision},
)
try:
tool_observation = await self._run_approved_or_denied_action(
task_id, approval.normalized_action, approval.decision
)
messages = self.context_builder.build_basic_messages(task)
tool_observations = [tool_observation]
if approval.decision != "deny":
tool_observations = await self._run_action_loop(
task_id,
messages,
task.workspace,
initial_observations=tool_observations,
)
if any(observation.get("requires_approval") for observation in tool_observations):
await self.task_store.waiting_for_approval(task_id)
await self.event_store.append(
task_id,
"task_waiting_for_approval",
{"observations": tool_observations},
)
return ChatResult(
task_id=task_id,
status="waiting_for_approval",
final_response="Waiting for approval.",
reasoning_content=None,
)
messages = [
*messages,
{
"role": "user",
"content": "tool_observations:\n"
+ json.dumps(tool_observations, ensure_ascii=False, indent=2),
},
]
await self.event_store.append(task_id, "model_call_started", {"role": "thinker"})
response = await self.model_client.chat("thinker", messages)
await self.event_store.append(
task_id,
"cognition_response",
{
"role": response.role,
"content": response.content,
"reasoning_content": response.reasoning_content,
},
)
await self.event_store.append(
task_id,
"model_call_finished",
{
"role": response.role,
"model": response.model,
"latency_ms": response.latency_ms,
"prompt_tokens": response.prompt_tokens,
"completion_tokens": response.completion_tokens,
"total_tokens": response.total_tokens,
},
)
await self.task_store.complete_task(task_id, response.content)
await self.event_store.append(
task_id,
"task_completed",
{
"final_response": response.content,
"reasoning_content": response.reasoning_content,
},
)
return ChatResult(
task_id=task_id,
status="completed",
final_response=response.content,
reasoning_content=response.reasoning_content,
)
except Exception as exc:
await self.task_store.fail_task(task_id, str(exc))
await self.event_store.append(task_id, "task_failed", {"error": str(exc)})
return ChatResult(
task_id=task_id,
status="failed",
final_response=str(exc),
reasoning_content=None,
)
async def _run_approved_or_denied_action(
self, task_id: str, action: dict[str, Any], decision: str
) -> dict[str, Any]:
tool_name = str(action.get("tool", ""))
index = await self._approval_action_index(task_id, action)
if decision == "deny":
result = ToolResult(
ok=False,
error="Tool action denied by user.",
metadata={"decision": "deny"},
)
else:
gateway = ToolGateway.default((await self.task_store.get_task(task_id)).workspace or ".")
result = await gateway.run_action(action, approved=True)
result_payload = result.model_dump()
await self.event_store.append(
task_id,
"tool_call_finished",
{"index": index, "tool": tool_name, "result": result_payload},
)
return {
"index": index,
"tool": tool_name,
"reason": action.get("reason"),
"decision": decision,
"result": result_payload,
}
async def _approval_action_index(self, task_id: str, action: dict[str, Any]) -> int:
events = await self.event_store.list_events(task_id)
for event in reversed(events):
if (
event.event_type == "tool_approval_requested"
and event.payload.get("action") == action
):
return int(event.payload.get("index") or 1)
return 1
async def _run_action_tools(
self, task_id: str, messages: list[dict[str, str]], workspace: str | None
self,
task_id: str,
messages: list[dict[str, str]],
workspace: str | None,
start_index: int = 1,
) -> list[dict[str, Any]]:
try:
await self.event_store.append(task_id, "model_call_started", {"role": "action"})
@ -141,7 +308,7 @@ class RuntimeLoop:
gateway = ToolGateway.default(workspace or ".")
observations: list[dict[str, Any]] = []
for index, action in enumerate(actions, start=1):
for index, action in enumerate(actions, start=start_index):
if not isinstance(action, dict):
observations.append(
{"index": index, "ok": False, "error": "Action must be an object"}
@ -153,7 +320,11 @@ class RuntimeLoop:
"tool_call_started",
{"index": index, "tool": tool_name, "args": action.get("args") or {}},
)
result = await gateway.run_action(action)
approved_forever = (
self.approval_service is not None
and await self.approval_service.is_allowed_forever(action)
)
result = await gateway.run_action(action, approved=approved_forever)
result_payload = result.model_dump()
if result.metadata.get("requires_approval"):
approval = None
@ -195,3 +366,48 @@ class RuntimeLoop:
}
)
return observations
async def _run_action_loop(
self,
task_id: str,
messages: list[dict[str, str]],
workspace: str | None,
initial_observations: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]:
all_observations = list(initial_observations or [])
current_messages = messages
if all_observations:
current_messages = [
*messages,
{
"role": "user",
"content": "tool_observations:\n"
+ json.dumps(all_observations, ensure_ascii=False, indent=2),
},
]
for _ in range(self.max_tool_iterations):
observations = await self._run_action_tools(
task_id,
current_messages,
workspace,
start_index=len(all_observations) + 1,
)
if not observations:
return all_observations
all_observations.extend(observations)
if any(observation.get("requires_approval") for observation in observations):
return all_observations
current_messages = [
*messages,
{
"role": "user",
"content": "tool_observations:\n"
+ json.dumps(all_observations, ensure_ascii=False, indent=2),
},
]
await self.event_store.append(
task_id,
"tool_iteration_limit_reached",
{"max_tool_iterations": self.max_tool_iterations},
)
return all_observations

View File

@ -3,6 +3,8 @@ from typing import Any
from duck_core.tools.base import Tool, ToolResult
from duck_core.tools.file_read import FileReadTool
from duck_core.tools.file_write import FileWriteTool
from duck_core.tools.list_dir import ListDirTool
from duck_core.tools.search_files import SearchFilesTool
from duck_core.tools.shell_exec_safe import ShellExecSafeTool
@ -16,11 +18,13 @@ class ToolGateway:
[
FileReadTool(workspace),
FileWriteTool(workspace),
ListDirTool(workspace),
SearchFilesTool(workspace),
ShellExecSafeTool(workspace),
]
)
async def run_action(self, action: dict[str, Any]) -> ToolResult:
async def run_action(self, action: dict[str, Any], approved: bool = False) -> ToolResult:
tool_name = str(action.get("tool", ""))
tool = self.tools.get(tool_name)
if tool is None:
@ -28,4 +32,6 @@ class ToolGateway:
args = action.get("args") or {}
if not isinstance(args, dict):
return ToolResult(ok=False, error="Tool args must be an object")
if approved:
args = {**args, "_approved": True}
return await tool.run(args)

View File

@ -0,0 +1,42 @@
from typing import Any
from duck_core.tools.base import ToolResult
from duck_core.tools.paths import WorkspacePathError, resolve_workspace_path
class ListDirTool:
name = "list_dir"
risk_level = "low"
def __init__(self, workspace: str, max_entries: int = 200):
self.workspace = workspace
self.max_entries = max_entries
async def run(self, args: dict[str, Any]) -> ToolResult:
raw_path = str(args.get("path") or ".")
try:
path = resolve_workspace_path(self.workspace, raw_path)
except WorkspacePathError as exc:
return ToolResult(ok=False, error=str(exc))
if not path.exists():
return ToolResult(ok=False, error=f"Directory not found: {raw_path}")
if not path.is_dir():
return ToolResult(ok=False, error=f"Not a directory: {raw_path}")
entries = []
for child in sorted(path.iterdir(), key=lambda item: (not item.is_dir(), item.name.lower())):
suffix = "/" if child.is_dir() else ""
entries.append(f"{child.name}{suffix}")
if len(entries) >= self.max_entries:
break
truncated = len(entries) >= self.max_entries
return ToolResult(
ok=True,
output="\n".join(entries),
metadata={
"path": str(path),
"entries": len(entries),
"truncated": truncated,
},
)

View File

@ -0,0 +1,66 @@
from fnmatch import fnmatch
from typing import Any
from duck_core.tools.base import ToolResult
from duck_core.tools.paths import WorkspacePathError, resolve_workspace_path
class SearchFilesTool:
name = "search_files"
risk_level = "low"
def __init__(self, workspace: str, max_matches: int = 100, max_file_bytes: int = 1_000_000):
self.workspace = workspace
self.max_matches = max_matches
self.max_file_bytes = max_file_bytes
async def run(self, args: dict[str, Any]) -> ToolResult:
query = str(args.get("query") or "")
raw_path = str(args.get("path") or ".")
pattern = str(args.get("glob") or "*")
case_sensitive = bool(args.get("case_sensitive", True))
max_matches = min(int(args.get("max_matches") or self.max_matches), self.max_matches)
if not query:
return ToolResult(ok=False, error="Search query is required")
try:
root = resolve_workspace_path(self.workspace, ".")
path = resolve_workspace_path(self.workspace, raw_path)
except WorkspacePathError as exc:
return ToolResult(ok=False, error=str(exc))
if not path.exists():
return ToolResult(ok=False, error=f"Search path not found: {raw_path}")
needle = query if case_sensitive else query.lower()
matches: list[str] = []
files_scanned = 0
candidates = [path] if path.is_file() else path.rglob("*")
for candidate in candidates:
if len(matches) >= max_matches:
break
if not candidate.is_file() or ".git" in candidate.parts:
continue
relative = candidate.relative_to(root).as_posix()
if not fnmatch(candidate.name, pattern) and not fnmatch(relative, pattern):
continue
if candidate.stat().st_size > self.max_file_bytes:
continue
files_scanned += 1
text = candidate.read_text(errors="replace")
for line_number, line in enumerate(text.splitlines(), start=1):
haystack = line if case_sensitive else line.lower()
if needle in haystack:
matches.append(f"{relative}:{line_number}:{line}")
if len(matches) >= max_matches:
break
return ToolResult(
ok=True,
output="\n".join(matches),
metadata={
"path": str(path),
"query": query,
"matches": len(matches),
"files_scanned": files_scanned,
"truncated": len(matches) >= max_matches,
},
)

View File

@ -57,7 +57,8 @@ class ShellExecSafeTool:
async def run(self, args: dict[str, Any]) -> ToolResult:
command = str(args.get("command", "")).strip()
allowed, reason = self._is_allowed(command)
approved = bool(args.get("_approved"))
allowed, reason = self._is_allowed(command, approved=approved)
if not allowed:
return ToolResult(ok=False, error=reason, metadata={"requires_approval": True})
try:
@ -79,13 +80,15 @@ class ShellExecSafeTool:
metadata={"returncode": completed.returncode, "command": command},
)
def _is_allowed(self, command: str) -> tuple[bool, str | None]:
def _is_allowed(self, command: str, approved: bool = False) -> tuple[bool, str | None]:
if not command:
return False, "Empty command"
lowered = command.lower()
for blocked in BLOCKLIST:
if lowered.startswith(blocked.lower()) or blocked.lower() in lowered:
return False, f"Command is blocked: {blocked}"
if approved:
return True, None
parts = shlex.split(command)
prefix1 = parts[0] if parts else ""
prefix2 = " ".join(parts[:2])

View File

@ -165,6 +165,7 @@ function appendApprovalTerminal(article, eventPayload) {
const command = formatToolCommand(action.tool || payload.tool, action.args || {});
terminal?.classList.add("is-waiting");
if (terminal && payload.approval_id) terminal.dataset.approvalId = payload.approval_id;
if (terminal && eventPayload.task_id) terminal.dataset.taskId = eventPayload.task_id;
if (status) status.textContent = "approval";
if (body) {
body.textContent = [
@ -207,6 +208,7 @@ function inlineApprovalButton(label, action, tone = "") {
async function resolveInlineApproval(button) {
const terminal = button.closest(".tool-terminal");
const approvalId = terminal?.dataset.approvalId;
const taskId = terminal?.dataset.taskId;
const action = button.dataset.inlineApprovalAction;
if (!terminal || !approvalId || !action) return;
@ -226,6 +228,9 @@ async function resolveInlineApproval(button) {
if (status) status.textContent = decision;
if (body) body.textContent = `${command}\n\n${decision}: ${humanApprovalDecision(action)}`;
actions?.remove();
if (taskId) {
await continueAfterInlineApproval(terminal.closest(".message"), taskId, approvalId);
}
}
function humanApprovalDecision(action) {
@ -320,8 +325,8 @@ function parseSseBlock(block) {
return {name: event.name, data: JSON.parse(event.data)};
}
async function streamChat(payload, onEvent) {
const response = await fetch("/v1/chat/stream", {
async function streamSse(url, payload, onEvent) {
const response = await fetch(url, {
method: "POST",
headers: {"Content-Type": "application/json"},
body: JSON.stringify(payload),
@ -350,31 +355,15 @@ async function streamChat(payload, onEvent) {
}
}
async function sendMessage() {
if (state.running) return;
const input = document.querySelector("#message");
const message = input.value.trim();
if (!message) return;
async function streamChat(payload, onEvent) {
await streamSse("/v1/chat/stream", payload, onEvent);
}
state.running = true;
document.querySelector("#run").disabled = true;
setStatus("#task-status", "running", "warn");
addMessage("user", message, "submitted");
input.value = "";
const pending = addMessage("assistant", "", "thinking", {reasoning: true});
let taskId = "";
let contentStarted = false;
try {
await streamChat({
message,
workspace: document.querySelector("#workspace").value,
debug: document.querySelector("#debug").checked,
}, async ({name, data}) => {
if (data.task_id) taskId = data.task_id;
async function handleAssistantStreamEvent(pending, name, data, context) {
if (data.task_id) context.taskId = data.task_id;
if (name === "task_created") {
taskId = data.task_id;
setStatus("#task-status", taskId, "warn");
context.taskId = data.task_id;
setStatus("#task-status", data.task_id, "warn");
return;
}
if (name === "reasoning_delta") {
@ -398,8 +387,8 @@ async function sendMessage() {
return;
}
if (name === "content_delta") {
if (!contentStarted) {
contentStarted = true;
if (!context.contentStarted) {
context.contentStarted = true;
setMessagePending(pending, "");
}
pending.querySelector(".message-meta span").textContent = "answering";
@ -407,7 +396,7 @@ async function sendMessage() {
return;
}
if (name === "done") {
if (!contentStarted) {
if (!context.contentStarted) {
setMessagePending(pending, data.final_response || "No final content returned.");
}
pending.querySelector(".message-meta span").textContent = data.status;
@ -419,13 +408,60 @@ async function sendMessage() {
if (name === "error") {
throw new Error(data.error || "Stream failed.");
}
}
async function continueAfterInlineApproval(article, taskId, approvalId) {
if (!article || state.running) return;
state.running = true;
document.querySelector("#run").disabled = true;
setStatus("#task-status", taskId, "warn");
const context = {taskId, contentStarted: false};
try {
await streamSse(
`/v1/tasks/${encodeURIComponent(taskId)}/continue/stream`,
{approval_id: approvalId},
async ({name, data}) => handleAssistantStreamEvent(article, name, data, context),
);
} catch (error) {
setMessagePending(article, error.message);
article.querySelector(".message-meta span").textContent = "failed";
setStatus("#task-status", "failed", "bad");
await refreshEvents(taskId);
} finally {
state.running = false;
document.querySelector("#run").disabled = false;
document.querySelector("#message")?.focus();
}
}
async function sendMessage() {
if (state.running) return;
const input = document.querySelector("#message");
const message = input.value.trim();
if (!message) return;
state.running = true;
document.querySelector("#run").disabled = true;
setStatus("#task-status", "running", "warn");
addMessage("user", message, "submitted");
input.value = "";
const pending = addMessage("assistant", "", "thinking", {reasoning: true});
const context = {taskId: "", contentStarted: false};
try {
await streamChat({
message,
workspace: document.querySelector("#workspace").value,
debug: document.querySelector("#debug").checked,
}, async ({name, data}) => {
await handleAssistantStreamEvent(pending, name, data, context);
});
} catch (error) {
if (!taskId) input.value = message;
if (!context.taskId) input.value = message;
setMessagePending(pending, error.message);
pending.querySelector(".message-meta span").textContent = "failed";
setStatus("#task-status", "failed", "bad");
if (taskId) await refreshEvents(taskId);
if (context.taskId) await refreshEvents(context.taskId);
} finally {
state.running = false;
document.querySelector("#run").disabled = false;

View File

@ -8,9 +8,16 @@ Available tools:
Args: {"path": "relative/path.txt"}
- file_write: write a file inside the current workspace.
Args: {"path": "relative/path.txt", "content": "text", "overwrite": false}
- list_dir: list direct children of a directory inside the current workspace.
Args: {"path": "."}
- search_files: search text inside files in the current workspace.
Args: {"query": "text to find", "path": ".", "glob": "*.py"}
- shell_exec_safe: run a safe allowlisted shell command in the current workspace.
Args: {"command": "pwd"}
Return actions=[] when the user can be answered directly without tools.
When tool_observations are already present, request only genuinely missing
follow-up information. Return actions=[] when the observations are sufficient
for the thinker to answer.
Use only the listed tools. Keep actions minimal and directly tied to the user's
request. Do not invent tool names.

View File

@ -1,5 +1,6 @@
from fastapi.testclient import TestClient
import json
import re
from duck_core.model_client import ModelResponse
@ -56,6 +57,16 @@ def test_stream_chat_endpoint_executes_tool_before_streaming_answer(tmp_path, mo
async def fake_chat(self, role, messages, temperature=None, max_output_tokens=None, response_format=None):
assert role == "action"
if any("tool_observations" in message["content"] for message in messages):
actions = []
else:
actions = [
{
"tool": "file_read",
"args": {"path": "note.txt"},
"reason": "User asked for file contents",
}
]
return ModelResponse(
role=role,
model="local-main",
@ -64,13 +75,7 @@ def test_stream_chat_endpoint_executes_tool_before_streaming_answer(tmp_path, mo
"kind": "action_directive",
"intent": "read requested file",
"risk_level": "low",
"actions": [
{
"tool": "file_read",
"args": {"path": "note.txt"},
"reason": "User asked for file contents",
}
],
"actions": actions,
}
),
reasoning_content=None,
@ -101,3 +106,70 @@ def test_stream_chat_endpoint_executes_tool_before_streaming_answer(tmp_path, mo
assert "event: content_delta" in body
assert "answer from tool" in body
assert "event: done" in body
def test_continue_stream_executes_approved_tool_and_streams_answer(tmp_path, monkeypatch):
monkeypatch.setenv("DUCK_DB_PATH", str(tmp_path / "duck.sqlite3"))
async def fake_chat(self, role, messages, temperature=None, max_output_tokens=None, response_format=None):
assert role == "action"
if any("tool_observations" in message["content"] for message in messages):
actions = []
else:
actions = [
{
"tool": "shell_exec_safe",
"args": {"command": "uname -a"},
"reason": "User asked for system information",
}
]
return ModelResponse(
role=role,
model="local-main",
content=json.dumps(
{
"kind": "action_directive",
"intent": "run command",
"risk_level": "medium",
"actions": actions,
}
),
reasoning_content=None,
raw={},
latency_ms=1.0,
)
async def fake_stream_chat(self, role, messages):
assert role == "thinker"
observation_message = next(message for message in messages if "tool_observations" in message["content"])
assert "uname" in observation_message["content"]
yield {"type": "content_delta", "delta": "continued after approval"}
monkeypatch.setattr("duck_core.model_client.ModelClient.chat", fake_chat)
monkeypatch.setattr("duck_core.model_client.ModelClient.stream_chat", fake_stream_chat)
client = TestClient(create_app())
with client.stream(
"POST",
"/v1/chat/stream",
json={"message": "run uname", "workspace": str(tmp_path), "debug": True},
) as response:
initial_body = "".join(response.iter_text())
task_id = re.search(r'"task_id"\s*:\s*"([^"]+)"', initial_body).group(1)
pending = client.get("/v1/approvals/pending").json()
approval = next(item for item in pending if item["task_id"] == task_id)
client.post(f"/v1/approvals/{approval['approval_id']}/allow_once")
with client.stream(
"POST",
f"/v1/tasks/{approval['task_id']}/continue/stream",
json={"approval_id": approval["approval_id"]},
) as response:
body = "".join(response.iter_text())
assert "event: tool_approval_requested" in initial_body
assert response.status_code == 200
assert "event: tool_call_finished" in body
assert "event: content_delta" in body
assert "continued after approval" in body
assert "event: done" in body

View File

@ -73,7 +73,7 @@ def test_chat_api_exposes_pending_approval_from_runtime_tool_gate(tmp_path, monk
"actions": [
{
"tool": "shell_exec_safe",
"args": {"command": "uname -a"},
"args": {"command": "hostname --pending-approval-test"},
"reason": "needs shell command",
}
],
@ -90,7 +90,11 @@ def test_chat_api_exposes_pending_approval_from_runtime_tool_gate(tmp_path, monk
response = client.post("/v1/chat", json={"message": "run uname", "debug": True})
approvals = client.get("/v1/approvals/pending").json()
approval = next(
item for item in approvals if item["task_id"] == response.json()["task_id"]
)
assert response.status_code == 200
assert response.json()["status"] == "waiting_for_approval"
assert approvals[0]["normalized_action"]["tool"] == "shell_exec_safe"
assert approval["normalized_action"]["tool"] == "shell_exec_safe"
assert approval["normalized_action"]["args"]["command"] == "hostname --pending-approval-test"

View File

@ -12,6 +12,16 @@ from duck_core.tasks.store import TaskStore
class FakeToolModelClient:
async def chat(self, role, messages):
if role == "action":
if any("tool_observations" in message["content"] for message in messages):
actions = []
else:
actions = [
{
"tool": "file_read",
"args": {"path": "note.txt"},
"reason": "User asked for file contents",
}
]
return ModelResponse(
role=role,
model="local-main",
@ -20,13 +30,7 @@ class FakeToolModelClient:
"kind": "action_directive",
"intent": "read requested file",
"risk_level": "low",
"actions": [
{
"tool": "file_read",
"args": {"path": "note.txt"},
"reason": "User asked for file contents",
}
],
"actions": actions,
}
),
reasoning_content=None,
@ -45,6 +49,58 @@ class FakeToolModelClient:
)
class FakeMultiStepToolModelClient:
async def chat(self, role, messages):
if role == "action":
observation_text = "\n".join(message["content"] for message in messages)
if "tool_observations" not in observation_text:
actions = [
{
"tool": "list_dir",
"args": {"path": "."},
"reason": "Find available files",
}
]
elif "README.md" in observation_text and "readme contents" not in observation_text:
actions = [
{
"tool": "file_read",
"args": {"path": "README.md"},
"reason": "Read discovered README",
}
]
else:
actions = []
return ModelResponse(
role=role,
model="local-main",
content=json.dumps(
{
"kind": "action_directive",
"intent": "multi-step file inspection",
"risk_level": "low",
"actions": actions,
}
),
reasoning_content=None,
raw={},
latency_ms=5.0,
)
assert role == "thinker"
observation_text = "\n".join(message["content"] for message in messages)
assert "list_dir" in observation_text
assert "file_read" in observation_text
assert "readme contents" in observation_text
return ModelResponse(
role=role,
model="local-main",
content="Readme inspected",
reasoning_content=None,
raw={},
latency_ms=12.0,
)
@pytest.mark.asyncio
async def test_runtime_executes_action_directive_tool_and_finishes_with_observation(tmp_path):
(tmp_path / "note.txt").write_text("hello from tool")
@ -67,9 +123,38 @@ async def test_runtime_executes_action_directive_tool_and_finishes_with_observat
assert tool_finished.payload["result"]["output"] == "hello from tool"
@pytest.mark.asyncio
async def test_runtime_runs_multiple_tool_steps_before_final_answer(tmp_path):
(tmp_path / "README.md").write_text("readme contents")
db_path = str(tmp_path / "duck.sqlite3")
task_store = TaskStore(db_path)
event_store = EventStore(db_path)
loop = RuntimeLoop(task_store, event_store, FakeMultiStepToolModelClient())
result = await loop.run_chat("inspect the workspace readme", str(tmp_path), debug=True)
events = await event_store.list_events(result.task_id)
finished_tools = [
event.payload["tool"] for event in events if event.event_type == "tool_call_finished"
]
assert result.status == "completed"
assert result.final_response == "Readme inspected"
assert finished_tools == ["list_dir", "file_read"]
class FakeApprovalModelClient:
async def chat(self, role, messages):
if role == "action":
if any("tool_observations" in message["content"] for message in messages):
actions = []
else:
actions = [
{
"tool": "shell_exec_safe",
"args": {"command": "uname -a"},
"reason": "User requested system information",
}
]
return ModelResponse(
role=role,
model="local-main",
@ -78,13 +163,7 @@ class FakeApprovalModelClient:
"kind": "action_directive",
"intent": "run command",
"risk_level": "medium",
"actions": [
{
"tool": "shell_exec_safe",
"args": {"command": "uname -a"},
"reason": "User requested system information",
}
],
"actions": actions,
}
),
reasoning_content=None,
@ -110,3 +189,197 @@ async def test_runtime_creates_pending_approval_when_tool_requires_it(tmp_path):
assert pending[0].task_id == result.task_id
assert pending[0].normalized_action["tool"] == "shell_exec_safe"
assert any(event.event_type == "tool_approval_requested" for event in events)
class FakeApprovalContinuationModelClient:
def __init__(self):
self.thinker_messages = []
async def chat(self, role, messages):
if role == "action":
if any("tool_observations" in message["content"] for message in messages):
actions = []
else:
actions = [
{
"tool": "shell_exec_safe",
"args": {"command": "uname -a"},
"reason": "User requested system information",
}
]
return ModelResponse(
role=role,
model="local-main",
content=json.dumps(
{
"kind": "action_directive",
"intent": "run command",
"risk_level": "medium",
"actions": actions,
}
),
reasoning_content=None,
raw={},
latency_ms=5.0,
)
assert role == "thinker"
self.thinker_messages = messages
assert any("tool_observations" in message["content"] for message in messages)
return ModelResponse(
role=role,
model="local-main",
content="uname completed",
reasoning_content="used approved shell command",
raw={},
latency_ms=10.0,
)
class FakeApprovalThenSecondToolModelClient:
async def chat(self, role, messages):
observation_text = "\n".join(message["content"] for message in messages)
if role == "action":
if "tool_observations" in observation_text and "second step content" not in observation_text:
actions = [
{
"tool": "file_read",
"args": {"path": "second.txt"},
"reason": "Read follow-up file after approved command",
}
]
elif "tool_observations" in observation_text:
actions = []
else:
actions = [
{
"tool": "shell_exec_safe",
"args": {"command": "uname -a"},
"reason": "User requested system information",
}
]
return ModelResponse(
role=role,
model="local-main",
content=json.dumps(
{
"kind": "action_directive",
"intent": "approval then follow-up",
"risk_level": "medium",
"actions": actions,
}
),
reasoning_content=None,
raw={},
latency_ms=5.0,
)
assert role == "thinker"
assert "shell_exec_safe" in observation_text
assert "file_read" in observation_text
assert "second step content" in observation_text
return ModelResponse(
role=role,
model="local-main",
content="approved command and second tool completed",
reasoning_content=None,
raw={},
latency_ms=10.0,
)
@pytest.mark.asyncio
async def test_runtime_continues_after_approved_tool_call(tmp_path):
db_path = str(tmp_path / "duck.sqlite3")
task_store = TaskStore(db_path)
event_store = EventStore(db_path)
approvals = ApprovalService(db_path)
model_client = FakeApprovalContinuationModelClient()
loop = RuntimeLoop(task_store, event_store, model_client, approval_service=approvals)
pending_result = await loop.run_chat("run uname", str(tmp_path), debug=True)
pending = await approvals.pending()
await approvals.allow_once(pending[0].approval_id)
result = await loop.continue_after_approval(pending_result.task_id, pending[0].approval_id)
events = await event_store.list_events(result.task_id)
finished = next(event for event in events if event.event_type == "tool_call_finished")
assert result.status == "completed"
assert result.final_response == "uname completed"
assert finished.payload["tool"] == "shell_exec_safe"
assert finished.payload["result"]["ok"] is True
assert "uname" in finished.payload["result"]["metadata"]["command"]
assert any(event.event_type == "task_completed" for event in events)
@pytest.mark.asyncio
async def test_runtime_can_run_followup_tool_after_approval(tmp_path):
(tmp_path / "second.txt").write_text("second step content")
db_path = str(tmp_path / "duck.sqlite3")
task_store = TaskStore(db_path)
event_store = EventStore(db_path)
approvals = ApprovalService(db_path)
loop = RuntimeLoop(
task_store,
event_store,
FakeApprovalThenSecondToolModelClient(),
approval_service=approvals,
)
pending_result = await loop.run_chat("run uname then inspect second file", str(tmp_path), debug=True)
pending = await approvals.pending()
await approvals.allow_once(pending[0].approval_id)
result = await loop.continue_after_approval(pending_result.task_id, pending[0].approval_id)
events = await event_store.list_events(result.task_id)
finished_tools = [
event.payload["tool"] for event in events if event.event_type == "tool_call_finished"
]
assert result.status == "completed"
assert finished_tools == ["shell_exec_safe", "file_read"]
@pytest.mark.asyncio
async def test_runtime_continues_after_denied_tool_call_without_execution(tmp_path):
db_path = str(tmp_path / "duck.sqlite3")
task_store = TaskStore(db_path)
event_store = EventStore(db_path)
approvals = ApprovalService(db_path)
model_client = FakeApprovalContinuationModelClient()
loop = RuntimeLoop(task_store, event_store, model_client, approval_service=approvals)
pending_result = await loop.run_chat("run uname", str(tmp_path), debug=True)
pending = await approvals.pending()
await approvals.deny(pending[0].approval_id)
result = await loop.continue_after_approval(pending_result.task_id, pending[0].approval_id)
events = await event_store.list_events(result.task_id)
finished = next(event for event in events if event.event_type == "tool_call_finished")
assert result.status == "completed"
assert finished.payload["result"]["ok"] is False
assert finished.payload["result"]["metadata"]["decision"] == "deny"
assert "denied" in finished.payload["result"]["error"].lower()
@pytest.mark.asyncio
async def test_runtime_reuses_allow_forever_for_matching_action(tmp_path):
db_path = str(tmp_path / "duck.sqlite3")
task_store = TaskStore(db_path)
event_store = EventStore(db_path)
approvals = ApprovalService(db_path)
model_client = FakeApprovalContinuationModelClient()
loop = RuntimeLoop(task_store, event_store, model_client, approval_service=approvals)
first_result = await loop.run_chat("run uname", str(tmp_path), debug=True)
first_pending = await approvals.pending()
await approvals.allow_forever(first_pending[0].approval_id)
await loop.continue_after_approval(first_result.task_id, first_pending[0].approval_id)
second_result = await loop.run_chat("run uname again", str(tmp_path), debug=True)
second_events = await event_store.list_events(second_result.task_id)
assert second_result.status == "completed"
assert second_result.final_response == "uname completed"
assert not any(event.event_type == "tool_approval_requested" for event in second_events)
assert any(event.event_type == "tool_call_finished" for event in second_events)

View File

@ -40,3 +40,39 @@ async def test_tool_gateway_runs_allowed_directive(tmp_path):
assert result.ok is True
assert result.metadata["path"].endswith("a.txt")
@pytest.mark.asyncio
async def test_tool_gateway_lists_workspace_directory(tmp_path):
(tmp_path / "src").mkdir()
(tmp_path / "src" / "app.py").write_text("print('duck')")
(tmp_path / "README.md").write_text("hello")
gateway = ToolGateway.default(str(tmp_path))
result = await gateway.run_action({"tool": "list_dir", "args": {"path": "."}})
escaped = await gateway.run_action({"tool": "list_dir", "args": {"path": ".."}})
assert result.ok is True
assert "README.md" in result.output
assert "src/" in result.output
assert escaped.ok is False
@pytest.mark.asyncio
async def test_tool_gateway_searches_file_contents(tmp_path):
(tmp_path / "src").mkdir()
(tmp_path / "src" / "app.py").write_text("duck tool gateway\\n")
(tmp_path / "notes.txt").write_text("other content\\n")
gateway = ToolGateway.default(str(tmp_path))
result = await gateway.run_action(
{"tool": "search_files", "args": {"query": "duck tool", "path": "."}}
)
escaped = await gateway.run_action(
{"tool": "search_files", "args": {"query": "duck", "path": ".."}}
)
assert result.ok is True
assert "src/app.py:1:duck tool gateway" in result.output
assert result.metadata["matches"] == 1
assert escaped.ok is False