Add approval continuation and multi-step tools
This commit is contained in:
parent
2d3a047548
commit
a4b7ef034a
128
duck_core/api.py
128
duck_core/api.py
|
|
@ -30,6 +30,10 @@ class ChatRequest(BaseModel):
|
|||
debug: bool = False
|
||||
|
||||
|
||||
class ContinueRequest(BaseModel):
|
||||
approval_id: str
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
settings = get_settings()
|
||||
if settings.api_host == "0.0.0.0":
|
||||
|
|
@ -142,7 +146,7 @@ def create_app() -> FastAPI:
|
|||
content_parts: list[str] = []
|
||||
try:
|
||||
messages = runtime.context_builder.build_basic_messages(task)
|
||||
tool_observations = await runtime._run_action_tools(
|
||||
tool_observations = await runtime._run_action_loop(
|
||||
task.task_id, messages, body.workspace or settings.workspace
|
||||
)
|
||||
async for tool_event in emit_tool_events(task.task_id, task_event.sequence):
|
||||
|
|
@ -283,6 +287,128 @@ def create_app() -> FastAPI:
|
|||
await event_store.append(task_id, "task_continued", {})
|
||||
return {"status": "running"}
|
||||
|
||||
@app.post("/v1/tasks/{task_id}/continue/stream")
|
||||
async def continue_task_stream(task_id: str, body: ContinueRequest) -> StreamingResponse:
|
||||
task = await task_store.get_task(task_id)
|
||||
if task is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
approval = await approvals.get(body.approval_id)
|
||||
if approval is None or approval.task_id != task_id:
|
||||
raise HTTPException(status_code=404, detail="Approval not found for task")
|
||||
if approval.decision is None:
|
||||
raise HTTPException(status_code=409, detail="Approval is still pending")
|
||||
|
||||
async def generator():
|
||||
reasoning_parts: list[str] = []
|
||||
content_parts: list[str] = []
|
||||
try:
|
||||
await task_store.update_status(task_id, "running")
|
||||
continued_event = await event_store.append(
|
||||
task_id,
|
||||
"task_continued",
|
||||
{"approval_id": body.approval_id, "decision": approval.decision},
|
||||
)
|
||||
tool_observation = await runtime._run_approved_or_denied_action(
|
||||
task_id, approval.normalized_action, approval.decision
|
||||
)
|
||||
messages = runtime.context_builder.build_basic_messages(task)
|
||||
tool_observations = [tool_observation]
|
||||
if approval.decision != "deny":
|
||||
tool_observations = await runtime._run_action_loop(
|
||||
task_id,
|
||||
messages,
|
||||
task.workspace,
|
||||
initial_observations=tool_observations,
|
||||
)
|
||||
async for tool_event in emit_tool_events(task_id, continued_event.sequence):
|
||||
yield tool_event
|
||||
if any(observation.get("requires_approval") for observation in tool_observations):
|
||||
await task_store.waiting_for_approval(task_id)
|
||||
await event_store.append(
|
||||
task_id,
|
||||
"task_waiting_for_approval",
|
||||
{"observations": tool_observations},
|
||||
)
|
||||
yield sse(
|
||||
"done",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"status": "waiting_for_approval",
|
||||
"final_response": "Waiting for approval.",
|
||||
"reasoning_content": None,
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
messages = [
|
||||
*messages,
|
||||
{
|
||||
"role": "user",
|
||||
"content": "tool_observations:\n"
|
||||
+ json.dumps(tool_observations, ensure_ascii=False, indent=2),
|
||||
},
|
||||
]
|
||||
await event_store.append(task_id, "model_call_started", {"role": "thinker"})
|
||||
async for chunk in model_client.stream_chat("thinker", messages):
|
||||
delta = str(chunk.get("delta") or "")
|
||||
if chunk.get("type") == "reasoning_delta":
|
||||
reasoning_parts.append(delta)
|
||||
yield sse("reasoning_delta", {"task_id": task_id, "delta": delta})
|
||||
elif chunk.get("type") == "content_delta":
|
||||
content_parts.append(delta)
|
||||
yield sse("content_delta", {"task_id": task_id, "delta": delta})
|
||||
|
||||
content = "".join(content_parts)
|
||||
reasoning_content = "".join(reasoning_parts) or None
|
||||
await event_store.append(
|
||||
task_id,
|
||||
"cognition_response",
|
||||
{
|
||||
"role": "thinker",
|
||||
"content": content,
|
||||
"reasoning_content": reasoning_content,
|
||||
},
|
||||
)
|
||||
await event_store.append(
|
||||
task_id,
|
||||
"model_call_finished",
|
||||
{
|
||||
"role": "thinker",
|
||||
"model": model_client.get_role_config("thinker").model,
|
||||
},
|
||||
)
|
||||
await task_store.complete_task(task_id, content)
|
||||
await event_store.append(
|
||||
task_id,
|
||||
"task_completed",
|
||||
{
|
||||
"final_response": content,
|
||||
"reasoning_content": reasoning_content,
|
||||
},
|
||||
)
|
||||
yield sse(
|
||||
"done",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"status": "completed",
|
||||
"final_response": content,
|
||||
"reasoning_content": reasoning_content,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
await task_store.fail_task(task_id, str(exc))
|
||||
await event_store.append(task_id, "task_failed", {"error": str(exc)})
|
||||
yield sse(
|
||||
"error",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"status": "failed",
|
||||
"error": str(exc),
|
||||
},
|
||||
)
|
||||
|
||||
return StreamingResponse(generator(), media_type="text/event-stream")
|
||||
|
||||
@app.post("/v1/tasks/{task_id}/cancel")
|
||||
async def cancel_task(task_id: str) -> dict[str, str]:
|
||||
await task_store.cancel_task(task_id)
|
||||
|
|
|
|||
|
|
@ -93,6 +93,16 @@ class ApprovalService:
|
|||
rows = await cursor.fetchall()
|
||||
return [self._row_to_approval(row) for row in rows]
|
||||
|
||||
async def get(self, approval_id: str) -> Approval | None:
|
||||
await self.init()
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"select * from approvals where approval_id = ?", (approval_id,)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
return self._row_to_approval(row) if row else None
|
||||
|
||||
async def allow_once(self, approval_id: str) -> None:
|
||||
await self._decide(approval_id, "resolved", "allow_once")
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from duck_core.events.store import EventStore
|
|||
from duck_core.model_client import ModelClient
|
||||
from duck_core.tasks.store import TaskStore
|
||||
from duck_core.tools.gateway import ToolGateway
|
||||
from duck_core.tools.base import ToolResult
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -26,12 +27,14 @@ class RuntimeLoop:
|
|||
model_client: ModelClient | None = None,
|
||||
context_builder: ContextBuilder | None = None,
|
||||
approval_service: ApprovalService | None = None,
|
||||
max_tool_iterations: int = 4,
|
||||
):
|
||||
self.task_store = task_store
|
||||
self.event_store = event_store
|
||||
self.model_client = model_client or ModelClient()
|
||||
self.context_builder = context_builder or ContextBuilder()
|
||||
self.approval_service = approval_service
|
||||
self.max_tool_iterations = max_tool_iterations
|
||||
|
||||
async def run_chat(
|
||||
self, message: str, workspace: str | None = None, debug: bool = False
|
||||
|
|
@ -44,7 +47,7 @@ class RuntimeLoop:
|
|||
)
|
||||
try:
|
||||
messages = self.context_builder.build_basic_messages(task)
|
||||
tool_observations = await self._run_action_tools(task.task_id, messages, workspace)
|
||||
tool_observations = await self._run_action_loop(task.task_id, messages, workspace)
|
||||
if any(observation.get("requires_approval") for observation in tool_observations):
|
||||
await self.task_store.waiting_for_approval(task.task_id)
|
||||
await self.event_store.append(
|
||||
|
|
@ -119,8 +122,172 @@ class RuntimeLoop:
|
|||
reasoning_content=None,
|
||||
)
|
||||
|
||||
async def continue_after_approval(self, task_id: str, approval_id: str) -> ChatResult:
|
||||
if self.approval_service is None:
|
||||
return ChatResult(
|
||||
task_id=task_id,
|
||||
status="failed",
|
||||
final_response="Approval service is not configured.",
|
||||
reasoning_content=None,
|
||||
)
|
||||
|
||||
task = await self.task_store.get_task(task_id)
|
||||
approval = await self.approval_service.get(approval_id)
|
||||
if task is None:
|
||||
return ChatResult(
|
||||
task_id=task_id,
|
||||
status="failed",
|
||||
final_response="Task not found.",
|
||||
reasoning_content=None,
|
||||
)
|
||||
if approval is None or approval.task_id != task_id:
|
||||
return ChatResult(
|
||||
task_id=task_id,
|
||||
status="failed",
|
||||
final_response="Approval not found for task.",
|
||||
reasoning_content=None,
|
||||
)
|
||||
if approval.decision is None:
|
||||
return ChatResult(
|
||||
task_id=task_id,
|
||||
status="waiting_for_approval",
|
||||
final_response="Waiting for approval.",
|
||||
reasoning_content=None,
|
||||
)
|
||||
|
||||
await self.task_store.update_status(task_id, "running")
|
||||
await self.event_store.append(
|
||||
task_id,
|
||||
"task_continued",
|
||||
{"approval_id": approval_id, "decision": approval.decision},
|
||||
)
|
||||
try:
|
||||
tool_observation = await self._run_approved_or_denied_action(
|
||||
task_id, approval.normalized_action, approval.decision
|
||||
)
|
||||
messages = self.context_builder.build_basic_messages(task)
|
||||
tool_observations = [tool_observation]
|
||||
if approval.decision != "deny":
|
||||
tool_observations = await self._run_action_loop(
|
||||
task_id,
|
||||
messages,
|
||||
task.workspace,
|
||||
initial_observations=tool_observations,
|
||||
)
|
||||
if any(observation.get("requires_approval") for observation in tool_observations):
|
||||
await self.task_store.waiting_for_approval(task_id)
|
||||
await self.event_store.append(
|
||||
task_id,
|
||||
"task_waiting_for_approval",
|
||||
{"observations": tool_observations},
|
||||
)
|
||||
return ChatResult(
|
||||
task_id=task_id,
|
||||
status="waiting_for_approval",
|
||||
final_response="Waiting for approval.",
|
||||
reasoning_content=None,
|
||||
)
|
||||
messages = [
|
||||
*messages,
|
||||
{
|
||||
"role": "user",
|
||||
"content": "tool_observations:\n"
|
||||
+ json.dumps(tool_observations, ensure_ascii=False, indent=2),
|
||||
},
|
||||
]
|
||||
await self.event_store.append(task_id, "model_call_started", {"role": "thinker"})
|
||||
response = await self.model_client.chat("thinker", messages)
|
||||
await self.event_store.append(
|
||||
task_id,
|
||||
"cognition_response",
|
||||
{
|
||||
"role": response.role,
|
||||
"content": response.content,
|
||||
"reasoning_content": response.reasoning_content,
|
||||
},
|
||||
)
|
||||
await self.event_store.append(
|
||||
task_id,
|
||||
"model_call_finished",
|
||||
{
|
||||
"role": response.role,
|
||||
"model": response.model,
|
||||
"latency_ms": response.latency_ms,
|
||||
"prompt_tokens": response.prompt_tokens,
|
||||
"completion_tokens": response.completion_tokens,
|
||||
"total_tokens": response.total_tokens,
|
||||
},
|
||||
)
|
||||
await self.task_store.complete_task(task_id, response.content)
|
||||
await self.event_store.append(
|
||||
task_id,
|
||||
"task_completed",
|
||||
{
|
||||
"final_response": response.content,
|
||||
"reasoning_content": response.reasoning_content,
|
||||
},
|
||||
)
|
||||
return ChatResult(
|
||||
task_id=task_id,
|
||||
status="completed",
|
||||
final_response=response.content,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
except Exception as exc:
|
||||
await self.task_store.fail_task(task_id, str(exc))
|
||||
await self.event_store.append(task_id, "task_failed", {"error": str(exc)})
|
||||
return ChatResult(
|
||||
task_id=task_id,
|
||||
status="failed",
|
||||
final_response=str(exc),
|
||||
reasoning_content=None,
|
||||
)
|
||||
|
||||
async def _run_approved_or_denied_action(
|
||||
self, task_id: str, action: dict[str, Any], decision: str
|
||||
) -> dict[str, Any]:
|
||||
tool_name = str(action.get("tool", ""))
|
||||
index = await self._approval_action_index(task_id, action)
|
||||
if decision == "deny":
|
||||
result = ToolResult(
|
||||
ok=False,
|
||||
error="Tool action denied by user.",
|
||||
metadata={"decision": "deny"},
|
||||
)
|
||||
else:
|
||||
gateway = ToolGateway.default((await self.task_store.get_task(task_id)).workspace or ".")
|
||||
result = await gateway.run_action(action, approved=True)
|
||||
|
||||
result_payload = result.model_dump()
|
||||
await self.event_store.append(
|
||||
task_id,
|
||||
"tool_call_finished",
|
||||
{"index": index, "tool": tool_name, "result": result_payload},
|
||||
)
|
||||
return {
|
||||
"index": index,
|
||||
"tool": tool_name,
|
||||
"reason": action.get("reason"),
|
||||
"decision": decision,
|
||||
"result": result_payload,
|
||||
}
|
||||
|
||||
async def _approval_action_index(self, task_id: str, action: dict[str, Any]) -> int:
|
||||
events = await self.event_store.list_events(task_id)
|
||||
for event in reversed(events):
|
||||
if (
|
||||
event.event_type == "tool_approval_requested"
|
||||
and event.payload.get("action") == action
|
||||
):
|
||||
return int(event.payload.get("index") or 1)
|
||||
return 1
|
||||
|
||||
async def _run_action_tools(
|
||||
self, task_id: str, messages: list[dict[str, str]], workspace: str | None
|
||||
self,
|
||||
task_id: str,
|
||||
messages: list[dict[str, str]],
|
||||
workspace: str | None,
|
||||
start_index: int = 1,
|
||||
) -> list[dict[str, Any]]:
|
||||
try:
|
||||
await self.event_store.append(task_id, "model_call_started", {"role": "action"})
|
||||
|
|
@ -141,7 +308,7 @@ class RuntimeLoop:
|
|||
|
||||
gateway = ToolGateway.default(workspace or ".")
|
||||
observations: list[dict[str, Any]] = []
|
||||
for index, action in enumerate(actions, start=1):
|
||||
for index, action in enumerate(actions, start=start_index):
|
||||
if not isinstance(action, dict):
|
||||
observations.append(
|
||||
{"index": index, "ok": False, "error": "Action must be an object"}
|
||||
|
|
@ -153,7 +320,11 @@ class RuntimeLoop:
|
|||
"tool_call_started",
|
||||
{"index": index, "tool": tool_name, "args": action.get("args") or {}},
|
||||
)
|
||||
result = await gateway.run_action(action)
|
||||
approved_forever = (
|
||||
self.approval_service is not None
|
||||
and await self.approval_service.is_allowed_forever(action)
|
||||
)
|
||||
result = await gateway.run_action(action, approved=approved_forever)
|
||||
result_payload = result.model_dump()
|
||||
if result.metadata.get("requires_approval"):
|
||||
approval = None
|
||||
|
|
@ -195,3 +366,48 @@ class RuntimeLoop:
|
|||
}
|
||||
)
|
||||
return observations
|
||||
|
||||
async def _run_action_loop(
|
||||
self,
|
||||
task_id: str,
|
||||
messages: list[dict[str, str]],
|
||||
workspace: str | None,
|
||||
initial_observations: list[dict[str, Any]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
all_observations = list(initial_observations or [])
|
||||
current_messages = messages
|
||||
if all_observations:
|
||||
current_messages = [
|
||||
*messages,
|
||||
{
|
||||
"role": "user",
|
||||
"content": "tool_observations:\n"
|
||||
+ json.dumps(all_observations, ensure_ascii=False, indent=2),
|
||||
},
|
||||
]
|
||||
for _ in range(self.max_tool_iterations):
|
||||
observations = await self._run_action_tools(
|
||||
task_id,
|
||||
current_messages,
|
||||
workspace,
|
||||
start_index=len(all_observations) + 1,
|
||||
)
|
||||
if not observations:
|
||||
return all_observations
|
||||
all_observations.extend(observations)
|
||||
if any(observation.get("requires_approval") for observation in observations):
|
||||
return all_observations
|
||||
current_messages = [
|
||||
*messages,
|
||||
{
|
||||
"role": "user",
|
||||
"content": "tool_observations:\n"
|
||||
+ json.dumps(all_observations, ensure_ascii=False, indent=2),
|
||||
},
|
||||
]
|
||||
await self.event_store.append(
|
||||
task_id,
|
||||
"tool_iteration_limit_reached",
|
||||
{"max_tool_iterations": self.max_tool_iterations},
|
||||
)
|
||||
return all_observations
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from typing import Any
|
|||
from duck_core.tools.base import Tool, ToolResult
|
||||
from duck_core.tools.file_read import FileReadTool
|
||||
from duck_core.tools.file_write import FileWriteTool
|
||||
from duck_core.tools.list_dir import ListDirTool
|
||||
from duck_core.tools.search_files import SearchFilesTool
|
||||
from duck_core.tools.shell_exec_safe import ShellExecSafeTool
|
||||
|
||||
|
||||
|
|
@ -16,11 +18,13 @@ class ToolGateway:
|
|||
[
|
||||
FileReadTool(workspace),
|
||||
FileWriteTool(workspace),
|
||||
ListDirTool(workspace),
|
||||
SearchFilesTool(workspace),
|
||||
ShellExecSafeTool(workspace),
|
||||
]
|
||||
)
|
||||
|
||||
async def run_action(self, action: dict[str, Any]) -> ToolResult:
|
||||
async def run_action(self, action: dict[str, Any], approved: bool = False) -> ToolResult:
|
||||
tool_name = str(action.get("tool", ""))
|
||||
tool = self.tools.get(tool_name)
|
||||
if tool is None:
|
||||
|
|
@ -28,4 +32,6 @@ class ToolGateway:
|
|||
args = action.get("args") or {}
|
||||
if not isinstance(args, dict):
|
||||
return ToolResult(ok=False, error="Tool args must be an object")
|
||||
if approved:
|
||||
args = {**args, "_approved": True}
|
||||
return await tool.run(args)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
from typing import Any
|
||||
|
||||
from duck_core.tools.base import ToolResult
|
||||
from duck_core.tools.paths import WorkspacePathError, resolve_workspace_path
|
||||
|
||||
|
||||
class ListDirTool:
|
||||
name = "list_dir"
|
||||
risk_level = "low"
|
||||
|
||||
def __init__(self, workspace: str, max_entries: int = 200):
|
||||
self.workspace = workspace
|
||||
self.max_entries = max_entries
|
||||
|
||||
async def run(self, args: dict[str, Any]) -> ToolResult:
|
||||
raw_path = str(args.get("path") or ".")
|
||||
try:
|
||||
path = resolve_workspace_path(self.workspace, raw_path)
|
||||
except WorkspacePathError as exc:
|
||||
return ToolResult(ok=False, error=str(exc))
|
||||
if not path.exists():
|
||||
return ToolResult(ok=False, error=f"Directory not found: {raw_path}")
|
||||
if not path.is_dir():
|
||||
return ToolResult(ok=False, error=f"Not a directory: {raw_path}")
|
||||
|
||||
entries = []
|
||||
for child in sorted(path.iterdir(), key=lambda item: (not item.is_dir(), item.name.lower())):
|
||||
suffix = "/" if child.is_dir() else ""
|
||||
entries.append(f"{child.name}{suffix}")
|
||||
if len(entries) >= self.max_entries:
|
||||
break
|
||||
|
||||
truncated = len(entries) >= self.max_entries
|
||||
return ToolResult(
|
||||
ok=True,
|
||||
output="\n".join(entries),
|
||||
metadata={
|
||||
"path": str(path),
|
||||
"entries": len(entries),
|
||||
"truncated": truncated,
|
||||
},
|
||||
)
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
from fnmatch import fnmatch
|
||||
from typing import Any
|
||||
|
||||
from duck_core.tools.base import ToolResult
|
||||
from duck_core.tools.paths import WorkspacePathError, resolve_workspace_path
|
||||
|
||||
|
||||
class SearchFilesTool:
|
||||
name = "search_files"
|
||||
risk_level = "low"
|
||||
|
||||
def __init__(self, workspace: str, max_matches: int = 100, max_file_bytes: int = 1_000_000):
|
||||
self.workspace = workspace
|
||||
self.max_matches = max_matches
|
||||
self.max_file_bytes = max_file_bytes
|
||||
|
||||
async def run(self, args: dict[str, Any]) -> ToolResult:
|
||||
query = str(args.get("query") or "")
|
||||
raw_path = str(args.get("path") or ".")
|
||||
pattern = str(args.get("glob") or "*")
|
||||
case_sensitive = bool(args.get("case_sensitive", True))
|
||||
max_matches = min(int(args.get("max_matches") or self.max_matches), self.max_matches)
|
||||
if not query:
|
||||
return ToolResult(ok=False, error="Search query is required")
|
||||
try:
|
||||
root = resolve_workspace_path(self.workspace, ".")
|
||||
path = resolve_workspace_path(self.workspace, raw_path)
|
||||
except WorkspacePathError as exc:
|
||||
return ToolResult(ok=False, error=str(exc))
|
||||
if not path.exists():
|
||||
return ToolResult(ok=False, error=f"Search path not found: {raw_path}")
|
||||
|
||||
needle = query if case_sensitive else query.lower()
|
||||
matches: list[str] = []
|
||||
files_scanned = 0
|
||||
candidates = [path] if path.is_file() else path.rglob("*")
|
||||
for candidate in candidates:
|
||||
if len(matches) >= max_matches:
|
||||
break
|
||||
if not candidate.is_file() or ".git" in candidate.parts:
|
||||
continue
|
||||
relative = candidate.relative_to(root).as_posix()
|
||||
if not fnmatch(candidate.name, pattern) and not fnmatch(relative, pattern):
|
||||
continue
|
||||
if candidate.stat().st_size > self.max_file_bytes:
|
||||
continue
|
||||
files_scanned += 1
|
||||
text = candidate.read_text(errors="replace")
|
||||
for line_number, line in enumerate(text.splitlines(), start=1):
|
||||
haystack = line if case_sensitive else line.lower()
|
||||
if needle in haystack:
|
||||
matches.append(f"{relative}:{line_number}:{line}")
|
||||
if len(matches) >= max_matches:
|
||||
break
|
||||
|
||||
return ToolResult(
|
||||
ok=True,
|
||||
output="\n".join(matches),
|
||||
metadata={
|
||||
"path": str(path),
|
||||
"query": query,
|
||||
"matches": len(matches),
|
||||
"files_scanned": files_scanned,
|
||||
"truncated": len(matches) >= max_matches,
|
||||
},
|
||||
)
|
||||
|
|
@ -57,7 +57,8 @@ class ShellExecSafeTool:
|
|||
|
||||
async def run(self, args: dict[str, Any]) -> ToolResult:
|
||||
command = str(args.get("command", "")).strip()
|
||||
allowed, reason = self._is_allowed(command)
|
||||
approved = bool(args.get("_approved"))
|
||||
allowed, reason = self._is_allowed(command, approved=approved)
|
||||
if not allowed:
|
||||
return ToolResult(ok=False, error=reason, metadata={"requires_approval": True})
|
||||
try:
|
||||
|
|
@ -79,13 +80,15 @@ class ShellExecSafeTool:
|
|||
metadata={"returncode": completed.returncode, "command": command},
|
||||
)
|
||||
|
||||
def _is_allowed(self, command: str) -> tuple[bool, str | None]:
|
||||
def _is_allowed(self, command: str, approved: bool = False) -> tuple[bool, str | None]:
|
||||
if not command:
|
||||
return False, "Empty command"
|
||||
lowered = command.lower()
|
||||
for blocked in BLOCKLIST:
|
||||
if lowered.startswith(blocked.lower()) or blocked.lower() in lowered:
|
||||
return False, f"Command is blocked: {blocked}"
|
||||
if approved:
|
||||
return True, None
|
||||
parts = shlex.split(command)
|
||||
prefix1 = parts[0] if parts else ""
|
||||
prefix2 = " ".join(parts[:2])
|
||||
|
|
|
|||
|
|
@ -165,6 +165,7 @@ function appendApprovalTerminal(article, eventPayload) {
|
|||
const command = formatToolCommand(action.tool || payload.tool, action.args || {});
|
||||
terminal?.classList.add("is-waiting");
|
||||
if (terminal && payload.approval_id) terminal.dataset.approvalId = payload.approval_id;
|
||||
if (terminal && eventPayload.task_id) terminal.dataset.taskId = eventPayload.task_id;
|
||||
if (status) status.textContent = "approval";
|
||||
if (body) {
|
||||
body.textContent = [
|
||||
|
|
@ -207,6 +208,7 @@ function inlineApprovalButton(label, action, tone = "") {
|
|||
async function resolveInlineApproval(button) {
|
||||
const terminal = button.closest(".tool-terminal");
|
||||
const approvalId = terminal?.dataset.approvalId;
|
||||
const taskId = terminal?.dataset.taskId;
|
||||
const action = button.dataset.inlineApprovalAction;
|
||||
if (!terminal || !approvalId || !action) return;
|
||||
|
||||
|
|
@ -226,6 +228,9 @@ async function resolveInlineApproval(button) {
|
|||
if (status) status.textContent = decision;
|
||||
if (body) body.textContent = `${command}\n\n${decision}: ${humanApprovalDecision(action)}`;
|
||||
actions?.remove();
|
||||
if (taskId) {
|
||||
await continueAfterInlineApproval(terminal.closest(".message"), taskId, approvalId);
|
||||
}
|
||||
}
|
||||
|
||||
function humanApprovalDecision(action) {
|
||||
|
|
@ -320,8 +325,8 @@ function parseSseBlock(block) {
|
|||
return {name: event.name, data: JSON.parse(event.data)};
|
||||
}
|
||||
|
||||
async function streamChat(payload, onEvent) {
|
||||
const response = await fetch("/v1/chat/stream", {
|
||||
async function streamSse(url, payload, onEvent) {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {"Content-Type": "application/json"},
|
||||
body: JSON.stringify(payload),
|
||||
|
|
@ -350,31 +355,15 @@ async function streamChat(payload, onEvent) {
|
|||
}
|
||||
}
|
||||
|
||||
async function sendMessage() {
|
||||
if (state.running) return;
|
||||
const input = document.querySelector("#message");
|
||||
const message = input.value.trim();
|
||||
if (!message) return;
|
||||
async function streamChat(payload, onEvent) {
|
||||
await streamSse("/v1/chat/stream", payload, onEvent);
|
||||
}
|
||||
|
||||
state.running = true;
|
||||
document.querySelector("#run").disabled = true;
|
||||
setStatus("#task-status", "running", "warn");
|
||||
addMessage("user", message, "submitted");
|
||||
input.value = "";
|
||||
const pending = addMessage("assistant", "", "thinking", {reasoning: true});
|
||||
let taskId = "";
|
||||
let contentStarted = false;
|
||||
|
||||
try {
|
||||
await streamChat({
|
||||
message,
|
||||
workspace: document.querySelector("#workspace").value,
|
||||
debug: document.querySelector("#debug").checked,
|
||||
}, async ({name, data}) => {
|
||||
if (data.task_id) taskId = data.task_id;
|
||||
async function handleAssistantStreamEvent(pending, name, data, context) {
|
||||
if (data.task_id) context.taskId = data.task_id;
|
||||
if (name === "task_created") {
|
||||
taskId = data.task_id;
|
||||
setStatus("#task-status", taskId, "warn");
|
||||
context.taskId = data.task_id;
|
||||
setStatus("#task-status", data.task_id, "warn");
|
||||
return;
|
||||
}
|
||||
if (name === "reasoning_delta") {
|
||||
|
|
@ -398,8 +387,8 @@ async function sendMessage() {
|
|||
return;
|
||||
}
|
||||
if (name === "content_delta") {
|
||||
if (!contentStarted) {
|
||||
contentStarted = true;
|
||||
if (!context.contentStarted) {
|
||||
context.contentStarted = true;
|
||||
setMessagePending(pending, "");
|
||||
}
|
||||
pending.querySelector(".message-meta span").textContent = "answering";
|
||||
|
|
@ -407,7 +396,7 @@ async function sendMessage() {
|
|||
return;
|
||||
}
|
||||
if (name === "done") {
|
||||
if (!contentStarted) {
|
||||
if (!context.contentStarted) {
|
||||
setMessagePending(pending, data.final_response || "No final content returned.");
|
||||
}
|
||||
pending.querySelector(".message-meta span").textContent = data.status;
|
||||
|
|
@ -419,13 +408,60 @@ async function sendMessage() {
|
|||
if (name === "error") {
|
||||
throw new Error(data.error || "Stream failed.");
|
||||
}
|
||||
}
|
||||
|
||||
async function continueAfterInlineApproval(article, taskId, approvalId) {
|
||||
if (!article || state.running) return;
|
||||
state.running = true;
|
||||
document.querySelector("#run").disabled = true;
|
||||
setStatus("#task-status", taskId, "warn");
|
||||
const context = {taskId, contentStarted: false};
|
||||
try {
|
||||
await streamSse(
|
||||
`/v1/tasks/${encodeURIComponent(taskId)}/continue/stream`,
|
||||
{approval_id: approvalId},
|
||||
async ({name, data}) => handleAssistantStreamEvent(article, name, data, context),
|
||||
);
|
||||
} catch (error) {
|
||||
setMessagePending(article, error.message);
|
||||
article.querySelector(".message-meta span").textContent = "failed";
|
||||
setStatus("#task-status", "failed", "bad");
|
||||
await refreshEvents(taskId);
|
||||
} finally {
|
||||
state.running = false;
|
||||
document.querySelector("#run").disabled = false;
|
||||
document.querySelector("#message")?.focus();
|
||||
}
|
||||
}
|
||||
|
||||
async function sendMessage() {
|
||||
if (state.running) return;
|
||||
const input = document.querySelector("#message");
|
||||
const message = input.value.trim();
|
||||
if (!message) return;
|
||||
|
||||
state.running = true;
|
||||
document.querySelector("#run").disabled = true;
|
||||
setStatus("#task-status", "running", "warn");
|
||||
addMessage("user", message, "submitted");
|
||||
input.value = "";
|
||||
const pending = addMessage("assistant", "", "thinking", {reasoning: true});
|
||||
const context = {taskId: "", contentStarted: false};
|
||||
|
||||
try {
|
||||
await streamChat({
|
||||
message,
|
||||
workspace: document.querySelector("#workspace").value,
|
||||
debug: document.querySelector("#debug").checked,
|
||||
}, async ({name, data}) => {
|
||||
await handleAssistantStreamEvent(pending, name, data, context);
|
||||
});
|
||||
} catch (error) {
|
||||
if (!taskId) input.value = message;
|
||||
if (!context.taskId) input.value = message;
|
||||
setMessagePending(pending, error.message);
|
||||
pending.querySelector(".message-meta span").textContent = "failed";
|
||||
setStatus("#task-status", "failed", "bad");
|
||||
if (taskId) await refreshEvents(taskId);
|
||||
if (context.taskId) await refreshEvents(context.taskId);
|
||||
} finally {
|
||||
state.running = false;
|
||||
document.querySelector("#run").disabled = false;
|
||||
|
|
|
|||
|
|
@ -8,9 +8,16 @@ Available tools:
|
|||
Args: {"path": "relative/path.txt"}
|
||||
- file_write: write a file inside the current workspace.
|
||||
Args: {"path": "relative/path.txt", "content": "text", "overwrite": false}
|
||||
- list_dir: list direct children of a directory inside the current workspace.
|
||||
Args: {"path": "."}
|
||||
- search_files: search text inside files in the current workspace.
|
||||
Args: {"query": "text to find", "path": ".", "glob": "*.py"}
|
||||
- shell_exec_safe: run a safe allowlisted shell command in the current workspace.
|
||||
Args: {"command": "pwd"}
|
||||
|
||||
Return actions=[] when the user can be answered directly without tools.
|
||||
When tool_observations are already present, request only genuinely missing
|
||||
follow-up information. Return actions=[] when the observations are sufficient
|
||||
for the thinker to answer.
|
||||
Use only the listed tools. Keep actions minimal and directly tied to the user's
|
||||
request. Do not invent tool names.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from fastapi.testclient import TestClient
|
||||
import json
|
||||
import re
|
||||
|
||||
from duck_core.model_client import ModelResponse
|
||||
|
||||
|
|
@ -56,6 +57,16 @@ def test_stream_chat_endpoint_executes_tool_before_streaming_answer(tmp_path, mo
|
|||
|
||||
async def fake_chat(self, role, messages, temperature=None, max_output_tokens=None, response_format=None):
|
||||
assert role == "action"
|
||||
if any("tool_observations" in message["content"] for message in messages):
|
||||
actions = []
|
||||
else:
|
||||
actions = [
|
||||
{
|
||||
"tool": "file_read",
|
||||
"args": {"path": "note.txt"},
|
||||
"reason": "User asked for file contents",
|
||||
}
|
||||
]
|
||||
return ModelResponse(
|
||||
role=role,
|
||||
model="local-main",
|
||||
|
|
@ -64,13 +75,7 @@ def test_stream_chat_endpoint_executes_tool_before_streaming_answer(tmp_path, mo
|
|||
"kind": "action_directive",
|
||||
"intent": "read requested file",
|
||||
"risk_level": "low",
|
||||
"actions": [
|
||||
{
|
||||
"tool": "file_read",
|
||||
"args": {"path": "note.txt"},
|
||||
"reason": "User asked for file contents",
|
||||
}
|
||||
],
|
||||
"actions": actions,
|
||||
}
|
||||
),
|
||||
reasoning_content=None,
|
||||
|
|
@ -101,3 +106,70 @@ def test_stream_chat_endpoint_executes_tool_before_streaming_answer(tmp_path, mo
|
|||
assert "event: content_delta" in body
|
||||
assert "answer from tool" in body
|
||||
assert "event: done" in body
|
||||
|
||||
|
||||
def test_continue_stream_executes_approved_tool_and_streams_answer(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("DUCK_DB_PATH", str(tmp_path / "duck.sqlite3"))
|
||||
|
||||
async def fake_chat(self, role, messages, temperature=None, max_output_tokens=None, response_format=None):
|
||||
assert role == "action"
|
||||
if any("tool_observations" in message["content"] for message in messages):
|
||||
actions = []
|
||||
else:
|
||||
actions = [
|
||||
{
|
||||
"tool": "shell_exec_safe",
|
||||
"args": {"command": "uname -a"},
|
||||
"reason": "User asked for system information",
|
||||
}
|
||||
]
|
||||
return ModelResponse(
|
||||
role=role,
|
||||
model="local-main",
|
||||
content=json.dumps(
|
||||
{
|
||||
"kind": "action_directive",
|
||||
"intent": "run command",
|
||||
"risk_level": "medium",
|
||||
"actions": actions,
|
||||
}
|
||||
),
|
||||
reasoning_content=None,
|
||||
raw={},
|
||||
latency_ms=1.0,
|
||||
)
|
||||
|
||||
async def fake_stream_chat(self, role, messages):
|
||||
assert role == "thinker"
|
||||
observation_message = next(message for message in messages if "tool_observations" in message["content"])
|
||||
assert "uname" in observation_message["content"]
|
||||
yield {"type": "content_delta", "delta": "continued after approval"}
|
||||
|
||||
monkeypatch.setattr("duck_core.model_client.ModelClient.chat", fake_chat)
|
||||
monkeypatch.setattr("duck_core.model_client.ModelClient.stream_chat", fake_stream_chat)
|
||||
client = TestClient(create_app())
|
||||
|
||||
with client.stream(
|
||||
"POST",
|
||||
"/v1/chat/stream",
|
||||
json={"message": "run uname", "workspace": str(tmp_path), "debug": True},
|
||||
) as response:
|
||||
initial_body = "".join(response.iter_text())
|
||||
task_id = re.search(r'"task_id"\s*:\s*"([^"]+)"', initial_body).group(1)
|
||||
pending = client.get("/v1/approvals/pending").json()
|
||||
approval = next(item for item in pending if item["task_id"] == task_id)
|
||||
client.post(f"/v1/approvals/{approval['approval_id']}/allow_once")
|
||||
|
||||
with client.stream(
|
||||
"POST",
|
||||
f"/v1/tasks/{approval['task_id']}/continue/stream",
|
||||
json={"approval_id": approval["approval_id"]},
|
||||
) as response:
|
||||
body = "".join(response.iter_text())
|
||||
|
||||
assert "event: tool_approval_requested" in initial_body
|
||||
assert response.status_code == 200
|
||||
assert "event: tool_call_finished" in body
|
||||
assert "event: content_delta" in body
|
||||
assert "continued after approval" in body
|
||||
assert "event: done" in body
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ def test_chat_api_exposes_pending_approval_from_runtime_tool_gate(tmp_path, monk
|
|||
"actions": [
|
||||
{
|
||||
"tool": "shell_exec_safe",
|
||||
"args": {"command": "uname -a"},
|
||||
"args": {"command": "hostname --pending-approval-test"},
|
||||
"reason": "needs shell command",
|
||||
}
|
||||
],
|
||||
|
|
@ -90,7 +90,11 @@ def test_chat_api_exposes_pending_approval_from_runtime_tool_gate(tmp_path, monk
|
|||
|
||||
response = client.post("/v1/chat", json={"message": "run uname", "debug": True})
|
||||
approvals = client.get("/v1/approvals/pending").json()
|
||||
approval = next(
|
||||
item for item in approvals if item["task_id"] == response.json()["task_id"]
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "waiting_for_approval"
|
||||
assert approvals[0]["normalized_action"]["tool"] == "shell_exec_safe"
|
||||
assert approval["normalized_action"]["tool"] == "shell_exec_safe"
|
||||
assert approval["normalized_action"]["args"]["command"] == "hostname --pending-approval-test"
|
||||
|
|
|
|||
|
|
@ -12,6 +12,16 @@ from duck_core.tasks.store import TaskStore
|
|||
class FakeToolModelClient:
|
||||
async def chat(self, role, messages):
|
||||
if role == "action":
|
||||
if any("tool_observations" in message["content"] for message in messages):
|
||||
actions = []
|
||||
else:
|
||||
actions = [
|
||||
{
|
||||
"tool": "file_read",
|
||||
"args": {"path": "note.txt"},
|
||||
"reason": "User asked for file contents",
|
||||
}
|
||||
]
|
||||
return ModelResponse(
|
||||
role=role,
|
||||
model="local-main",
|
||||
|
|
@ -20,13 +30,7 @@ class FakeToolModelClient:
|
|||
"kind": "action_directive",
|
||||
"intent": "read requested file",
|
||||
"risk_level": "low",
|
||||
"actions": [
|
||||
{
|
||||
"tool": "file_read",
|
||||
"args": {"path": "note.txt"},
|
||||
"reason": "User asked for file contents",
|
||||
}
|
||||
],
|
||||
"actions": actions,
|
||||
}
|
||||
),
|
||||
reasoning_content=None,
|
||||
|
|
@ -45,6 +49,58 @@ class FakeToolModelClient:
|
|||
)
|
||||
|
||||
|
||||
class FakeMultiStepToolModelClient:
|
||||
async def chat(self, role, messages):
|
||||
if role == "action":
|
||||
observation_text = "\n".join(message["content"] for message in messages)
|
||||
if "tool_observations" not in observation_text:
|
||||
actions = [
|
||||
{
|
||||
"tool": "list_dir",
|
||||
"args": {"path": "."},
|
||||
"reason": "Find available files",
|
||||
}
|
||||
]
|
||||
elif "README.md" in observation_text and "readme contents" not in observation_text:
|
||||
actions = [
|
||||
{
|
||||
"tool": "file_read",
|
||||
"args": {"path": "README.md"},
|
||||
"reason": "Read discovered README",
|
||||
}
|
||||
]
|
||||
else:
|
||||
actions = []
|
||||
return ModelResponse(
|
||||
role=role,
|
||||
model="local-main",
|
||||
content=json.dumps(
|
||||
{
|
||||
"kind": "action_directive",
|
||||
"intent": "multi-step file inspection",
|
||||
"risk_level": "low",
|
||||
"actions": actions,
|
||||
}
|
||||
),
|
||||
reasoning_content=None,
|
||||
raw={},
|
||||
latency_ms=5.0,
|
||||
)
|
||||
assert role == "thinker"
|
||||
observation_text = "\n".join(message["content"] for message in messages)
|
||||
assert "list_dir" in observation_text
|
||||
assert "file_read" in observation_text
|
||||
assert "readme contents" in observation_text
|
||||
return ModelResponse(
|
||||
role=role,
|
||||
model="local-main",
|
||||
content="Readme inspected",
|
||||
reasoning_content=None,
|
||||
raw={},
|
||||
latency_ms=12.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_executes_action_directive_tool_and_finishes_with_observation(tmp_path):
|
||||
(tmp_path / "note.txt").write_text("hello from tool")
|
||||
|
|
@ -67,9 +123,38 @@ async def test_runtime_executes_action_directive_tool_and_finishes_with_observat
|
|||
assert tool_finished.payload["result"]["output"] == "hello from tool"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_runs_multiple_tool_steps_before_final_answer(tmp_path):
|
||||
(tmp_path / "README.md").write_text("readme contents")
|
||||
db_path = str(tmp_path / "duck.sqlite3")
|
||||
task_store = TaskStore(db_path)
|
||||
event_store = EventStore(db_path)
|
||||
loop = RuntimeLoop(task_store, event_store, FakeMultiStepToolModelClient())
|
||||
|
||||
result = await loop.run_chat("inspect the workspace readme", str(tmp_path), debug=True)
|
||||
events = await event_store.list_events(result.task_id)
|
||||
finished_tools = [
|
||||
event.payload["tool"] for event in events if event.event_type == "tool_call_finished"
|
||||
]
|
||||
|
||||
assert result.status == "completed"
|
||||
assert result.final_response == "Readme inspected"
|
||||
assert finished_tools == ["list_dir", "file_read"]
|
||||
|
||||
|
||||
class FakeApprovalModelClient:
|
||||
async def chat(self, role, messages):
|
||||
if role == "action":
|
||||
if any("tool_observations" in message["content"] for message in messages):
|
||||
actions = []
|
||||
else:
|
||||
actions = [
|
||||
{
|
||||
"tool": "shell_exec_safe",
|
||||
"args": {"command": "uname -a"},
|
||||
"reason": "User requested system information",
|
||||
}
|
||||
]
|
||||
return ModelResponse(
|
||||
role=role,
|
||||
model="local-main",
|
||||
|
|
@ -78,13 +163,7 @@ class FakeApprovalModelClient:
|
|||
"kind": "action_directive",
|
||||
"intent": "run command",
|
||||
"risk_level": "medium",
|
||||
"actions": [
|
||||
{
|
||||
"tool": "shell_exec_safe",
|
||||
"args": {"command": "uname -a"},
|
||||
"reason": "User requested system information",
|
||||
}
|
||||
],
|
||||
"actions": actions,
|
||||
}
|
||||
),
|
||||
reasoning_content=None,
|
||||
|
|
@ -110,3 +189,197 @@ async def test_runtime_creates_pending_approval_when_tool_requires_it(tmp_path):
|
|||
assert pending[0].task_id == result.task_id
|
||||
assert pending[0].normalized_action["tool"] == "shell_exec_safe"
|
||||
assert any(event.event_type == "tool_approval_requested" for event in events)
|
||||
|
||||
|
||||
class FakeApprovalContinuationModelClient:
|
||||
def __init__(self):
|
||||
self.thinker_messages = []
|
||||
|
||||
async def chat(self, role, messages):
|
||||
if role == "action":
|
||||
if any("tool_observations" in message["content"] for message in messages):
|
||||
actions = []
|
||||
else:
|
||||
actions = [
|
||||
{
|
||||
"tool": "shell_exec_safe",
|
||||
"args": {"command": "uname -a"},
|
||||
"reason": "User requested system information",
|
||||
}
|
||||
]
|
||||
return ModelResponse(
|
||||
role=role,
|
||||
model="local-main",
|
||||
content=json.dumps(
|
||||
{
|
||||
"kind": "action_directive",
|
||||
"intent": "run command",
|
||||
"risk_level": "medium",
|
||||
"actions": actions,
|
||||
}
|
||||
),
|
||||
reasoning_content=None,
|
||||
raw={},
|
||||
latency_ms=5.0,
|
||||
)
|
||||
assert role == "thinker"
|
||||
self.thinker_messages = messages
|
||||
assert any("tool_observations" in message["content"] for message in messages)
|
||||
return ModelResponse(
|
||||
role=role,
|
||||
model="local-main",
|
||||
content="uname completed",
|
||||
reasoning_content="used approved shell command",
|
||||
raw={},
|
||||
latency_ms=10.0,
|
||||
)
|
||||
|
||||
|
||||
class FakeApprovalThenSecondToolModelClient:
|
||||
async def chat(self, role, messages):
|
||||
observation_text = "\n".join(message["content"] for message in messages)
|
||||
if role == "action":
|
||||
if "tool_observations" in observation_text and "second step content" not in observation_text:
|
||||
actions = [
|
||||
{
|
||||
"tool": "file_read",
|
||||
"args": {"path": "second.txt"},
|
||||
"reason": "Read follow-up file after approved command",
|
||||
}
|
||||
]
|
||||
elif "tool_observations" in observation_text:
|
||||
actions = []
|
||||
else:
|
||||
actions = [
|
||||
{
|
||||
"tool": "shell_exec_safe",
|
||||
"args": {"command": "uname -a"},
|
||||
"reason": "User requested system information",
|
||||
}
|
||||
]
|
||||
return ModelResponse(
|
||||
role=role,
|
||||
model="local-main",
|
||||
content=json.dumps(
|
||||
{
|
||||
"kind": "action_directive",
|
||||
"intent": "approval then follow-up",
|
||||
"risk_level": "medium",
|
||||
"actions": actions,
|
||||
}
|
||||
),
|
||||
reasoning_content=None,
|
||||
raw={},
|
||||
latency_ms=5.0,
|
||||
)
|
||||
assert role == "thinker"
|
||||
assert "shell_exec_safe" in observation_text
|
||||
assert "file_read" in observation_text
|
||||
assert "second step content" in observation_text
|
||||
return ModelResponse(
|
||||
role=role,
|
||||
model="local-main",
|
||||
content="approved command and second tool completed",
|
||||
reasoning_content=None,
|
||||
raw={},
|
||||
latency_ms=10.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_continues_after_approved_tool_call(tmp_path):
|
||||
db_path = str(tmp_path / "duck.sqlite3")
|
||||
task_store = TaskStore(db_path)
|
||||
event_store = EventStore(db_path)
|
||||
approvals = ApprovalService(db_path)
|
||||
model_client = FakeApprovalContinuationModelClient()
|
||||
loop = RuntimeLoop(task_store, event_store, model_client, approval_service=approvals)
|
||||
|
||||
pending_result = await loop.run_chat("run uname", str(tmp_path), debug=True)
|
||||
pending = await approvals.pending()
|
||||
await approvals.allow_once(pending[0].approval_id)
|
||||
|
||||
result = await loop.continue_after_approval(pending_result.task_id, pending[0].approval_id)
|
||||
events = await event_store.list_events(result.task_id)
|
||||
finished = next(event for event in events if event.event_type == "tool_call_finished")
|
||||
|
||||
assert result.status == "completed"
|
||||
assert result.final_response == "uname completed"
|
||||
assert finished.payload["tool"] == "shell_exec_safe"
|
||||
assert finished.payload["result"]["ok"] is True
|
||||
assert "uname" in finished.payload["result"]["metadata"]["command"]
|
||||
assert any(event.event_type == "task_completed" for event in events)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_can_run_followup_tool_after_approval(tmp_path):
|
||||
(tmp_path / "second.txt").write_text("second step content")
|
||||
db_path = str(tmp_path / "duck.sqlite3")
|
||||
task_store = TaskStore(db_path)
|
||||
event_store = EventStore(db_path)
|
||||
approvals = ApprovalService(db_path)
|
||||
loop = RuntimeLoop(
|
||||
task_store,
|
||||
event_store,
|
||||
FakeApprovalThenSecondToolModelClient(),
|
||||
approval_service=approvals,
|
||||
)
|
||||
|
||||
pending_result = await loop.run_chat("run uname then inspect second file", str(tmp_path), debug=True)
|
||||
pending = await approvals.pending()
|
||||
await approvals.allow_once(pending[0].approval_id)
|
||||
|
||||
result = await loop.continue_after_approval(pending_result.task_id, pending[0].approval_id)
|
||||
events = await event_store.list_events(result.task_id)
|
||||
finished_tools = [
|
||||
event.payload["tool"] for event in events if event.event_type == "tool_call_finished"
|
||||
]
|
||||
|
||||
assert result.status == "completed"
|
||||
assert finished_tools == ["shell_exec_safe", "file_read"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_continues_after_denied_tool_call_without_execution(tmp_path):
|
||||
db_path = str(tmp_path / "duck.sqlite3")
|
||||
task_store = TaskStore(db_path)
|
||||
event_store = EventStore(db_path)
|
||||
approvals = ApprovalService(db_path)
|
||||
model_client = FakeApprovalContinuationModelClient()
|
||||
loop = RuntimeLoop(task_store, event_store, model_client, approval_service=approvals)
|
||||
|
||||
pending_result = await loop.run_chat("run uname", str(tmp_path), debug=True)
|
||||
pending = await approvals.pending()
|
||||
await approvals.deny(pending[0].approval_id)
|
||||
|
||||
result = await loop.continue_after_approval(pending_result.task_id, pending[0].approval_id)
|
||||
events = await event_store.list_events(result.task_id)
|
||||
finished = next(event for event in events if event.event_type == "tool_call_finished")
|
||||
|
||||
assert result.status == "completed"
|
||||
assert finished.payload["result"]["ok"] is False
|
||||
assert finished.payload["result"]["metadata"]["decision"] == "deny"
|
||||
assert "denied" in finished.payload["result"]["error"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_reuses_allow_forever_for_matching_action(tmp_path):
|
||||
db_path = str(tmp_path / "duck.sqlite3")
|
||||
task_store = TaskStore(db_path)
|
||||
event_store = EventStore(db_path)
|
||||
approvals = ApprovalService(db_path)
|
||||
model_client = FakeApprovalContinuationModelClient()
|
||||
loop = RuntimeLoop(task_store, event_store, model_client, approval_service=approvals)
|
||||
|
||||
first_result = await loop.run_chat("run uname", str(tmp_path), debug=True)
|
||||
first_pending = await approvals.pending()
|
||||
await approvals.allow_forever(first_pending[0].approval_id)
|
||||
await loop.continue_after_approval(first_result.task_id, first_pending[0].approval_id)
|
||||
|
||||
second_result = await loop.run_chat("run uname again", str(tmp_path), debug=True)
|
||||
second_events = await event_store.list_events(second_result.task_id)
|
||||
|
||||
assert second_result.status == "completed"
|
||||
assert second_result.final_response == "uname completed"
|
||||
assert not any(event.event_type == "tool_approval_requested" for event in second_events)
|
||||
assert any(event.event_type == "tool_call_finished" for event in second_events)
|
||||
|
|
|
|||
|
|
@ -40,3 +40,39 @@ async def test_tool_gateway_runs_allowed_directive(tmp_path):
|
|||
|
||||
assert result.ok is True
|
||||
assert result.metadata["path"].endswith("a.txt")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_gateway_lists_workspace_directory(tmp_path):
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "src" / "app.py").write_text("print('duck')")
|
||||
(tmp_path / "README.md").write_text("hello")
|
||||
gateway = ToolGateway.default(str(tmp_path))
|
||||
|
||||
result = await gateway.run_action({"tool": "list_dir", "args": {"path": "."}})
|
||||
escaped = await gateway.run_action({"tool": "list_dir", "args": {"path": ".."}})
|
||||
|
||||
assert result.ok is True
|
||||
assert "README.md" in result.output
|
||||
assert "src/" in result.output
|
||||
assert escaped.ok is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_gateway_searches_file_contents(tmp_path):
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "src" / "app.py").write_text("duck tool gateway\\n")
|
||||
(tmp_path / "notes.txt").write_text("other content\\n")
|
||||
gateway = ToolGateway.default(str(tmp_path))
|
||||
|
||||
result = await gateway.run_action(
|
||||
{"tool": "search_files", "args": {"query": "duck tool", "path": "."}}
|
||||
)
|
||||
escaped = await gateway.run_action(
|
||||
{"tool": "search_files", "args": {"query": "duck", "path": ".."}}
|
||||
)
|
||||
|
||||
assert result.ok is True
|
||||
assert "src/app.py:1:duck tool gateway" in result.output
|
||||
assert result.metadata["matches"] == 1
|
||||
assert escaped.ok is False
|
||||
|
|
|
|||
Loading…
Reference in New Issue