Add approval continuation and multi-step tools
This commit is contained in:
parent
2d3a047548
commit
a4b7ef034a
128
duck_core/api.py
128
duck_core/api.py
|
|
@ -30,6 +30,10 @@ class ChatRequest(BaseModel):
|
||||||
debug: bool = False
|
debug: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ContinueRequest(BaseModel):
|
||||||
|
approval_id: str
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
if settings.api_host == "0.0.0.0":
|
if settings.api_host == "0.0.0.0":
|
||||||
|
|
@ -142,7 +146,7 @@ def create_app() -> FastAPI:
|
||||||
content_parts: list[str] = []
|
content_parts: list[str] = []
|
||||||
try:
|
try:
|
||||||
messages = runtime.context_builder.build_basic_messages(task)
|
messages = runtime.context_builder.build_basic_messages(task)
|
||||||
tool_observations = await runtime._run_action_tools(
|
tool_observations = await runtime._run_action_loop(
|
||||||
task.task_id, messages, body.workspace or settings.workspace
|
task.task_id, messages, body.workspace or settings.workspace
|
||||||
)
|
)
|
||||||
async for tool_event in emit_tool_events(task.task_id, task_event.sequence):
|
async for tool_event in emit_tool_events(task.task_id, task_event.sequence):
|
||||||
|
|
@ -283,6 +287,128 @@ def create_app() -> FastAPI:
|
||||||
await event_store.append(task_id, "task_continued", {})
|
await event_store.append(task_id, "task_continued", {})
|
||||||
return {"status": "running"}
|
return {"status": "running"}
|
||||||
|
|
||||||
|
@app.post("/v1/tasks/{task_id}/continue/stream")
|
||||||
|
async def continue_task_stream(task_id: str, body: ContinueRequest) -> StreamingResponse:
|
||||||
|
task = await task_store.get_task(task_id)
|
||||||
|
if task is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
approval = await approvals.get(body.approval_id)
|
||||||
|
if approval is None or approval.task_id != task_id:
|
||||||
|
raise HTTPException(status_code=404, detail="Approval not found for task")
|
||||||
|
if approval.decision is None:
|
||||||
|
raise HTTPException(status_code=409, detail="Approval is still pending")
|
||||||
|
|
||||||
|
async def generator():
|
||||||
|
reasoning_parts: list[str] = []
|
||||||
|
content_parts: list[str] = []
|
||||||
|
try:
|
||||||
|
await task_store.update_status(task_id, "running")
|
||||||
|
continued_event = await event_store.append(
|
||||||
|
task_id,
|
||||||
|
"task_continued",
|
||||||
|
{"approval_id": body.approval_id, "decision": approval.decision},
|
||||||
|
)
|
||||||
|
tool_observation = await runtime._run_approved_or_denied_action(
|
||||||
|
task_id, approval.normalized_action, approval.decision
|
||||||
|
)
|
||||||
|
messages = runtime.context_builder.build_basic_messages(task)
|
||||||
|
tool_observations = [tool_observation]
|
||||||
|
if approval.decision != "deny":
|
||||||
|
tool_observations = await runtime._run_action_loop(
|
||||||
|
task_id,
|
||||||
|
messages,
|
||||||
|
task.workspace,
|
||||||
|
initial_observations=tool_observations,
|
||||||
|
)
|
||||||
|
async for tool_event in emit_tool_events(task_id, continued_event.sequence):
|
||||||
|
yield tool_event
|
||||||
|
if any(observation.get("requires_approval") for observation in tool_observations):
|
||||||
|
await task_store.waiting_for_approval(task_id)
|
||||||
|
await event_store.append(
|
||||||
|
task_id,
|
||||||
|
"task_waiting_for_approval",
|
||||||
|
{"observations": tool_observations},
|
||||||
|
)
|
||||||
|
yield sse(
|
||||||
|
"done",
|
||||||
|
{
|
||||||
|
"task_id": task_id,
|
||||||
|
"status": "waiting_for_approval",
|
||||||
|
"final_response": "Waiting for approval.",
|
||||||
|
"reasoning_content": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
*messages,
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "tool_observations:\n"
|
||||||
|
+ json.dumps(tool_observations, ensure_ascii=False, indent=2),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
await event_store.append(task_id, "model_call_started", {"role": "thinker"})
|
||||||
|
async for chunk in model_client.stream_chat("thinker", messages):
|
||||||
|
delta = str(chunk.get("delta") or "")
|
||||||
|
if chunk.get("type") == "reasoning_delta":
|
||||||
|
reasoning_parts.append(delta)
|
||||||
|
yield sse("reasoning_delta", {"task_id": task_id, "delta": delta})
|
||||||
|
elif chunk.get("type") == "content_delta":
|
||||||
|
content_parts.append(delta)
|
||||||
|
yield sse("content_delta", {"task_id": task_id, "delta": delta})
|
||||||
|
|
||||||
|
content = "".join(content_parts)
|
||||||
|
reasoning_content = "".join(reasoning_parts) or None
|
||||||
|
await event_store.append(
|
||||||
|
task_id,
|
||||||
|
"cognition_response",
|
||||||
|
{
|
||||||
|
"role": "thinker",
|
||||||
|
"content": content,
|
||||||
|
"reasoning_content": reasoning_content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await event_store.append(
|
||||||
|
task_id,
|
||||||
|
"model_call_finished",
|
||||||
|
{
|
||||||
|
"role": "thinker",
|
||||||
|
"model": model_client.get_role_config("thinker").model,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await task_store.complete_task(task_id, content)
|
||||||
|
await event_store.append(
|
||||||
|
task_id,
|
||||||
|
"task_completed",
|
||||||
|
{
|
||||||
|
"final_response": content,
|
||||||
|
"reasoning_content": reasoning_content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield sse(
|
||||||
|
"done",
|
||||||
|
{
|
||||||
|
"task_id": task_id,
|
||||||
|
"status": "completed",
|
||||||
|
"final_response": content,
|
||||||
|
"reasoning_content": reasoning_content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
await task_store.fail_task(task_id, str(exc))
|
||||||
|
await event_store.append(task_id, "task_failed", {"error": str(exc)})
|
||||||
|
yield sse(
|
||||||
|
"error",
|
||||||
|
{
|
||||||
|
"task_id": task_id,
|
||||||
|
"status": "failed",
|
||||||
|
"error": str(exc),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return StreamingResponse(generator(), media_type="text/event-stream")
|
||||||
|
|
||||||
@app.post("/v1/tasks/{task_id}/cancel")
|
@app.post("/v1/tasks/{task_id}/cancel")
|
||||||
async def cancel_task(task_id: str) -> dict[str, str]:
|
async def cancel_task(task_id: str) -> dict[str, str]:
|
||||||
await task_store.cancel_task(task_id)
|
await task_store.cancel_task(task_id)
|
||||||
|
|
|
||||||
|
|
@ -93,6 +93,16 @@ class ApprovalService:
|
||||||
rows = await cursor.fetchall()
|
rows = await cursor.fetchall()
|
||||||
return [self._row_to_approval(row) for row in rows]
|
return [self._row_to_approval(row) for row in rows]
|
||||||
|
|
||||||
|
async def get(self, approval_id: str) -> Approval | None:
|
||||||
|
await self.init()
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"select * from approvals where approval_id = ?", (approval_id,)
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
return self._row_to_approval(row) if row else None
|
||||||
|
|
||||||
async def allow_once(self, approval_id: str) -> None:
|
async def allow_once(self, approval_id: str) -> None:
|
||||||
await self._decide(approval_id, "resolved", "allow_once")
|
await self._decide(approval_id, "resolved", "allow_once")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from duck_core.events.store import EventStore
|
||||||
from duck_core.model_client import ModelClient
|
from duck_core.model_client import ModelClient
|
||||||
from duck_core.tasks.store import TaskStore
|
from duck_core.tasks.store import TaskStore
|
||||||
from duck_core.tools.gateway import ToolGateway
|
from duck_core.tools.gateway import ToolGateway
|
||||||
|
from duck_core.tools.base import ToolResult
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -26,12 +27,14 @@ class RuntimeLoop:
|
||||||
model_client: ModelClient | None = None,
|
model_client: ModelClient | None = None,
|
||||||
context_builder: ContextBuilder | None = None,
|
context_builder: ContextBuilder | None = None,
|
||||||
approval_service: ApprovalService | None = None,
|
approval_service: ApprovalService | None = None,
|
||||||
|
max_tool_iterations: int = 4,
|
||||||
):
|
):
|
||||||
self.task_store = task_store
|
self.task_store = task_store
|
||||||
self.event_store = event_store
|
self.event_store = event_store
|
||||||
self.model_client = model_client or ModelClient()
|
self.model_client = model_client or ModelClient()
|
||||||
self.context_builder = context_builder or ContextBuilder()
|
self.context_builder = context_builder or ContextBuilder()
|
||||||
self.approval_service = approval_service
|
self.approval_service = approval_service
|
||||||
|
self.max_tool_iterations = max_tool_iterations
|
||||||
|
|
||||||
async def run_chat(
|
async def run_chat(
|
||||||
self, message: str, workspace: str | None = None, debug: bool = False
|
self, message: str, workspace: str | None = None, debug: bool = False
|
||||||
|
|
@ -44,7 +47,7 @@ class RuntimeLoop:
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
messages = self.context_builder.build_basic_messages(task)
|
messages = self.context_builder.build_basic_messages(task)
|
||||||
tool_observations = await self._run_action_tools(task.task_id, messages, workspace)
|
tool_observations = await self._run_action_loop(task.task_id, messages, workspace)
|
||||||
if any(observation.get("requires_approval") for observation in tool_observations):
|
if any(observation.get("requires_approval") for observation in tool_observations):
|
||||||
await self.task_store.waiting_for_approval(task.task_id)
|
await self.task_store.waiting_for_approval(task.task_id)
|
||||||
await self.event_store.append(
|
await self.event_store.append(
|
||||||
|
|
@ -119,8 +122,172 @@ class RuntimeLoop:
|
||||||
reasoning_content=None,
|
reasoning_content=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def continue_after_approval(self, task_id: str, approval_id: str) -> ChatResult:
|
||||||
|
if self.approval_service is None:
|
||||||
|
return ChatResult(
|
||||||
|
task_id=task_id,
|
||||||
|
status="failed",
|
||||||
|
final_response="Approval service is not configured.",
|
||||||
|
reasoning_content=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
task = await self.task_store.get_task(task_id)
|
||||||
|
approval = await self.approval_service.get(approval_id)
|
||||||
|
if task is None:
|
||||||
|
return ChatResult(
|
||||||
|
task_id=task_id,
|
||||||
|
status="failed",
|
||||||
|
final_response="Task not found.",
|
||||||
|
reasoning_content=None,
|
||||||
|
)
|
||||||
|
if approval is None or approval.task_id != task_id:
|
||||||
|
return ChatResult(
|
||||||
|
task_id=task_id,
|
||||||
|
status="failed",
|
||||||
|
final_response="Approval not found for task.",
|
||||||
|
reasoning_content=None,
|
||||||
|
)
|
||||||
|
if approval.decision is None:
|
||||||
|
return ChatResult(
|
||||||
|
task_id=task_id,
|
||||||
|
status="waiting_for_approval",
|
||||||
|
final_response="Waiting for approval.",
|
||||||
|
reasoning_content=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.task_store.update_status(task_id, "running")
|
||||||
|
await self.event_store.append(
|
||||||
|
task_id,
|
||||||
|
"task_continued",
|
||||||
|
{"approval_id": approval_id, "decision": approval.decision},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
tool_observation = await self._run_approved_or_denied_action(
|
||||||
|
task_id, approval.normalized_action, approval.decision
|
||||||
|
)
|
||||||
|
messages = self.context_builder.build_basic_messages(task)
|
||||||
|
tool_observations = [tool_observation]
|
||||||
|
if approval.decision != "deny":
|
||||||
|
tool_observations = await self._run_action_loop(
|
||||||
|
task_id,
|
||||||
|
messages,
|
||||||
|
task.workspace,
|
||||||
|
initial_observations=tool_observations,
|
||||||
|
)
|
||||||
|
if any(observation.get("requires_approval") for observation in tool_observations):
|
||||||
|
await self.task_store.waiting_for_approval(task_id)
|
||||||
|
await self.event_store.append(
|
||||||
|
task_id,
|
||||||
|
"task_waiting_for_approval",
|
||||||
|
{"observations": tool_observations},
|
||||||
|
)
|
||||||
|
return ChatResult(
|
||||||
|
task_id=task_id,
|
||||||
|
status="waiting_for_approval",
|
||||||
|
final_response="Waiting for approval.",
|
||||||
|
reasoning_content=None,
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
*messages,
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "tool_observations:\n"
|
||||||
|
+ json.dumps(tool_observations, ensure_ascii=False, indent=2),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
await self.event_store.append(task_id, "model_call_started", {"role": "thinker"})
|
||||||
|
response = await self.model_client.chat("thinker", messages)
|
||||||
|
await self.event_store.append(
|
||||||
|
task_id,
|
||||||
|
"cognition_response",
|
||||||
|
{
|
||||||
|
"role": response.role,
|
||||||
|
"content": response.content,
|
||||||
|
"reasoning_content": response.reasoning_content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await self.event_store.append(
|
||||||
|
task_id,
|
||||||
|
"model_call_finished",
|
||||||
|
{
|
||||||
|
"role": response.role,
|
||||||
|
"model": response.model,
|
||||||
|
"latency_ms": response.latency_ms,
|
||||||
|
"prompt_tokens": response.prompt_tokens,
|
||||||
|
"completion_tokens": response.completion_tokens,
|
||||||
|
"total_tokens": response.total_tokens,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await self.task_store.complete_task(task_id, response.content)
|
||||||
|
await self.event_store.append(
|
||||||
|
task_id,
|
||||||
|
"task_completed",
|
||||||
|
{
|
||||||
|
"final_response": response.content,
|
||||||
|
"reasoning_content": response.reasoning_content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return ChatResult(
|
||||||
|
task_id=task_id,
|
||||||
|
status="completed",
|
||||||
|
final_response=response.content,
|
||||||
|
reasoning_content=response.reasoning_content,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
await self.task_store.fail_task(task_id, str(exc))
|
||||||
|
await self.event_store.append(task_id, "task_failed", {"error": str(exc)})
|
||||||
|
return ChatResult(
|
||||||
|
task_id=task_id,
|
||||||
|
status="failed",
|
||||||
|
final_response=str(exc),
|
||||||
|
reasoning_content=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _run_approved_or_denied_action(
|
||||||
|
self, task_id: str, action: dict[str, Any], decision: str
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
tool_name = str(action.get("tool", ""))
|
||||||
|
index = await self._approval_action_index(task_id, action)
|
||||||
|
if decision == "deny":
|
||||||
|
result = ToolResult(
|
||||||
|
ok=False,
|
||||||
|
error="Tool action denied by user.",
|
||||||
|
metadata={"decision": "deny"},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
gateway = ToolGateway.default((await self.task_store.get_task(task_id)).workspace or ".")
|
||||||
|
result = await gateway.run_action(action, approved=True)
|
||||||
|
|
||||||
|
result_payload = result.model_dump()
|
||||||
|
await self.event_store.append(
|
||||||
|
task_id,
|
||||||
|
"tool_call_finished",
|
||||||
|
{"index": index, "tool": tool_name, "result": result_payload},
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"index": index,
|
||||||
|
"tool": tool_name,
|
||||||
|
"reason": action.get("reason"),
|
||||||
|
"decision": decision,
|
||||||
|
"result": result_payload,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _approval_action_index(self, task_id: str, action: dict[str, Any]) -> int:
|
||||||
|
events = await self.event_store.list_events(task_id)
|
||||||
|
for event in reversed(events):
|
||||||
|
if (
|
||||||
|
event.event_type == "tool_approval_requested"
|
||||||
|
and event.payload.get("action") == action
|
||||||
|
):
|
||||||
|
return int(event.payload.get("index") or 1)
|
||||||
|
return 1
|
||||||
|
|
||||||
async def _run_action_tools(
|
async def _run_action_tools(
|
||||||
self, task_id: str, messages: list[dict[str, str]], workspace: str | None
|
self,
|
||||||
|
task_id: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
workspace: str | None,
|
||||||
|
start_index: int = 1,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
try:
|
try:
|
||||||
await self.event_store.append(task_id, "model_call_started", {"role": "action"})
|
await self.event_store.append(task_id, "model_call_started", {"role": "action"})
|
||||||
|
|
@ -141,7 +308,7 @@ class RuntimeLoop:
|
||||||
|
|
||||||
gateway = ToolGateway.default(workspace or ".")
|
gateway = ToolGateway.default(workspace or ".")
|
||||||
observations: list[dict[str, Any]] = []
|
observations: list[dict[str, Any]] = []
|
||||||
for index, action in enumerate(actions, start=1):
|
for index, action in enumerate(actions, start=start_index):
|
||||||
if not isinstance(action, dict):
|
if not isinstance(action, dict):
|
||||||
observations.append(
|
observations.append(
|
||||||
{"index": index, "ok": False, "error": "Action must be an object"}
|
{"index": index, "ok": False, "error": "Action must be an object"}
|
||||||
|
|
@ -153,7 +320,11 @@ class RuntimeLoop:
|
||||||
"tool_call_started",
|
"tool_call_started",
|
||||||
{"index": index, "tool": tool_name, "args": action.get("args") or {}},
|
{"index": index, "tool": tool_name, "args": action.get("args") or {}},
|
||||||
)
|
)
|
||||||
result = await gateway.run_action(action)
|
approved_forever = (
|
||||||
|
self.approval_service is not None
|
||||||
|
and await self.approval_service.is_allowed_forever(action)
|
||||||
|
)
|
||||||
|
result = await gateway.run_action(action, approved=approved_forever)
|
||||||
result_payload = result.model_dump()
|
result_payload = result.model_dump()
|
||||||
if result.metadata.get("requires_approval"):
|
if result.metadata.get("requires_approval"):
|
||||||
approval = None
|
approval = None
|
||||||
|
|
@ -195,3 +366,48 @@ class RuntimeLoop:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return observations
|
return observations
|
||||||
|
|
||||||
|
async def _run_action_loop(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
workspace: str | None,
|
||||||
|
initial_observations: list[dict[str, Any]] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
all_observations = list(initial_observations or [])
|
||||||
|
current_messages = messages
|
||||||
|
if all_observations:
|
||||||
|
current_messages = [
|
||||||
|
*messages,
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "tool_observations:\n"
|
||||||
|
+ json.dumps(all_observations, ensure_ascii=False, indent=2),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
for _ in range(self.max_tool_iterations):
|
||||||
|
observations = await self._run_action_tools(
|
||||||
|
task_id,
|
||||||
|
current_messages,
|
||||||
|
workspace,
|
||||||
|
start_index=len(all_observations) + 1,
|
||||||
|
)
|
||||||
|
if not observations:
|
||||||
|
return all_observations
|
||||||
|
all_observations.extend(observations)
|
||||||
|
if any(observation.get("requires_approval") for observation in observations):
|
||||||
|
return all_observations
|
||||||
|
current_messages = [
|
||||||
|
*messages,
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "tool_observations:\n"
|
||||||
|
+ json.dumps(all_observations, ensure_ascii=False, indent=2),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
await self.event_store.append(
|
||||||
|
task_id,
|
||||||
|
"tool_iteration_limit_reached",
|
||||||
|
{"max_tool_iterations": self.max_tool_iterations},
|
||||||
|
)
|
||||||
|
return all_observations
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@ from typing import Any
|
||||||
from duck_core.tools.base import Tool, ToolResult
|
from duck_core.tools.base import Tool, ToolResult
|
||||||
from duck_core.tools.file_read import FileReadTool
|
from duck_core.tools.file_read import FileReadTool
|
||||||
from duck_core.tools.file_write import FileWriteTool
|
from duck_core.tools.file_write import FileWriteTool
|
||||||
|
from duck_core.tools.list_dir import ListDirTool
|
||||||
|
from duck_core.tools.search_files import SearchFilesTool
|
||||||
from duck_core.tools.shell_exec_safe import ShellExecSafeTool
|
from duck_core.tools.shell_exec_safe import ShellExecSafeTool
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -16,11 +18,13 @@ class ToolGateway:
|
||||||
[
|
[
|
||||||
FileReadTool(workspace),
|
FileReadTool(workspace),
|
||||||
FileWriteTool(workspace),
|
FileWriteTool(workspace),
|
||||||
|
ListDirTool(workspace),
|
||||||
|
SearchFilesTool(workspace),
|
||||||
ShellExecSafeTool(workspace),
|
ShellExecSafeTool(workspace),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_action(self, action: dict[str, Any]) -> ToolResult:
|
async def run_action(self, action: dict[str, Any], approved: bool = False) -> ToolResult:
|
||||||
tool_name = str(action.get("tool", ""))
|
tool_name = str(action.get("tool", ""))
|
||||||
tool = self.tools.get(tool_name)
|
tool = self.tools.get(tool_name)
|
||||||
if tool is None:
|
if tool is None:
|
||||||
|
|
@ -28,4 +32,6 @@ class ToolGateway:
|
||||||
args = action.get("args") or {}
|
args = action.get("args") or {}
|
||||||
if not isinstance(args, dict):
|
if not isinstance(args, dict):
|
||||||
return ToolResult(ok=False, error="Tool args must be an object")
|
return ToolResult(ok=False, error="Tool args must be an object")
|
||||||
|
if approved:
|
||||||
|
args = {**args, "_approved": True}
|
||||||
return await tool.run(args)
|
return await tool.run(args)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from duck_core.tools.base import ToolResult
|
||||||
|
from duck_core.tools.paths import WorkspacePathError, resolve_workspace_path
|
||||||
|
|
||||||
|
|
||||||
|
class ListDirTool:
|
||||||
|
name = "list_dir"
|
||||||
|
risk_level = "low"
|
||||||
|
|
||||||
|
def __init__(self, workspace: str, max_entries: int = 200):
|
||||||
|
self.workspace = workspace
|
||||||
|
self.max_entries = max_entries
|
||||||
|
|
||||||
|
async def run(self, args: dict[str, Any]) -> ToolResult:
|
||||||
|
raw_path = str(args.get("path") or ".")
|
||||||
|
try:
|
||||||
|
path = resolve_workspace_path(self.workspace, raw_path)
|
||||||
|
except WorkspacePathError as exc:
|
||||||
|
return ToolResult(ok=False, error=str(exc))
|
||||||
|
if not path.exists():
|
||||||
|
return ToolResult(ok=False, error=f"Directory not found: {raw_path}")
|
||||||
|
if not path.is_dir():
|
||||||
|
return ToolResult(ok=False, error=f"Not a directory: {raw_path}")
|
||||||
|
|
||||||
|
entries = []
|
||||||
|
for child in sorted(path.iterdir(), key=lambda item: (not item.is_dir(), item.name.lower())):
|
||||||
|
suffix = "/" if child.is_dir() else ""
|
||||||
|
entries.append(f"{child.name}{suffix}")
|
||||||
|
if len(entries) >= self.max_entries:
|
||||||
|
break
|
||||||
|
|
||||||
|
truncated = len(entries) >= self.max_entries
|
||||||
|
return ToolResult(
|
||||||
|
ok=True,
|
||||||
|
output="\n".join(entries),
|
||||||
|
metadata={
|
||||||
|
"path": str(path),
|
||||||
|
"entries": len(entries),
|
||||||
|
"truncated": truncated,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,66 @@
|
||||||
|
from fnmatch import fnmatch
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from duck_core.tools.base import ToolResult
|
||||||
|
from duck_core.tools.paths import WorkspacePathError, resolve_workspace_path
|
||||||
|
|
||||||
|
|
||||||
|
class SearchFilesTool:
|
||||||
|
name = "search_files"
|
||||||
|
risk_level = "low"
|
||||||
|
|
||||||
|
def __init__(self, workspace: str, max_matches: int = 100, max_file_bytes: int = 1_000_000):
|
||||||
|
self.workspace = workspace
|
||||||
|
self.max_matches = max_matches
|
||||||
|
self.max_file_bytes = max_file_bytes
|
||||||
|
|
||||||
|
async def run(self, args: dict[str, Any]) -> ToolResult:
|
||||||
|
query = str(args.get("query") or "")
|
||||||
|
raw_path = str(args.get("path") or ".")
|
||||||
|
pattern = str(args.get("glob") or "*")
|
||||||
|
case_sensitive = bool(args.get("case_sensitive", True))
|
||||||
|
max_matches = min(int(args.get("max_matches") or self.max_matches), self.max_matches)
|
||||||
|
if not query:
|
||||||
|
return ToolResult(ok=False, error="Search query is required")
|
||||||
|
try:
|
||||||
|
root = resolve_workspace_path(self.workspace, ".")
|
||||||
|
path = resolve_workspace_path(self.workspace, raw_path)
|
||||||
|
except WorkspacePathError as exc:
|
||||||
|
return ToolResult(ok=False, error=str(exc))
|
||||||
|
if not path.exists():
|
||||||
|
return ToolResult(ok=False, error=f"Search path not found: {raw_path}")
|
||||||
|
|
||||||
|
needle = query if case_sensitive else query.lower()
|
||||||
|
matches: list[str] = []
|
||||||
|
files_scanned = 0
|
||||||
|
candidates = [path] if path.is_file() else path.rglob("*")
|
||||||
|
for candidate in candidates:
|
||||||
|
if len(matches) >= max_matches:
|
||||||
|
break
|
||||||
|
if not candidate.is_file() or ".git" in candidate.parts:
|
||||||
|
continue
|
||||||
|
relative = candidate.relative_to(root).as_posix()
|
||||||
|
if not fnmatch(candidate.name, pattern) and not fnmatch(relative, pattern):
|
||||||
|
continue
|
||||||
|
if candidate.stat().st_size > self.max_file_bytes:
|
||||||
|
continue
|
||||||
|
files_scanned += 1
|
||||||
|
text = candidate.read_text(errors="replace")
|
||||||
|
for line_number, line in enumerate(text.splitlines(), start=1):
|
||||||
|
haystack = line if case_sensitive else line.lower()
|
||||||
|
if needle in haystack:
|
||||||
|
matches.append(f"{relative}:{line_number}:{line}")
|
||||||
|
if len(matches) >= max_matches:
|
||||||
|
break
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
ok=True,
|
||||||
|
output="\n".join(matches),
|
||||||
|
metadata={
|
||||||
|
"path": str(path),
|
||||||
|
"query": query,
|
||||||
|
"matches": len(matches),
|
||||||
|
"files_scanned": files_scanned,
|
||||||
|
"truncated": len(matches) >= max_matches,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
@ -57,7 +57,8 @@ class ShellExecSafeTool:
|
||||||
|
|
||||||
async def run(self, args: dict[str, Any]) -> ToolResult:
|
async def run(self, args: dict[str, Any]) -> ToolResult:
|
||||||
command = str(args.get("command", "")).strip()
|
command = str(args.get("command", "")).strip()
|
||||||
allowed, reason = self._is_allowed(command)
|
approved = bool(args.get("_approved"))
|
||||||
|
allowed, reason = self._is_allowed(command, approved=approved)
|
||||||
if not allowed:
|
if not allowed:
|
||||||
return ToolResult(ok=False, error=reason, metadata={"requires_approval": True})
|
return ToolResult(ok=False, error=reason, metadata={"requires_approval": True})
|
||||||
try:
|
try:
|
||||||
|
|
@ -79,13 +80,15 @@ class ShellExecSafeTool:
|
||||||
metadata={"returncode": completed.returncode, "command": command},
|
metadata={"returncode": completed.returncode, "command": command},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _is_allowed(self, command: str) -> tuple[bool, str | None]:
|
def _is_allowed(self, command: str, approved: bool = False) -> tuple[bool, str | None]:
|
||||||
if not command:
|
if not command:
|
||||||
return False, "Empty command"
|
return False, "Empty command"
|
||||||
lowered = command.lower()
|
lowered = command.lower()
|
||||||
for blocked in BLOCKLIST:
|
for blocked in BLOCKLIST:
|
||||||
if lowered.startswith(blocked.lower()) or blocked.lower() in lowered:
|
if lowered.startswith(blocked.lower()) or blocked.lower() in lowered:
|
||||||
return False, f"Command is blocked: {blocked}"
|
return False, f"Command is blocked: {blocked}"
|
||||||
|
if approved:
|
||||||
|
return True, None
|
||||||
parts = shlex.split(command)
|
parts = shlex.split(command)
|
||||||
prefix1 = parts[0] if parts else ""
|
prefix1 = parts[0] if parts else ""
|
||||||
prefix2 = " ".join(parts[:2])
|
prefix2 = " ".join(parts[:2])
|
||||||
|
|
|
||||||
|
|
@ -165,6 +165,7 @@ function appendApprovalTerminal(article, eventPayload) {
|
||||||
const command = formatToolCommand(action.tool || payload.tool, action.args || {});
|
const command = formatToolCommand(action.tool || payload.tool, action.args || {});
|
||||||
terminal?.classList.add("is-waiting");
|
terminal?.classList.add("is-waiting");
|
||||||
if (terminal && payload.approval_id) terminal.dataset.approvalId = payload.approval_id;
|
if (terminal && payload.approval_id) terminal.dataset.approvalId = payload.approval_id;
|
||||||
|
if (terminal && eventPayload.task_id) terminal.dataset.taskId = eventPayload.task_id;
|
||||||
if (status) status.textContent = "approval";
|
if (status) status.textContent = "approval";
|
||||||
if (body) {
|
if (body) {
|
||||||
body.textContent = [
|
body.textContent = [
|
||||||
|
|
@ -207,6 +208,7 @@ function inlineApprovalButton(label, action, tone = "") {
|
||||||
async function resolveInlineApproval(button) {
|
async function resolveInlineApproval(button) {
|
||||||
const terminal = button.closest(".tool-terminal");
|
const terminal = button.closest(".tool-terminal");
|
||||||
const approvalId = terminal?.dataset.approvalId;
|
const approvalId = terminal?.dataset.approvalId;
|
||||||
|
const taskId = terminal?.dataset.taskId;
|
||||||
const action = button.dataset.inlineApprovalAction;
|
const action = button.dataset.inlineApprovalAction;
|
||||||
if (!terminal || !approvalId || !action) return;
|
if (!terminal || !approvalId || !action) return;
|
||||||
|
|
||||||
|
|
@ -226,6 +228,9 @@ async function resolveInlineApproval(button) {
|
||||||
if (status) status.textContent = decision;
|
if (status) status.textContent = decision;
|
||||||
if (body) body.textContent = `${command}\n\n${decision}: ${humanApprovalDecision(action)}`;
|
if (body) body.textContent = `${command}\n\n${decision}: ${humanApprovalDecision(action)}`;
|
||||||
actions?.remove();
|
actions?.remove();
|
||||||
|
if (taskId) {
|
||||||
|
await continueAfterInlineApproval(terminal.closest(".message"), taskId, approvalId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function humanApprovalDecision(action) {
|
function humanApprovalDecision(action) {
|
||||||
|
|
@ -320,8 +325,8 @@ function parseSseBlock(block) {
|
||||||
return {name: event.name, data: JSON.parse(event.data)};
|
return {name: event.name, data: JSON.parse(event.data)};
|
||||||
}
|
}
|
||||||
|
|
||||||
async function streamChat(payload, onEvent) {
|
async function streamSse(url, payload, onEvent) {
|
||||||
const response = await fetch("/v1/chat/stream", {
|
const response = await fetch(url, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {"Content-Type": "application/json"},
|
headers: {"Content-Type": "application/json"},
|
||||||
body: JSON.stringify(payload),
|
body: JSON.stringify(payload),
|
||||||
|
|
@ -350,6 +355,85 @@ async function streamChat(payload, onEvent) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function streamChat(payload, onEvent) {
|
||||||
|
await streamSse("/v1/chat/stream", payload, onEvent);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function handleAssistantStreamEvent(pending, name, data, context) {
|
||||||
|
if (data.task_id) context.taskId = data.task_id;
|
||||||
|
if (name === "task_created") {
|
||||||
|
context.taskId = data.task_id;
|
||||||
|
setStatus("#task-status", data.task_id, "warn");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (name === "reasoning_delta") {
|
||||||
|
pending.querySelector(".message-meta span").textContent = "reasoning";
|
||||||
|
appendInlineReasoning(pending, data.delta || "");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (name === "tool_call_started") {
|
||||||
|
pending.querySelector(".message-meta span").textContent = "tool";
|
||||||
|
appendToolTerminal(pending, data);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (name === "tool_call_finished") {
|
||||||
|
pending.querySelector(".message-meta span").textContent = "tool";
|
||||||
|
updateToolTerminal(pending, data);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (name === "tool_approval_requested") {
|
||||||
|
pending.querySelector(".message-meta span").textContent = "approval";
|
||||||
|
appendApprovalTerminal(pending, data);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (name === "content_delta") {
|
||||||
|
if (!context.contentStarted) {
|
||||||
|
context.contentStarted = true;
|
||||||
|
setMessagePending(pending, "");
|
||||||
|
}
|
||||||
|
pending.querySelector(".message-meta span").textContent = "answering";
|
||||||
|
appendMessageText(pending, data.delta || "");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (name === "done") {
|
||||||
|
if (!context.contentStarted) {
|
||||||
|
setMessagePending(pending, data.final_response || "No final content returned.");
|
||||||
|
}
|
||||||
|
pending.querySelector(".message-meta span").textContent = data.status;
|
||||||
|
setStatus("#task-status", data.task_id, data.status === "completed" ? "ok" : "warn");
|
||||||
|
finishInlineReasoning(pending, data.reasoning_content);
|
||||||
|
await refreshEvents(data.task_id);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (name === "error") {
|
||||||
|
throw new Error(data.error || "Stream failed.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function continueAfterInlineApproval(article, taskId, approvalId) {
|
||||||
|
if (!article || state.running) return;
|
||||||
|
state.running = true;
|
||||||
|
document.querySelector("#run").disabled = true;
|
||||||
|
setStatus("#task-status", taskId, "warn");
|
||||||
|
const context = {taskId, contentStarted: false};
|
||||||
|
try {
|
||||||
|
await streamSse(
|
||||||
|
`/v1/tasks/${encodeURIComponent(taskId)}/continue/stream`,
|
||||||
|
{approval_id: approvalId},
|
||||||
|
async ({name, data}) => handleAssistantStreamEvent(article, name, data, context),
|
||||||
|
);
|
||||||
|
} catch (error) {
|
||||||
|
setMessagePending(article, error.message);
|
||||||
|
article.querySelector(".message-meta span").textContent = "failed";
|
||||||
|
setStatus("#task-status", "failed", "bad");
|
||||||
|
await refreshEvents(taskId);
|
||||||
|
} finally {
|
||||||
|
state.running = false;
|
||||||
|
document.querySelector("#run").disabled = false;
|
||||||
|
document.querySelector("#message")?.focus();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async function sendMessage() {
|
async function sendMessage() {
|
||||||
if (state.running) return;
|
if (state.running) return;
|
||||||
const input = document.querySelector("#message");
|
const input = document.querySelector("#message");
|
||||||
|
|
@ -362,8 +446,7 @@ async function sendMessage() {
|
||||||
addMessage("user", message, "submitted");
|
addMessage("user", message, "submitted");
|
||||||
input.value = "";
|
input.value = "";
|
||||||
const pending = addMessage("assistant", "", "thinking", {reasoning: true});
|
const pending = addMessage("assistant", "", "thinking", {reasoning: true});
|
||||||
let taskId = "";
|
const context = {taskId: "", contentStarted: false};
|
||||||
let contentStarted = false;
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await streamChat({
|
await streamChat({
|
||||||
|
|
@ -371,61 +454,14 @@ async function sendMessage() {
|
||||||
workspace: document.querySelector("#workspace").value,
|
workspace: document.querySelector("#workspace").value,
|
||||||
debug: document.querySelector("#debug").checked,
|
debug: document.querySelector("#debug").checked,
|
||||||
}, async ({name, data}) => {
|
}, async ({name, data}) => {
|
||||||
if (data.task_id) taskId = data.task_id;
|
await handleAssistantStreamEvent(pending, name, data, context);
|
||||||
if (name === "task_created") {
|
|
||||||
taskId = data.task_id;
|
|
||||||
setStatus("#task-status", taskId, "warn");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (name === "reasoning_delta") {
|
|
||||||
pending.querySelector(".message-meta span").textContent = "reasoning";
|
|
||||||
appendInlineReasoning(pending, data.delta || "");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (name === "tool_call_started") {
|
|
||||||
pending.querySelector(".message-meta span").textContent = "tool";
|
|
||||||
appendToolTerminal(pending, data);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (name === "tool_call_finished") {
|
|
||||||
pending.querySelector(".message-meta span").textContent = "tool";
|
|
||||||
updateToolTerminal(pending, data);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (name === "tool_approval_requested") {
|
|
||||||
pending.querySelector(".message-meta span").textContent = "approval";
|
|
||||||
appendApprovalTerminal(pending, data);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (name === "content_delta") {
|
|
||||||
if (!contentStarted) {
|
|
||||||
contentStarted = true;
|
|
||||||
setMessagePending(pending, "");
|
|
||||||
}
|
|
||||||
pending.querySelector(".message-meta span").textContent = "answering";
|
|
||||||
appendMessageText(pending, data.delta || "");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (name === "done") {
|
|
||||||
if (!contentStarted) {
|
|
||||||
setMessagePending(pending, data.final_response || "No final content returned.");
|
|
||||||
}
|
|
||||||
pending.querySelector(".message-meta span").textContent = data.status;
|
|
||||||
setStatus("#task-status", data.task_id, data.status === "completed" ? "ok" : "warn");
|
|
||||||
finishInlineReasoning(pending, data.reasoning_content);
|
|
||||||
await refreshEvents(data.task_id);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (name === "error") {
|
|
||||||
throw new Error(data.error || "Stream failed.");
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (!taskId) input.value = message;
|
if (!context.taskId) input.value = message;
|
||||||
setMessagePending(pending, error.message);
|
setMessagePending(pending, error.message);
|
||||||
pending.querySelector(".message-meta span").textContent = "failed";
|
pending.querySelector(".message-meta span").textContent = "failed";
|
||||||
setStatus("#task-status", "failed", "bad");
|
setStatus("#task-status", "failed", "bad");
|
||||||
if (taskId) await refreshEvents(taskId);
|
if (context.taskId) await refreshEvents(context.taskId);
|
||||||
} finally {
|
} finally {
|
||||||
state.running = false;
|
state.running = false;
|
||||||
document.querySelector("#run").disabled = false;
|
document.querySelector("#run").disabled = false;
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,16 @@ Available tools:
|
||||||
Args: {"path": "relative/path.txt"}
|
Args: {"path": "relative/path.txt"}
|
||||||
- file_write: write a file inside the current workspace.
|
- file_write: write a file inside the current workspace.
|
||||||
Args: {"path": "relative/path.txt", "content": "text", "overwrite": false}
|
Args: {"path": "relative/path.txt", "content": "text", "overwrite": false}
|
||||||
|
- list_dir: list direct children of a directory inside the current workspace.
|
||||||
|
Args: {"path": "."}
|
||||||
|
- search_files: search text inside files in the current workspace.
|
||||||
|
Args: {"query": "text to find", "path": ".", "glob": "*.py"}
|
||||||
- shell_exec_safe: run a safe allowlisted shell command in the current workspace.
|
- shell_exec_safe: run a safe allowlisted shell command in the current workspace.
|
||||||
Args: {"command": "pwd"}
|
Args: {"command": "pwd"}
|
||||||
|
|
||||||
Return actions=[] when the user can be answered directly without tools.
|
Return actions=[] when the user can be answered directly without tools.
|
||||||
|
When tool_observations are already present, request only genuinely missing
|
||||||
|
follow-up information. Return actions=[] when the observations are sufficient
|
||||||
|
for the thinker to answer.
|
||||||
Use only the listed tools. Keep actions minimal and directly tied to the user's
|
Use only the listed tools. Keep actions minimal and directly tied to the user's
|
||||||
request. Do not invent tool names.
|
request. Do not invent tool names.
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
from duck_core.model_client import ModelResponse
|
from duck_core.model_client import ModelResponse
|
||||||
|
|
||||||
|
|
@ -56,6 +57,16 @@ def test_stream_chat_endpoint_executes_tool_before_streaming_answer(tmp_path, mo
|
||||||
|
|
||||||
async def fake_chat(self, role, messages, temperature=None, max_output_tokens=None, response_format=None):
|
async def fake_chat(self, role, messages, temperature=None, max_output_tokens=None, response_format=None):
|
||||||
assert role == "action"
|
assert role == "action"
|
||||||
|
if any("tool_observations" in message["content"] for message in messages):
|
||||||
|
actions = []
|
||||||
|
else:
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"tool": "file_read",
|
||||||
|
"args": {"path": "note.txt"},
|
||||||
|
"reason": "User asked for file contents",
|
||||||
|
}
|
||||||
|
]
|
||||||
return ModelResponse(
|
return ModelResponse(
|
||||||
role=role,
|
role=role,
|
||||||
model="local-main",
|
model="local-main",
|
||||||
|
|
@ -64,13 +75,7 @@ def test_stream_chat_endpoint_executes_tool_before_streaming_answer(tmp_path, mo
|
||||||
"kind": "action_directive",
|
"kind": "action_directive",
|
||||||
"intent": "read requested file",
|
"intent": "read requested file",
|
||||||
"risk_level": "low",
|
"risk_level": "low",
|
||||||
"actions": [
|
"actions": actions,
|
||||||
{
|
|
||||||
"tool": "file_read",
|
|
||||||
"args": {"path": "note.txt"},
|
|
||||||
"reason": "User asked for file contents",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
reasoning_content=None,
|
reasoning_content=None,
|
||||||
|
|
@ -101,3 +106,70 @@ def test_stream_chat_endpoint_executes_tool_before_streaming_answer(tmp_path, mo
|
||||||
assert "event: content_delta" in body
|
assert "event: content_delta" in body
|
||||||
assert "answer from tool" in body
|
assert "answer from tool" in body
|
||||||
assert "event: done" in body
|
assert "event: done" in body
|
||||||
|
|
||||||
|
|
||||||
|
def test_continue_stream_executes_approved_tool_and_streams_answer(tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setenv("DUCK_DB_PATH", str(tmp_path / "duck.sqlite3"))
|
||||||
|
|
||||||
|
async def fake_chat(self, role, messages, temperature=None, max_output_tokens=None, response_format=None):
|
||||||
|
assert role == "action"
|
||||||
|
if any("tool_observations" in message["content"] for message in messages):
|
||||||
|
actions = []
|
||||||
|
else:
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"tool": "shell_exec_safe",
|
||||||
|
"args": {"command": "uname -a"},
|
||||||
|
"reason": "User asked for system information",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return ModelResponse(
|
||||||
|
role=role,
|
||||||
|
model="local-main",
|
||||||
|
content=json.dumps(
|
||||||
|
{
|
||||||
|
"kind": "action_directive",
|
||||||
|
"intent": "run command",
|
||||||
|
"risk_level": "medium",
|
||||||
|
"actions": actions,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
reasoning_content=None,
|
||||||
|
raw={},
|
||||||
|
latency_ms=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fake_stream_chat(self, role, messages):
|
||||||
|
assert role == "thinker"
|
||||||
|
observation_message = next(message for message in messages if "tool_observations" in message["content"])
|
||||||
|
assert "uname" in observation_message["content"]
|
||||||
|
yield {"type": "content_delta", "delta": "continued after approval"}
|
||||||
|
|
||||||
|
monkeypatch.setattr("duck_core.model_client.ModelClient.chat", fake_chat)
|
||||||
|
monkeypatch.setattr("duck_core.model_client.ModelClient.stream_chat", fake_stream_chat)
|
||||||
|
client = TestClient(create_app())
|
||||||
|
|
||||||
|
with client.stream(
|
||||||
|
"POST",
|
||||||
|
"/v1/chat/stream",
|
||||||
|
json={"message": "run uname", "workspace": str(tmp_path), "debug": True},
|
||||||
|
) as response:
|
||||||
|
initial_body = "".join(response.iter_text())
|
||||||
|
task_id = re.search(r'"task_id"\s*:\s*"([^"]+)"', initial_body).group(1)
|
||||||
|
pending = client.get("/v1/approvals/pending").json()
|
||||||
|
approval = next(item for item in pending if item["task_id"] == task_id)
|
||||||
|
client.post(f"/v1/approvals/{approval['approval_id']}/allow_once")
|
||||||
|
|
||||||
|
with client.stream(
|
||||||
|
"POST",
|
||||||
|
f"/v1/tasks/{approval['task_id']}/continue/stream",
|
||||||
|
json={"approval_id": approval["approval_id"]},
|
||||||
|
) as response:
|
||||||
|
body = "".join(response.iter_text())
|
||||||
|
|
||||||
|
assert "event: tool_approval_requested" in initial_body
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "event: tool_call_finished" in body
|
||||||
|
assert "event: content_delta" in body
|
||||||
|
assert "continued after approval" in body
|
||||||
|
assert "event: done" in body
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ def test_chat_api_exposes_pending_approval_from_runtime_tool_gate(tmp_path, monk
|
||||||
"actions": [
|
"actions": [
|
||||||
{
|
{
|
||||||
"tool": "shell_exec_safe",
|
"tool": "shell_exec_safe",
|
||||||
"args": {"command": "uname -a"},
|
"args": {"command": "hostname --pending-approval-test"},
|
||||||
"reason": "needs shell command",
|
"reason": "needs shell command",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
@ -90,7 +90,11 @@ def test_chat_api_exposes_pending_approval_from_runtime_tool_gate(tmp_path, monk
|
||||||
|
|
||||||
response = client.post("/v1/chat", json={"message": "run uname", "debug": True})
|
response = client.post("/v1/chat", json={"message": "run uname", "debug": True})
|
||||||
approvals = client.get("/v1/approvals/pending").json()
|
approvals = client.get("/v1/approvals/pending").json()
|
||||||
|
approval = next(
|
||||||
|
item for item in approvals if item["task_id"] == response.json()["task_id"]
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["status"] == "waiting_for_approval"
|
assert response.json()["status"] == "waiting_for_approval"
|
||||||
assert approvals[0]["normalized_action"]["tool"] == "shell_exec_safe"
|
assert approval["normalized_action"]["tool"] == "shell_exec_safe"
|
||||||
|
assert approval["normalized_action"]["args"]["command"] == "hostname --pending-approval-test"
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,16 @@ from duck_core.tasks.store import TaskStore
|
||||||
class FakeToolModelClient:
|
class FakeToolModelClient:
|
||||||
async def chat(self, role, messages):
|
async def chat(self, role, messages):
|
||||||
if role == "action":
|
if role == "action":
|
||||||
|
if any("tool_observations" in message["content"] for message in messages):
|
||||||
|
actions = []
|
||||||
|
else:
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"tool": "file_read",
|
||||||
|
"args": {"path": "note.txt"},
|
||||||
|
"reason": "User asked for file contents",
|
||||||
|
}
|
||||||
|
]
|
||||||
return ModelResponse(
|
return ModelResponse(
|
||||||
role=role,
|
role=role,
|
||||||
model="local-main",
|
model="local-main",
|
||||||
|
|
@ -20,13 +30,7 @@ class FakeToolModelClient:
|
||||||
"kind": "action_directive",
|
"kind": "action_directive",
|
||||||
"intent": "read requested file",
|
"intent": "read requested file",
|
||||||
"risk_level": "low",
|
"risk_level": "low",
|
||||||
"actions": [
|
"actions": actions,
|
||||||
{
|
|
||||||
"tool": "file_read",
|
|
||||||
"args": {"path": "note.txt"},
|
|
||||||
"reason": "User asked for file contents",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
reasoning_content=None,
|
reasoning_content=None,
|
||||||
|
|
@ -45,6 +49,58 @@ class FakeToolModelClient:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeMultiStepToolModelClient:
|
||||||
|
async def chat(self, role, messages):
|
||||||
|
if role == "action":
|
||||||
|
observation_text = "\n".join(message["content"] for message in messages)
|
||||||
|
if "tool_observations" not in observation_text:
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"tool": "list_dir",
|
||||||
|
"args": {"path": "."},
|
||||||
|
"reason": "Find available files",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif "README.md" in observation_text and "readme contents" not in observation_text:
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"tool": "file_read",
|
||||||
|
"args": {"path": "README.md"},
|
||||||
|
"reason": "Read discovered README",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
actions = []
|
||||||
|
return ModelResponse(
|
||||||
|
role=role,
|
||||||
|
model="local-main",
|
||||||
|
content=json.dumps(
|
||||||
|
{
|
||||||
|
"kind": "action_directive",
|
||||||
|
"intent": "multi-step file inspection",
|
||||||
|
"risk_level": "low",
|
||||||
|
"actions": actions,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
reasoning_content=None,
|
||||||
|
raw={},
|
||||||
|
latency_ms=5.0,
|
||||||
|
)
|
||||||
|
assert role == "thinker"
|
||||||
|
observation_text = "\n".join(message["content"] for message in messages)
|
||||||
|
assert "list_dir" in observation_text
|
||||||
|
assert "file_read" in observation_text
|
||||||
|
assert "readme contents" in observation_text
|
||||||
|
return ModelResponse(
|
||||||
|
role=role,
|
||||||
|
model="local-main",
|
||||||
|
content="Readme inspected",
|
||||||
|
reasoning_content=None,
|
||||||
|
raw={},
|
||||||
|
latency_ms=12.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_runtime_executes_action_directive_tool_and_finishes_with_observation(tmp_path):
|
async def test_runtime_executes_action_directive_tool_and_finishes_with_observation(tmp_path):
|
||||||
(tmp_path / "note.txt").write_text("hello from tool")
|
(tmp_path / "note.txt").write_text("hello from tool")
|
||||||
|
|
@ -67,9 +123,38 @@ async def test_runtime_executes_action_directive_tool_and_finishes_with_observat
|
||||||
assert tool_finished.payload["result"]["output"] == "hello from tool"
|
assert tool_finished.payload["result"]["output"] == "hello from tool"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_runs_multiple_tool_steps_before_final_answer(tmp_path):
|
||||||
|
(tmp_path / "README.md").write_text("readme contents")
|
||||||
|
db_path = str(tmp_path / "duck.sqlite3")
|
||||||
|
task_store = TaskStore(db_path)
|
||||||
|
event_store = EventStore(db_path)
|
||||||
|
loop = RuntimeLoop(task_store, event_store, FakeMultiStepToolModelClient())
|
||||||
|
|
||||||
|
result = await loop.run_chat("inspect the workspace readme", str(tmp_path), debug=True)
|
||||||
|
events = await event_store.list_events(result.task_id)
|
||||||
|
finished_tools = [
|
||||||
|
event.payload["tool"] for event in events if event.event_type == "tool_call_finished"
|
||||||
|
]
|
||||||
|
|
||||||
|
assert result.status == "completed"
|
||||||
|
assert result.final_response == "Readme inspected"
|
||||||
|
assert finished_tools == ["list_dir", "file_read"]
|
||||||
|
|
||||||
|
|
||||||
class FakeApprovalModelClient:
|
class FakeApprovalModelClient:
|
||||||
async def chat(self, role, messages):
|
async def chat(self, role, messages):
|
||||||
if role == "action":
|
if role == "action":
|
||||||
|
if any("tool_observations" in message["content"] for message in messages):
|
||||||
|
actions = []
|
||||||
|
else:
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"tool": "shell_exec_safe",
|
||||||
|
"args": {"command": "uname -a"},
|
||||||
|
"reason": "User requested system information",
|
||||||
|
}
|
||||||
|
]
|
||||||
return ModelResponse(
|
return ModelResponse(
|
||||||
role=role,
|
role=role,
|
||||||
model="local-main",
|
model="local-main",
|
||||||
|
|
@ -78,13 +163,7 @@ class FakeApprovalModelClient:
|
||||||
"kind": "action_directive",
|
"kind": "action_directive",
|
||||||
"intent": "run command",
|
"intent": "run command",
|
||||||
"risk_level": "medium",
|
"risk_level": "medium",
|
||||||
"actions": [
|
"actions": actions,
|
||||||
{
|
|
||||||
"tool": "shell_exec_safe",
|
|
||||||
"args": {"command": "uname -a"},
|
|
||||||
"reason": "User requested system information",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
reasoning_content=None,
|
reasoning_content=None,
|
||||||
|
|
@ -110,3 +189,197 @@ async def test_runtime_creates_pending_approval_when_tool_requires_it(tmp_path):
|
||||||
assert pending[0].task_id == result.task_id
|
assert pending[0].task_id == result.task_id
|
||||||
assert pending[0].normalized_action["tool"] == "shell_exec_safe"
|
assert pending[0].normalized_action["tool"] == "shell_exec_safe"
|
||||||
assert any(event.event_type == "tool_approval_requested" for event in events)
|
assert any(event.event_type == "tool_approval_requested" for event in events)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeApprovalContinuationModelClient:
|
||||||
|
def __init__(self):
|
||||||
|
self.thinker_messages = []
|
||||||
|
|
||||||
|
async def chat(self, role, messages):
|
||||||
|
if role == "action":
|
||||||
|
if any("tool_observations" in message["content"] for message in messages):
|
||||||
|
actions = []
|
||||||
|
else:
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"tool": "shell_exec_safe",
|
||||||
|
"args": {"command": "uname -a"},
|
||||||
|
"reason": "User requested system information",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return ModelResponse(
|
||||||
|
role=role,
|
||||||
|
model="local-main",
|
||||||
|
content=json.dumps(
|
||||||
|
{
|
||||||
|
"kind": "action_directive",
|
||||||
|
"intent": "run command",
|
||||||
|
"risk_level": "medium",
|
||||||
|
"actions": actions,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
reasoning_content=None,
|
||||||
|
raw={},
|
||||||
|
latency_ms=5.0,
|
||||||
|
)
|
||||||
|
assert role == "thinker"
|
||||||
|
self.thinker_messages = messages
|
||||||
|
assert any("tool_observations" in message["content"] for message in messages)
|
||||||
|
return ModelResponse(
|
||||||
|
role=role,
|
||||||
|
model="local-main",
|
||||||
|
content="uname completed",
|
||||||
|
reasoning_content="used approved shell command",
|
||||||
|
raw={},
|
||||||
|
latency_ms=10.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeApprovalThenSecondToolModelClient:
|
||||||
|
async def chat(self, role, messages):
|
||||||
|
observation_text = "\n".join(message["content"] for message in messages)
|
||||||
|
if role == "action":
|
||||||
|
if "tool_observations" in observation_text and "second step content" not in observation_text:
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"tool": "file_read",
|
||||||
|
"args": {"path": "second.txt"},
|
||||||
|
"reason": "Read follow-up file after approved command",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif "tool_observations" in observation_text:
|
||||||
|
actions = []
|
||||||
|
else:
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"tool": "shell_exec_safe",
|
||||||
|
"args": {"command": "uname -a"},
|
||||||
|
"reason": "User requested system information",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return ModelResponse(
|
||||||
|
role=role,
|
||||||
|
model="local-main",
|
||||||
|
content=json.dumps(
|
||||||
|
{
|
||||||
|
"kind": "action_directive",
|
||||||
|
"intent": "approval then follow-up",
|
||||||
|
"risk_level": "medium",
|
||||||
|
"actions": actions,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
reasoning_content=None,
|
||||||
|
raw={},
|
||||||
|
latency_ms=5.0,
|
||||||
|
)
|
||||||
|
assert role == "thinker"
|
||||||
|
assert "shell_exec_safe" in observation_text
|
||||||
|
assert "file_read" in observation_text
|
||||||
|
assert "second step content" in observation_text
|
||||||
|
return ModelResponse(
|
||||||
|
role=role,
|
||||||
|
model="local-main",
|
||||||
|
content="approved command and second tool completed",
|
||||||
|
reasoning_content=None,
|
||||||
|
raw={},
|
||||||
|
latency_ms=10.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_continues_after_approved_tool_call(tmp_path):
|
||||||
|
db_path = str(tmp_path / "duck.sqlite3")
|
||||||
|
task_store = TaskStore(db_path)
|
||||||
|
event_store = EventStore(db_path)
|
||||||
|
approvals = ApprovalService(db_path)
|
||||||
|
model_client = FakeApprovalContinuationModelClient()
|
||||||
|
loop = RuntimeLoop(task_store, event_store, model_client, approval_service=approvals)
|
||||||
|
|
||||||
|
pending_result = await loop.run_chat("run uname", str(tmp_path), debug=True)
|
||||||
|
pending = await approvals.pending()
|
||||||
|
await approvals.allow_once(pending[0].approval_id)
|
||||||
|
|
||||||
|
result = await loop.continue_after_approval(pending_result.task_id, pending[0].approval_id)
|
||||||
|
events = await event_store.list_events(result.task_id)
|
||||||
|
finished = next(event for event in events if event.event_type == "tool_call_finished")
|
||||||
|
|
||||||
|
assert result.status == "completed"
|
||||||
|
assert result.final_response == "uname completed"
|
||||||
|
assert finished.payload["tool"] == "shell_exec_safe"
|
||||||
|
assert finished.payload["result"]["ok"] is True
|
||||||
|
assert "uname" in finished.payload["result"]["metadata"]["command"]
|
||||||
|
assert any(event.event_type == "task_completed" for event in events)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_can_run_followup_tool_after_approval(tmp_path):
|
||||||
|
(tmp_path / "second.txt").write_text("second step content")
|
||||||
|
db_path = str(tmp_path / "duck.sqlite3")
|
||||||
|
task_store = TaskStore(db_path)
|
||||||
|
event_store = EventStore(db_path)
|
||||||
|
approvals = ApprovalService(db_path)
|
||||||
|
loop = RuntimeLoop(
|
||||||
|
task_store,
|
||||||
|
event_store,
|
||||||
|
FakeApprovalThenSecondToolModelClient(),
|
||||||
|
approval_service=approvals,
|
||||||
|
)
|
||||||
|
|
||||||
|
pending_result = await loop.run_chat("run uname then inspect second file", str(tmp_path), debug=True)
|
||||||
|
pending = await approvals.pending()
|
||||||
|
await approvals.allow_once(pending[0].approval_id)
|
||||||
|
|
||||||
|
result = await loop.continue_after_approval(pending_result.task_id, pending[0].approval_id)
|
||||||
|
events = await event_store.list_events(result.task_id)
|
||||||
|
finished_tools = [
|
||||||
|
event.payload["tool"] for event in events if event.event_type == "tool_call_finished"
|
||||||
|
]
|
||||||
|
|
||||||
|
assert result.status == "completed"
|
||||||
|
assert finished_tools == ["shell_exec_safe", "file_read"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_continues_after_denied_tool_call_without_execution(tmp_path):
|
||||||
|
db_path = str(tmp_path / "duck.sqlite3")
|
||||||
|
task_store = TaskStore(db_path)
|
||||||
|
event_store = EventStore(db_path)
|
||||||
|
approvals = ApprovalService(db_path)
|
||||||
|
model_client = FakeApprovalContinuationModelClient()
|
||||||
|
loop = RuntimeLoop(task_store, event_store, model_client, approval_service=approvals)
|
||||||
|
|
||||||
|
pending_result = await loop.run_chat("run uname", str(tmp_path), debug=True)
|
||||||
|
pending = await approvals.pending()
|
||||||
|
await approvals.deny(pending[0].approval_id)
|
||||||
|
|
||||||
|
result = await loop.continue_after_approval(pending_result.task_id, pending[0].approval_id)
|
||||||
|
events = await event_store.list_events(result.task_id)
|
||||||
|
finished = next(event for event in events if event.event_type == "tool_call_finished")
|
||||||
|
|
||||||
|
assert result.status == "completed"
|
||||||
|
assert finished.payload["result"]["ok"] is False
|
||||||
|
assert finished.payload["result"]["metadata"]["decision"] == "deny"
|
||||||
|
assert "denied" in finished.payload["result"]["error"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_reuses_allow_forever_for_matching_action(tmp_path):
|
||||||
|
db_path = str(tmp_path / "duck.sqlite3")
|
||||||
|
task_store = TaskStore(db_path)
|
||||||
|
event_store = EventStore(db_path)
|
||||||
|
approvals = ApprovalService(db_path)
|
||||||
|
model_client = FakeApprovalContinuationModelClient()
|
||||||
|
loop = RuntimeLoop(task_store, event_store, model_client, approval_service=approvals)
|
||||||
|
|
||||||
|
first_result = await loop.run_chat("run uname", str(tmp_path), debug=True)
|
||||||
|
first_pending = await approvals.pending()
|
||||||
|
await approvals.allow_forever(first_pending[0].approval_id)
|
||||||
|
await loop.continue_after_approval(first_result.task_id, first_pending[0].approval_id)
|
||||||
|
|
||||||
|
second_result = await loop.run_chat("run uname again", str(tmp_path), debug=True)
|
||||||
|
second_events = await event_store.list_events(second_result.task_id)
|
||||||
|
|
||||||
|
assert second_result.status == "completed"
|
||||||
|
assert second_result.final_response == "uname completed"
|
||||||
|
assert not any(event.event_type == "tool_approval_requested" for event in second_events)
|
||||||
|
assert any(event.event_type == "tool_call_finished" for event in second_events)
|
||||||
|
|
|
||||||
|
|
@ -40,3 +40,39 @@ async def test_tool_gateway_runs_allowed_directive(tmp_path):
|
||||||
|
|
||||||
assert result.ok is True
|
assert result.ok is True
|
||||||
assert result.metadata["path"].endswith("a.txt")
|
assert result.metadata["path"].endswith("a.txt")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_gateway_lists_workspace_directory(tmp_path):
|
||||||
|
(tmp_path / "src").mkdir()
|
||||||
|
(tmp_path / "src" / "app.py").write_text("print('duck')")
|
||||||
|
(tmp_path / "README.md").write_text("hello")
|
||||||
|
gateway = ToolGateway.default(str(tmp_path))
|
||||||
|
|
||||||
|
result = await gateway.run_action({"tool": "list_dir", "args": {"path": "."}})
|
||||||
|
escaped = await gateway.run_action({"tool": "list_dir", "args": {"path": ".."}})
|
||||||
|
|
||||||
|
assert result.ok is True
|
||||||
|
assert "README.md" in result.output
|
||||||
|
assert "src/" in result.output
|
||||||
|
assert escaped.ok is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_gateway_searches_file_contents(tmp_path):
|
||||||
|
(tmp_path / "src").mkdir()
|
||||||
|
(tmp_path / "src" / "app.py").write_text("duck tool gateway\\n")
|
||||||
|
(tmp_path / "notes.txt").write_text("other content\\n")
|
||||||
|
gateway = ToolGateway.default(str(tmp_path))
|
||||||
|
|
||||||
|
result = await gateway.run_action(
|
||||||
|
{"tool": "search_files", "args": {"query": "duck tool", "path": "."}}
|
||||||
|
)
|
||||||
|
escaped = await gateway.run_action(
|
||||||
|
{"tool": "search_files", "args": {"query": "duck", "path": ".."}}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.ok is True
|
||||||
|
assert "src/app.py:1:duck tool gateway" in result.output
|
||||||
|
assert result.metadata["matches"] == 1
|
||||||
|
assert escaped.ok is False
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue