diff --git a/duck_core/api.py b/duck_core/api.py index 1abad91..41dac97 100644 --- a/duck_core/api.py +++ b/duck_core/api.py @@ -30,6 +30,10 @@ class ChatRequest(BaseModel): debug: bool = False +class ContinueRequest(BaseModel): + approval_id: str + + def create_app() -> FastAPI: settings = get_settings() if settings.api_host == "0.0.0.0": @@ -142,7 +146,7 @@ def create_app() -> FastAPI: content_parts: list[str] = [] try: messages = runtime.context_builder.build_basic_messages(task) - tool_observations = await runtime._run_action_tools( + tool_observations = await runtime._run_action_loop( task.task_id, messages, body.workspace or settings.workspace ) async for tool_event in emit_tool_events(task.task_id, task_event.sequence): @@ -283,6 +287,128 @@ def create_app() -> FastAPI: await event_store.append(task_id, "task_continued", {}) return {"status": "running"} + @app.post("/v1/tasks/{task_id}/continue/stream") + async def continue_task_stream(task_id: str, body: ContinueRequest) -> StreamingResponse: + task = await task_store.get_task(task_id) + if task is None: + raise HTTPException(status_code=404, detail="Task not found") + approval = await approvals.get(body.approval_id) + if approval is None or approval.task_id != task_id: + raise HTTPException(status_code=404, detail="Approval not found for task") + if approval.decision is None: + raise HTTPException(status_code=409, detail="Approval is still pending") + + async def generator(): + reasoning_parts: list[str] = [] + content_parts: list[str] = [] + try: + await task_store.update_status(task_id, "running") + continued_event = await event_store.append( + task_id, + "task_continued", + {"approval_id": body.approval_id, "decision": approval.decision}, + ) + tool_observation = await runtime._run_approved_or_denied_action( + task_id, approval.normalized_action, approval.decision + ) + messages = runtime.context_builder.build_basic_messages(task) + tool_observations = [tool_observation] + if approval.decision != "deny": + tool_observations = await runtime._run_action_loop( + task_id, + messages, + task.workspace, + initial_observations=tool_observations, + ) + async for tool_event in emit_tool_events(task_id, continued_event.sequence): + yield tool_event + if any(observation.get("requires_approval") for observation in tool_observations): + await task_store.waiting_for_approval(task_id) + await event_store.append( + task_id, + "task_waiting_for_approval", + {"observations": tool_observations}, + ) + yield sse( + "done", + { + "task_id": task_id, + "status": "waiting_for_approval", + "final_response": "Waiting for approval.", + "reasoning_content": None, + }, + ) + return + + messages = [ + *messages, + { + "role": "user", + "content": "tool_observations:\n" + + json.dumps(tool_observations, ensure_ascii=False, indent=2), + }, + ] + await event_store.append(task_id, "model_call_started", {"role": "thinker"}) + async for chunk in model_client.stream_chat("thinker", messages): + delta = str(chunk.get("delta") or "") + if chunk.get("type") == "reasoning_delta": + reasoning_parts.append(delta) + yield sse("reasoning_delta", {"task_id": task_id, "delta": delta}) + elif chunk.get("type") == "content_delta": + content_parts.append(delta) + yield sse("content_delta", {"task_id": task_id, "delta": delta}) + + content = "".join(content_parts) + reasoning_content = "".join(reasoning_parts) or None + await event_store.append( + task_id, + "cognition_response", + { + "role": "thinker", + "content": content, + "reasoning_content": reasoning_content, + }, + ) + await event_store.append( + task_id, + "model_call_finished", + { + "role": "thinker", + "model": model_client.get_role_config("thinker").model, + }, + ) + await task_store.complete_task(task_id, content) + await event_store.append( + task_id, + "task_completed", + { + "final_response": content, + "reasoning_content": reasoning_content, + }, + ) + yield sse( + "done", + { + "task_id": task_id, + "status": "completed", + "final_response": content, + "reasoning_content": reasoning_content, + }, + ) + except Exception as exc: + await task_store.fail_task(task_id, str(exc)) + await event_store.append(task_id, "task_failed", {"error": str(exc)}) + yield sse( + "error", + { + "task_id": task_id, + "status": "failed", + "error": str(exc), + }, + ) + + return StreamingResponse(generator(), media_type="text/event-stream") + @app.post("/v1/tasks/{task_id}/cancel") async def cancel_task(task_id: str) -> dict[str, str]: await task_store.cancel_task(task_id) diff --git a/duck_core/approvals/service.py b/duck_core/approvals/service.py index f62eecf..b720477 100644 --- a/duck_core/approvals/service.py +++ b/duck_core/approvals/service.py @@ -93,6 +93,16 @@ class ApprovalService: rows = await cursor.fetchall() return [self._row_to_approval(row) for row in rows] + async def get(self, approval_id: str) -> Approval | None: + await self.init() + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "select * from approvals where approval_id = ?", (approval_id,) + ) + row = await cursor.fetchone() + return self._row_to_approval(row) if row else None + async def allow_once(self, approval_id: str) -> None: await self._decide(approval_id, "resolved", "allow_once") diff --git a/duck_core/runtime_loop.py b/duck_core/runtime_loop.py index ed6b22c..072910a 100644 --- a/duck_core/runtime_loop.py +++ b/duck_core/runtime_loop.py @@ -8,6 +8,7 @@ from duck_core.events.store import EventStore from duck_core.model_client import ModelClient from duck_core.tasks.store import TaskStore from duck_core.tools.gateway import ToolGateway +from duck_core.tools.base import ToolResult @dataclass @@ -26,12 +27,14 @@ class RuntimeLoop: model_client: ModelClient | None = None, context_builder: ContextBuilder | None = None, approval_service: ApprovalService | None = None, + max_tool_iterations: int = 4, ): self.task_store = task_store self.event_store = event_store self.model_client = model_client or ModelClient() self.context_builder = context_builder or ContextBuilder() self.approval_service = approval_service + self.max_tool_iterations = max_tool_iterations async def run_chat( self, message: str, workspace: str | None = None, debug: bool = False @@ -44,7 +47,7 @@ class RuntimeLoop: ) try: messages = self.context_builder.build_basic_messages(task) - tool_observations = await self._run_action_tools(task.task_id, messages, workspace) + tool_observations = await self._run_action_loop(task.task_id, messages, workspace) if any(observation.get("requires_approval") for observation in tool_observations): await self.task_store.waiting_for_approval(task.task_id) await self.event_store.append( @@ -119,8 +122,172 @@ class RuntimeLoop: reasoning_content=None, ) + async def continue_after_approval(self, task_id: str, approval_id: str) -> ChatResult: + if self.approval_service is None: + return ChatResult( + task_id=task_id, + status="failed", + final_response="Approval service is not configured.", + reasoning_content=None, + ) + + task = await self.task_store.get_task(task_id) + approval = await self.approval_service.get(approval_id) + if task is None: + return ChatResult( + task_id=task_id, + status="failed", + final_response="Task not found.", + reasoning_content=None, + ) + if approval is None or approval.task_id != task_id: + return ChatResult( + task_id=task_id, + status="failed", + final_response="Approval not found for task.", + reasoning_content=None, + ) + if approval.decision is None: + return ChatResult( + task_id=task_id, + status="waiting_for_approval", + final_response="Waiting for approval.", + reasoning_content=None, + ) + + await self.task_store.update_status(task_id, "running") + await self.event_store.append( + task_id, + "task_continued", + {"approval_id": approval_id, "decision": approval.decision}, + ) + try: + tool_observation = await self._run_approved_or_denied_action( + task_id, approval.normalized_action, approval.decision + ) + messages = self.context_builder.build_basic_messages(task) + tool_observations = [tool_observation] + if approval.decision != "deny": + tool_observations = await self._run_action_loop( + task_id, + messages, + task.workspace, + initial_observations=tool_observations, + ) + if any(observation.get("requires_approval") for observation in tool_observations): + await self.task_store.waiting_for_approval(task_id) + await self.event_store.append( + task_id, + "task_waiting_for_approval", + {"observations": tool_observations}, + ) + return ChatResult( + task_id=task_id, + status="waiting_for_approval", + final_response="Waiting for approval.", + reasoning_content=None, + ) + messages = [ + *messages, + { + "role": "user", + "content": "tool_observations:\n" + + json.dumps(tool_observations, ensure_ascii=False, indent=2), + }, + ] + await self.event_store.append(task_id, "model_call_started", {"role": "thinker"}) + response = await self.model_client.chat("thinker", messages) + await self.event_store.append( + task_id, + "cognition_response", + { + "role": response.role, + "content": response.content, + "reasoning_content": response.reasoning_content, + }, + ) + await self.event_store.append( + task_id, + "model_call_finished", + { + "role": response.role, + "model": response.model, + "latency_ms": response.latency_ms, + "prompt_tokens": response.prompt_tokens, + "completion_tokens": response.completion_tokens, + "total_tokens": response.total_tokens, + }, + ) + await self.task_store.complete_task(task_id, response.content) + await self.event_store.append( + task_id, + "task_completed", + { + "final_response": response.content, + "reasoning_content": response.reasoning_content, + }, + ) + return ChatResult( + task_id=task_id, + status="completed", + final_response=response.content, + reasoning_content=response.reasoning_content, + ) + except Exception as exc: + await self.task_store.fail_task(task_id, str(exc)) + await self.event_store.append(task_id, "task_failed", {"error": str(exc)}) + return ChatResult( + task_id=task_id, + status="failed", + final_response=str(exc), + reasoning_content=None, + ) + + async def _run_approved_or_denied_action( + self, task_id: str, action: dict[str, Any], decision: str + ) -> dict[str, Any]: + tool_name = str(action.get("tool", "")) + index = await self._approval_action_index(task_id, action) + if decision == "deny": + result = ToolResult( + ok=False, + error="Tool action denied by user.", + metadata={"decision": "deny"}, + ) + else: + gateway = ToolGateway.default((await self.task_store.get_task(task_id)).workspace or ".") + result = await gateway.run_action(action, approved=True) + + result_payload = result.model_dump() + await self.event_store.append( + task_id, + "tool_call_finished", + {"index": index, "tool": tool_name, "result": result_payload}, + ) + return { + "index": index, + "tool": tool_name, + "reason": action.get("reason"), + "decision": decision, + "result": result_payload, + } + + async def _approval_action_index(self, task_id: str, action: dict[str, Any]) -> int: + events = await self.event_store.list_events(task_id) + for event in reversed(events): + if ( + event.event_type == "tool_approval_requested" + and event.payload.get("action") == action + ): + return int(event.payload.get("index") or 1) + return 1 + async def _run_action_tools( - self, task_id: str, messages: list[dict[str, str]], workspace: str | None + self, + task_id: str, + messages: list[dict[str, str]], + workspace: str | None, + start_index: int = 1, ) -> list[dict[str, Any]]: try: await self.event_store.append(task_id, "model_call_started", {"role": "action"}) @@ -141,7 +308,7 @@ class RuntimeLoop: gateway = ToolGateway.default(workspace or ".") observations: list[dict[str, Any]] = [] - for index, action in enumerate(actions, start=1): + for index, action in enumerate(actions, start=start_index): if not isinstance(action, dict): observations.append( {"index": index, "ok": False, "error": "Action must be an object"} @@ -153,7 +320,11 @@ class RuntimeLoop: "tool_call_started", {"index": index, "tool": tool_name, "args": action.get("args") or {}}, ) - result = await gateway.run_action(action) + approved_forever = ( + self.approval_service is not None + and await self.approval_service.is_allowed_forever(action) + ) + result = await gateway.run_action(action, approved=approved_forever) result_payload = result.model_dump() if result.metadata.get("requires_approval"): approval = None @@ -195,3 +366,48 @@ class RuntimeLoop: } ) return observations + + async def _run_action_loop( + self, + task_id: str, + messages: list[dict[str, str]], + workspace: str | None, + initial_observations: list[dict[str, Any]] | None = None, + ) -> list[dict[str, Any]]: + all_observations = list(initial_observations or []) + current_messages = messages + if all_observations: + current_messages = [ + *messages, + { + "role": "user", + "content": "tool_observations:\n" + + json.dumps(all_observations, ensure_ascii=False, indent=2), + }, + ] + for _ in range(self.max_tool_iterations): + observations = await self._run_action_tools( + task_id, + current_messages, + workspace, + start_index=len(all_observations) + 1, + ) + if not observations: + return all_observations + all_observations.extend(observations) + if any(observation.get("requires_approval") for observation in observations): + return all_observations + current_messages = [ + *messages, + { + "role": "user", + "content": "tool_observations:\n" + + json.dumps(all_observations, ensure_ascii=False, indent=2), + }, + ] + await self.event_store.append( + task_id, + "tool_iteration_limit_reached", + {"max_tool_iterations": self.max_tool_iterations}, + ) + return all_observations diff --git a/duck_core/tools/gateway.py b/duck_core/tools/gateway.py index b9df257..b4d0a09 100644 --- a/duck_core/tools/gateway.py +++ b/duck_core/tools/gateway.py @@ -3,6 +3,8 @@ from typing import Any from duck_core.tools.base import Tool, ToolResult from duck_core.tools.file_read import FileReadTool from duck_core.tools.file_write import FileWriteTool +from duck_core.tools.list_dir import ListDirTool +from duck_core.tools.search_files import SearchFilesTool from duck_core.tools.shell_exec_safe import ShellExecSafeTool @@ -16,11 +18,13 @@ class ToolGateway: [ FileReadTool(workspace), FileWriteTool(workspace), + ListDirTool(workspace), + SearchFilesTool(workspace), ShellExecSafeTool(workspace), ] ) - async def run_action(self, action: dict[str, Any]) -> ToolResult: + async def run_action(self, action: dict[str, Any], approved: bool = False) -> ToolResult: tool_name = str(action.get("tool", "")) tool = self.tools.get(tool_name) if tool is None: @@ -28,4 +32,6 @@ class ToolGateway: args = action.get("args") or {} if not isinstance(args, dict): return ToolResult(ok=False, error="Tool args must be an object") + if approved: + args = {**args, "_approved": True} return await tool.run(args) diff --git a/duck_core/tools/list_dir.py b/duck_core/tools/list_dir.py new file mode 100644 index 0000000..db9b504 --- /dev/null +++ b/duck_core/tools/list_dir.py @@ -0,0 +1,42 @@ +from typing import Any + +from duck_core.tools.base import ToolResult +from duck_core.tools.paths import WorkspacePathError, resolve_workspace_path + + +class ListDirTool: + name = "list_dir" + risk_level = "low" + + def __init__(self, workspace: str, max_entries: int = 200): + self.workspace = workspace + self.max_entries = max_entries + + async def run(self, args: dict[str, Any]) -> ToolResult: + raw_path = str(args.get("path") or ".") + try: + path = resolve_workspace_path(self.workspace, raw_path) + except WorkspacePathError as exc: + return ToolResult(ok=False, error=str(exc)) + if not path.exists(): + return ToolResult(ok=False, error=f"Directory not found: {raw_path}") + if not path.is_dir(): + return ToolResult(ok=False, error=f"Not a directory: {raw_path}") + + entries = [] + for child in sorted(path.iterdir(), key=lambda item: (not item.is_dir(), item.name.lower())): + suffix = "/" if child.is_dir() else "" + entries.append(f"{child.name}{suffix}") + if len(entries) >= self.max_entries: + break + + truncated = len(entries) >= self.max_entries + return ToolResult( + ok=True, + output="\n".join(entries), + metadata={ + "path": str(path), + "entries": len(entries), + "truncated": truncated, + }, + ) diff --git a/duck_core/tools/search_files.py b/duck_core/tools/search_files.py new file mode 100644 index 0000000..f3daaed --- /dev/null +++ b/duck_core/tools/search_files.py @@ -0,0 +1,66 @@ +from fnmatch import fnmatch +from typing import Any + +from duck_core.tools.base import ToolResult +from duck_core.tools.paths import WorkspacePathError, resolve_workspace_path + + +class SearchFilesTool: + name = "search_files" + risk_level = "low" + + def __init__(self, workspace: str, max_matches: int = 100, max_file_bytes: int = 1_000_000): + self.workspace = workspace + self.max_matches = max_matches + self.max_file_bytes = max_file_bytes + + async def run(self, args: dict[str, Any]) -> ToolResult: + query = str(args.get("query") or "") + raw_path = str(args.get("path") or ".") + pattern = str(args.get("glob") or "*") + case_sensitive = bool(args.get("case_sensitive", True)) + max_matches = min(int(args.get("max_matches") or self.max_matches), self.max_matches) + if not query: + return ToolResult(ok=False, error="Search query is required") + try: + root = resolve_workspace_path(self.workspace, ".") + path = resolve_workspace_path(self.workspace, raw_path) + except WorkspacePathError as exc: + return ToolResult(ok=False, error=str(exc)) + if not path.exists(): + return ToolResult(ok=False, error=f"Search path not found: {raw_path}") + + needle = query if case_sensitive else query.lower() + matches: list[str] = [] + files_scanned = 0 + candidates = [path] if path.is_file() else path.rglob("*") + for candidate in candidates: + if len(matches) >= max_matches: + break + if not candidate.is_file() or ".git" in candidate.parts: + continue + relative = candidate.relative_to(root).as_posix() + if not fnmatch(candidate.name, pattern) and not fnmatch(relative, pattern): + continue + if candidate.stat().st_size > self.max_file_bytes: + continue + files_scanned += 1 + text = candidate.read_text(errors="replace") + for line_number, line in enumerate(text.splitlines(), start=1): + haystack = line if case_sensitive else line.lower() + if needle in haystack: + matches.append(f"{relative}:{line_number}:{line}") + if len(matches) >= max_matches: + break + + return ToolResult( + ok=True, + output="\n".join(matches), + metadata={ + "path": str(path), + "query": query, + "matches": len(matches), + "files_scanned": files_scanned, + "truncated": len(matches) >= max_matches, + }, + ) diff --git a/duck_core/tools/shell_exec_safe.py b/duck_core/tools/shell_exec_safe.py index a015545..daeb9d4 100644 --- a/duck_core/tools/shell_exec_safe.py +++ b/duck_core/tools/shell_exec_safe.py @@ -57,7 +57,8 @@ class ShellExecSafeTool: async def run(self, args: dict[str, Any]) -> ToolResult: command = str(args.get("command", "")).strip() - allowed, reason = self._is_allowed(command) + approved = bool(args.get("_approved")) + allowed, reason = self._is_allowed(command, approved=approved) if not allowed: return ToolResult(ok=False, error=reason, metadata={"requires_approval": True}) try: @@ -79,13 +80,15 @@ class ShellExecSafeTool: metadata={"returncode": completed.returncode, "command": command}, ) - def _is_allowed(self, command: str) -> tuple[bool, str | None]: + def _is_allowed(self, command: str, approved: bool = False) -> tuple[bool, str | None]: if not command: return False, "Empty command" lowered = command.lower() for blocked in BLOCKLIST: if lowered.startswith(blocked.lower()) or blocked.lower() in lowered: return False, f"Command is blocked: {blocked}" + if approved: + return True, None parts = shlex.split(command) prefix1 = parts[0] if parts else "" prefix2 = " ".join(parts[:2]) diff --git a/duck_core/web/static/app.js b/duck_core/web/static/app.js index 0ab1f13..7535f5a 100644 --- a/duck_core/web/static/app.js +++ b/duck_core/web/static/app.js @@ -165,6 +165,7 @@ function appendApprovalTerminal(article, eventPayload) { const command = formatToolCommand(action.tool || payload.tool, action.args || {}); terminal?.classList.add("is-waiting"); if (terminal && payload.approval_id) terminal.dataset.approvalId = payload.approval_id; + if (terminal && eventPayload.task_id) terminal.dataset.taskId = eventPayload.task_id; if (status) status.textContent = "approval"; if (body) { body.textContent = [ @@ -207,6 +208,7 @@ function inlineApprovalButton(label, action, tone = "") { async function resolveInlineApproval(button) { const terminal = button.closest(".tool-terminal"); const approvalId = terminal?.dataset.approvalId; + const taskId = terminal?.dataset.taskId; const action = button.dataset.inlineApprovalAction; if (!terminal || !approvalId || !action) return; @@ -226,6 +228,9 @@ async function resolveInlineApproval(button) { if (status) status.textContent = decision; if (body) body.textContent = `${command}\n\n${decision}: ${humanApprovalDecision(action)}`; actions?.remove(); + if (taskId) { + await continueAfterInlineApproval(terminal.closest(".message"), taskId, approvalId); + } } function humanApprovalDecision(action) { @@ -320,8 +325,8 @@ function parseSseBlock(block) { return {name: event.name, data: JSON.parse(event.data)}; } -async function streamChat(payload, onEvent) { - const response = await fetch("/v1/chat/stream", { +async function streamSse(url, payload, onEvent) { + const response = await fetch(url, { method: "POST", headers: {"Content-Type": "application/json"}, body: JSON.stringify(payload), @@ -350,6 +355,85 @@ async function streamChat(payload, onEvent) { } } +async function streamChat(payload, onEvent) { + await streamSse("/v1/chat/stream", payload, onEvent); +} + +async function handleAssistantStreamEvent(pending, name, data, context) { + if (data.task_id) context.taskId = data.task_id; + if (name === "task_created") { + context.taskId = data.task_id; + setStatus("#task-status", data.task_id, "warn"); + return; + } + if (name === "reasoning_delta") { + pending.querySelector(".message-meta span").textContent = "reasoning"; + appendInlineReasoning(pending, data.delta || ""); + return; + } + if (name === "tool_call_started") { + pending.querySelector(".message-meta span").textContent = "tool"; + appendToolTerminal(pending, data); + return; + } + if (name === "tool_call_finished") { + pending.querySelector(".message-meta span").textContent = "tool"; + updateToolTerminal(pending, data); + return; + } + if (name === "tool_approval_requested") { + pending.querySelector(".message-meta span").textContent = "approval"; + appendApprovalTerminal(pending, data); + return; + } + if (name === "content_delta") { + if (!context.contentStarted) { + context.contentStarted = true; + setMessagePending(pending, ""); + } + pending.querySelector(".message-meta span").textContent = "answering"; + appendMessageText(pending, data.delta || ""); + return; + } + if (name === "done") { + if (!context.contentStarted) { + setMessagePending(pending, data.final_response || "No final content returned."); + } + pending.querySelector(".message-meta span").textContent = data.status; + setStatus("#task-status", data.task_id, data.status === "completed" ? "ok" : "warn"); + finishInlineReasoning(pending, data.reasoning_content); + await refreshEvents(data.task_id); + return; + } + if (name === "error") { + throw new Error(data.error || "Stream failed."); + } +} + +async function continueAfterInlineApproval(article, taskId, approvalId) { + if (!article || state.running) return; + state.running = true; + document.querySelector("#run").disabled = true; + setStatus("#task-status", taskId, "warn"); + const context = {taskId, contentStarted: false}; + try { + await streamSse( + `/v1/tasks/${encodeURIComponent(taskId)}/continue/stream`, + {approval_id: approvalId}, + async ({name, data}) => handleAssistantStreamEvent(article, name, data, context), + ); + } catch (error) { + setMessagePending(article, error.message); + article.querySelector(".message-meta span").textContent = "failed"; + setStatus("#task-status", "failed", "bad"); + await refreshEvents(taskId); + } finally { + state.running = false; + document.querySelector("#run").disabled = false; + document.querySelector("#message")?.focus(); + } +} + async function sendMessage() { if (state.running) return; const input = document.querySelector("#message"); @@ -362,8 +446,7 @@ async function sendMessage() { addMessage("user", message, "submitted"); input.value = ""; const pending = addMessage("assistant", "", "thinking", {reasoning: true}); - let taskId = ""; - let contentStarted = false; + const context = {taskId: "", contentStarted: false}; try { await streamChat({ @@ -371,61 +454,14 @@ async function sendMessage() { workspace: document.querySelector("#workspace").value, debug: document.querySelector("#debug").checked, }, async ({name, data}) => { - if (data.task_id) taskId = data.task_id; - if (name === "task_created") { - taskId = data.task_id; - setStatus("#task-status", taskId, "warn"); - return; - } - if (name === "reasoning_delta") { - pending.querySelector(".message-meta span").textContent = "reasoning"; - appendInlineReasoning(pending, data.delta || ""); - return; - } - if (name === "tool_call_started") { - pending.querySelector(".message-meta span").textContent = "tool"; - appendToolTerminal(pending, data); - return; - } - if (name === "tool_call_finished") { - pending.querySelector(".message-meta span").textContent = "tool"; - updateToolTerminal(pending, data); - return; - } - if (name === "tool_approval_requested") { - pending.querySelector(".message-meta span").textContent = "approval"; - appendApprovalTerminal(pending, data); - return; - } - if (name === "content_delta") { - if (!contentStarted) { - contentStarted = true; - setMessagePending(pending, ""); - } - pending.querySelector(".message-meta span").textContent = "answering"; - appendMessageText(pending, data.delta || ""); - return; - } - if (name === "done") { - if (!contentStarted) { - setMessagePending(pending, data.final_response || "No final content returned."); - } - pending.querySelector(".message-meta span").textContent = data.status; - setStatus("#task-status", data.task_id, data.status === "completed" ? "ok" : "warn"); - finishInlineReasoning(pending, data.reasoning_content); - await refreshEvents(data.task_id); - return; - } - if (name === "error") { - throw new Error(data.error || "Stream failed."); - } + await handleAssistantStreamEvent(pending, name, data, context); }); } catch (error) { - if (!taskId) input.value = message; + if (!context.taskId) input.value = message; setMessagePending(pending, error.message); pending.querySelector(".message-meta span").textContent = "failed"; setStatus("#task-status", "failed", "bad"); - if (taskId) await refreshEvents(taskId); + if (context.taskId) await refreshEvents(context.taskId); } finally { state.running = false; document.querySelector("#run").disabled = false; diff --git a/prompts/roles/action.md b/prompts/roles/action.md index 41cfab3..33df5ac 100644 --- a/prompts/roles/action.md +++ b/prompts/roles/action.md @@ -8,9 +8,16 @@ Available tools: Args: {"path": "relative/path.txt"} - file_write: write a file inside the current workspace. Args: {"path": "relative/path.txt", "content": "text", "overwrite": false} +- list_dir: list direct children of a directory inside the current workspace. + Args: {"path": "."} +- search_files: search text inside files in the current workspace. + Args: {"query": "text to find", "path": ".", "glob": "*.py"} - shell_exec_safe: run a safe allowlisted shell command in the current workspace. Args: {"command": "pwd"} Return actions=[] when the user can be answered directly without tools. +When tool_observations are already present, request only genuinely missing +follow-up information. Return actions=[] when the observations are sufficient +for the thinker to answer. Use only the listed tools. Keep actions minimal and directly tied to the user's request. Do not invent tool names. diff --git a/tests/smoke/test_api_stream_chat.py b/tests/smoke/test_api_stream_chat.py index 45d3a11..fb0ae95 100644 --- a/tests/smoke/test_api_stream_chat.py +++ b/tests/smoke/test_api_stream_chat.py @@ -1,5 +1,6 @@ from fastapi.testclient import TestClient import json +import re from duck_core.model_client import ModelResponse @@ -56,6 +57,16 @@ def test_stream_chat_endpoint_executes_tool_before_streaming_answer(tmp_path, mo async def fake_chat(self, role, messages, temperature=None, max_output_tokens=None, response_format=None): assert role == "action" + if any("tool_observations" in message["content"] for message in messages): + actions = [] + else: + actions = [ + { + "tool": "file_read", + "args": {"path": "note.txt"}, + "reason": "User asked for file contents", + } + ] return ModelResponse( role=role, model="local-main", @@ -64,13 +75,7 @@ def test_stream_chat_endpoint_executes_tool_before_streaming_answer(tmp_path, mo "kind": "action_directive", "intent": "read requested file", "risk_level": "low", - "actions": [ - { - "tool": "file_read", - "args": {"path": "note.txt"}, - "reason": "User asked for file contents", - } - ], + "actions": actions, } ), reasoning_content=None, @@ -101,3 +106,70 @@ def test_stream_chat_endpoint_executes_tool_before_streaming_answer(tmp_path, mo assert "event: content_delta" in body assert "answer from tool" in body assert "event: done" in body + + +def test_continue_stream_executes_approved_tool_and_streams_answer(tmp_path, monkeypatch): + monkeypatch.setenv("DUCK_DB_PATH", str(tmp_path / "duck.sqlite3")) + + async def fake_chat(self, role, messages, temperature=None, max_output_tokens=None, response_format=None): + assert role == "action" + if any("tool_observations" in message["content"] for message in messages): + actions = [] + else: + actions = [ + { + "tool": "shell_exec_safe", + "args": {"command": "uname -a"}, + "reason": "User asked for system information", + } + ] + return ModelResponse( + role=role, + model="local-main", + content=json.dumps( + { + "kind": "action_directive", + "intent": "run command", + "risk_level": "medium", + "actions": actions, + } + ), + reasoning_content=None, + raw={}, + latency_ms=1.0, + ) + + async def fake_stream_chat(self, role, messages): + assert role == "thinker" + observation_message = next(message for message in messages if "tool_observations" in message["content"]) + assert "uname" in observation_message["content"] + yield {"type": "content_delta", "delta": "continued after approval"} + + monkeypatch.setattr("duck_core.model_client.ModelClient.chat", fake_chat) + monkeypatch.setattr("duck_core.model_client.ModelClient.stream_chat", fake_stream_chat) + client = TestClient(create_app()) + + with client.stream( + "POST", + "/v1/chat/stream", + json={"message": "run uname", "workspace": str(tmp_path), "debug": True}, + ) as response: + initial_body = "".join(response.iter_text()) + task_id = re.search(r'"task_id"\s*:\s*"([^"]+)"', initial_body).group(1) + pending = client.get("/v1/approvals/pending").json() + approval = next(item for item in pending if item["task_id"] == task_id) + client.post(f"/v1/approvals/{approval['approval_id']}/allow_once") + + with client.stream( + "POST", + f"/v1/tasks/{approval['task_id']}/continue/stream", + json={"approval_id": approval["approval_id"]}, + ) as response: + body = "".join(response.iter_text()) + + assert "event: tool_approval_requested" in initial_body + assert response.status_code == 200 + assert "event: tool_call_finished" in body + assert "event: content_delta" in body + assert "continued after approval" in body + assert "event: done" in body diff --git a/tests/smoke/test_chat_api.py b/tests/smoke/test_chat_api.py index edc0b85..3630b16 100644 --- a/tests/smoke/test_chat_api.py +++ b/tests/smoke/test_chat_api.py @@ -73,7 +73,7 @@ def test_chat_api_exposes_pending_approval_from_runtime_tool_gate(tmp_path, monk "actions": [ { "tool": "shell_exec_safe", - "args": {"command": "uname -a"}, + "args": {"command": "hostname --pending-approval-test"}, "reason": "needs shell command", } ], @@ -90,7 +90,11 @@ def test_chat_api_exposes_pending_approval_from_runtime_tool_gate(tmp_path, monk response = client.post("/v1/chat", json={"message": "run uname", "debug": True}) approvals = client.get("/v1/approvals/pending").json() + approval = next( + item for item in approvals if item["task_id"] == response.json()["task_id"] + ) assert response.status_code == 200 assert response.json()["status"] == "waiting_for_approval" - assert approvals[0]["normalized_action"]["tool"] == "shell_exec_safe" + assert approval["normalized_action"]["tool"] == "shell_exec_safe" + assert approval["normalized_action"]["args"]["command"] == "hostname --pending-approval-test" diff --git a/tests/smoke/test_runtime_tools.py b/tests/smoke/test_runtime_tools.py index 9eca0c2..bd5b7d7 100644 --- a/tests/smoke/test_runtime_tools.py +++ b/tests/smoke/test_runtime_tools.py @@ -12,6 +12,16 @@ from duck_core.tasks.store import TaskStore class FakeToolModelClient: async def chat(self, role, messages): if role == "action": + if any("tool_observations" in message["content"] for message in messages): + actions = [] + else: + actions = [ + { + "tool": "file_read", + "args": {"path": "note.txt"}, + "reason": "User asked for file contents", + } + ] return ModelResponse( role=role, model="local-main", @@ -20,13 +30,7 @@ class FakeToolModelClient: "kind": "action_directive", "intent": "read requested file", "risk_level": "low", - "actions": [ - { - "tool": "file_read", - "args": {"path": "note.txt"}, - "reason": "User asked for file contents", - } - ], + "actions": actions, } ), reasoning_content=None, @@ -45,6 +49,58 @@ class FakeToolModelClient: ) +class FakeMultiStepToolModelClient: + async def chat(self, role, messages): + if role == "action": + observation_text = "\n".join(message["content"] for message in messages) + if "tool_observations" not in observation_text: + actions = [ + { + "tool": "list_dir", + "args": {"path": "."}, + "reason": "Find available files", + } + ] + elif "README.md" in observation_text and "readme contents" not in observation_text: + actions = [ + { + "tool": "file_read", + "args": {"path": "README.md"}, + "reason": "Read discovered README", + } + ] + else: + actions = [] + return ModelResponse( + role=role, + model="local-main", + content=json.dumps( + { + "kind": "action_directive", + "intent": "multi-step file inspection", + "risk_level": "low", + "actions": actions, + } + ), + reasoning_content=None, + raw={}, + latency_ms=5.0, + ) + assert role == "thinker" + observation_text = "\n".join(message["content"] for message in messages) + assert "list_dir" in observation_text + assert "file_read" in observation_text + assert "readme contents" in observation_text + return ModelResponse( + role=role, + model="local-main", + content="Readme inspected", + reasoning_content=None, + raw={}, + latency_ms=12.0, + ) + + @pytest.mark.asyncio async def test_runtime_executes_action_directive_tool_and_finishes_with_observation(tmp_path): (tmp_path / "note.txt").write_text("hello from tool") @@ -67,9 +123,38 @@ async def test_runtime_executes_action_directive_tool_and_finishes_with_observat assert tool_finished.payload["result"]["output"] == "hello from tool" +@pytest.mark.asyncio +async def test_runtime_runs_multiple_tool_steps_before_final_answer(tmp_path): + (tmp_path / "README.md").write_text("readme contents") + db_path = str(tmp_path / "duck.sqlite3") + task_store = TaskStore(db_path) + event_store = EventStore(db_path) + loop = RuntimeLoop(task_store, event_store, FakeMultiStepToolModelClient()) + + result = await loop.run_chat("inspect the workspace readme", str(tmp_path), debug=True) + events = await event_store.list_events(result.task_id) + finished_tools = [ + event.payload["tool"] for event in events if event.event_type == "tool_call_finished" + ] + + assert result.status == "completed" + assert result.final_response == "Readme inspected" + assert finished_tools == ["list_dir", "file_read"] + + class FakeApprovalModelClient: async def chat(self, role, messages): if role == "action": + if any("tool_observations" in message["content"] for message in messages): + actions = [] + else: + actions = [ + { + "tool": "shell_exec_safe", + "args": {"command": "uname -a"}, + "reason": "User requested system information", + } + ] return ModelResponse( role=role, model="local-main", @@ -78,13 +163,7 @@ class FakeApprovalModelClient: "kind": "action_directive", "intent": "run command", "risk_level": "medium", - "actions": [ - { - "tool": "shell_exec_safe", - "args": {"command": "uname -a"}, - "reason": "User requested system information", - } - ], + "actions": actions, } ), reasoning_content=None, @@ -110,3 +189,197 @@ async def test_runtime_creates_pending_approval_when_tool_requires_it(tmp_path): assert pending[0].task_id == result.task_id assert pending[0].normalized_action["tool"] == "shell_exec_safe" assert any(event.event_type == "tool_approval_requested" for event in events) + + +class FakeApprovalContinuationModelClient: + def __init__(self): + self.thinker_messages = [] + + async def chat(self, role, messages): + if role == "action": + if any("tool_observations" in message["content"] for message in messages): + actions = [] + else: + actions = [ + { + "tool": "shell_exec_safe", + "args": {"command": "uname -a"}, + "reason": "User requested system information", + } + ] + return ModelResponse( + role=role, + model="local-main", + content=json.dumps( + { + "kind": "action_directive", + "intent": "run command", + "risk_level": "medium", + "actions": actions, + } + ), + reasoning_content=None, + raw={}, + latency_ms=5.0, + ) + assert role == "thinker" + self.thinker_messages = messages + assert any("tool_observations" in message["content"] for message in messages) + return ModelResponse( + role=role, + model="local-main", + content="uname completed", + reasoning_content="used approved shell command", + raw={}, + latency_ms=10.0, + ) + + +class FakeApprovalThenSecondToolModelClient: + async def chat(self, role, messages): + observation_text = "\n".join(message["content"] for message in messages) + if role == "action": + if "tool_observations" in observation_text and "second step content" not in observation_text: + actions = [ + { + "tool": "file_read", + "args": {"path": "second.txt"}, + "reason": "Read follow-up file after approved command", + } + ] + elif "tool_observations" in observation_text: + actions = [] + else: + actions = [ + { + "tool": "shell_exec_safe", + "args": {"command": "uname -a"}, + "reason": "User requested system information", + } + ] + return ModelResponse( + role=role, + model="local-main", + content=json.dumps( + { + "kind": "action_directive", + "intent": "approval then follow-up", + "risk_level": "medium", + "actions": actions, + } + ), + reasoning_content=None, + raw={}, + latency_ms=5.0, + ) + assert role == "thinker" + assert "shell_exec_safe" in observation_text + assert "file_read" in observation_text + assert "second step content" in observation_text + return ModelResponse( + role=role, + model="local-main", + content="approved command and second tool completed", + reasoning_content=None, + raw={}, + latency_ms=10.0, + ) + + +@pytest.mark.asyncio +async def test_runtime_continues_after_approved_tool_call(tmp_path): + db_path = str(tmp_path / "duck.sqlite3") + task_store = TaskStore(db_path) + event_store = EventStore(db_path) + approvals = ApprovalService(db_path) + model_client = FakeApprovalContinuationModelClient() + loop = RuntimeLoop(task_store, event_store, model_client, approval_service=approvals) + + pending_result = await loop.run_chat("run uname", str(tmp_path), debug=True) + pending = await approvals.pending() + await approvals.allow_once(pending[0].approval_id) + + result = await loop.continue_after_approval(pending_result.task_id, pending[0].approval_id) + events = await event_store.list_events(result.task_id) + finished = next(event for event in events if event.event_type == "tool_call_finished") + + assert result.status == "completed" + assert result.final_response == "uname completed" + assert finished.payload["tool"] == "shell_exec_safe" + assert finished.payload["result"]["ok"] is True + assert "uname" in finished.payload["result"]["metadata"]["command"] + assert any(event.event_type == "task_completed" for event in events) + + +@pytest.mark.asyncio +async def test_runtime_can_run_followup_tool_after_approval(tmp_path): + (tmp_path / "second.txt").write_text("second step content") + db_path = str(tmp_path / "duck.sqlite3") + task_store = TaskStore(db_path) + event_store = EventStore(db_path) + approvals = ApprovalService(db_path) + loop = RuntimeLoop( + task_store, + event_store, + FakeApprovalThenSecondToolModelClient(), + approval_service=approvals, + ) + + pending_result = await loop.run_chat("run uname then inspect second file", str(tmp_path), debug=True) + pending = await approvals.pending() + await approvals.allow_once(pending[0].approval_id) + + result = await loop.continue_after_approval(pending_result.task_id, pending[0].approval_id) + events = await event_store.list_events(result.task_id) + finished_tools = [ + event.payload["tool"] for event in events if event.event_type == "tool_call_finished" + ] + + assert result.status == "completed" + assert finished_tools == ["shell_exec_safe", "file_read"] + + +@pytest.mark.asyncio +async def test_runtime_continues_after_denied_tool_call_without_execution(tmp_path): + db_path = str(tmp_path / "duck.sqlite3") + task_store = TaskStore(db_path) + event_store = EventStore(db_path) + approvals = ApprovalService(db_path) + model_client = FakeApprovalContinuationModelClient() + loop = RuntimeLoop(task_store, event_store, model_client, approval_service=approvals) + + pending_result = await loop.run_chat("run uname", str(tmp_path), debug=True) + pending = await approvals.pending() + await approvals.deny(pending[0].approval_id) + + result = await loop.continue_after_approval(pending_result.task_id, pending[0].approval_id) + events = await event_store.list_events(result.task_id) + finished = next(event for event in events if event.event_type == "tool_call_finished") + + assert result.status == "completed" + assert finished.payload["result"]["ok"] is False + assert finished.payload["result"]["metadata"]["decision"] == "deny" + assert "denied" in finished.payload["result"]["error"].lower() + + +@pytest.mark.asyncio +async def test_runtime_reuses_allow_forever_for_matching_action(tmp_path): + db_path = str(tmp_path / "duck.sqlite3") + task_store = TaskStore(db_path) + event_store = EventStore(db_path) + approvals = ApprovalService(db_path) + model_client = FakeApprovalContinuationModelClient() + loop = RuntimeLoop(task_store, event_store, model_client, approval_service=approvals) + + first_result = await loop.run_chat("run uname", str(tmp_path), debug=True) + first_pending = await approvals.pending() + await approvals.allow_forever(first_pending[0].approval_id) + await loop.continue_after_approval(first_result.task_id, first_pending[0].approval_id) + + second_result = await loop.run_chat("run uname again", str(tmp_path), debug=True) + second_events = await event_store.list_events(second_result.task_id) + + assert second_result.status == "completed" + assert second_result.final_response == "uname completed" + assert not any(event.event_type == "tool_approval_requested" for event in second_events) + assert any(event.event_type == "tool_call_finished" for event in second_events) diff --git a/tests/smoke/test_tool_gateway.py b/tests/smoke/test_tool_gateway.py index fb3a596..b277a9d 100644 --- a/tests/smoke/test_tool_gateway.py +++ b/tests/smoke/test_tool_gateway.py @@ -40,3 +40,39 @@ async def test_tool_gateway_runs_allowed_directive(tmp_path): assert result.ok is True assert result.metadata["path"].endswith("a.txt") + + +@pytest.mark.asyncio +async def test_tool_gateway_lists_workspace_directory(tmp_path): + (tmp_path / "src").mkdir() + (tmp_path / "src" / "app.py").write_text("print('duck')") + (tmp_path / "README.md").write_text("hello") + gateway = ToolGateway.default(str(tmp_path)) + + result = await gateway.run_action({"tool": "list_dir", "args": {"path": "."}}) + escaped = await gateway.run_action({"tool": "list_dir", "args": {"path": ".."}}) + + assert result.ok is True + assert "README.md" in result.output + assert "src/" in result.output + assert escaped.ok is False + + +@pytest.mark.asyncio +async def test_tool_gateway_searches_file_contents(tmp_path): + (tmp_path / "src").mkdir() + (tmp_path / "src" / "app.py").write_text("duck tool gateway\\n") + (tmp_path / "notes.txt").write_text("other content\\n") + gateway = ToolGateway.default(str(tmp_path)) + + result = await gateway.run_action( + {"tool": "search_files", "args": {"query": "duck tool", "path": "."}} + ) + escaped = await gateway.run_action( + {"tool": "search_files", "args": {"query": "duck", "path": ".."}} + ) + + assert result.ok is True + assert "src/app.py:1:duck tool gateway" in result.output + assert result.metadata["matches"] == 1 + assert escaped.ok is False