Add tool failure recovery decisions
This commit is contained in:
parent
717e931a5e
commit
1b4f4c836e
|
|
@ -48,6 +48,7 @@ class ExecutionEngine:
|
|||
memory_policy: MemoryWritePolicy | None = None,
|
||||
memory_interface: MemoryInterface | None = None,
|
||||
prompts: dict[str, str] | None = None,
|
||||
recovery_limit: int = 1,
|
||||
) -> None:
|
||||
self._event_bus = event_bus
|
||||
self._tool_registry = tool_registry
|
||||
|
|
@ -58,6 +59,7 @@ class ExecutionEngine:
|
|||
self._memory_policy = memory_policy
|
||||
self._memory_interface = memory_interface
|
||||
self._prompts = prompts or {}
|
||||
self._recovery_limit = recovery_limit
|
||||
|
||||
def set_critic(self, critic: AsyncCriticAdapter) -> None:
|
||||
self._critic = critic
|
||||
|
|
@ -229,14 +231,38 @@ class ExecutionEngine:
|
|||
"status": result.get("status"),
|
||||
})
|
||||
|
||||
# If tool needs permission or failed - return immediately, don't continue execution
|
||||
if result.get("status") == "failed":
|
||||
recovery = self._recover_failed_step(
|
||||
task=task,
|
||||
step=step,
|
||||
result=result,
|
||||
step_results=step_results,
|
||||
permission_override=permission_override,
|
||||
secret_override=secret_override,
|
||||
password_override=password_override,
|
||||
)
|
||||
if recovery.get("status") == "awaiting_permission":
|
||||
return recovery
|
||||
if recovery.get("status") == "completed":
|
||||
recovered_result = recovery.get("result")
|
||||
if recovered_result:
|
||||
step_results[-1]["result"] = recovered_result
|
||||
if recovery.get("finish"):
|
||||
return {
|
||||
"status": "completed",
|
||||
"result": {
|
||||
"message": recovery.get("message", "Recovered from failed step"),
|
||||
"step_results": step_results,
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "failed",
|
||||
"result": {
|
||||
"error": f"Step {step.id} failed",
|
||||
"failed_step": step.id,
|
||||
"step_results": step_results,
|
||||
"recovery": recovery.get("result"),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -260,6 +286,144 @@ class ExecutionEngine:
|
|||
},
|
||||
}
|
||||
|
||||
def _recover_failed_step(
|
||||
self,
|
||||
task: UserTask,
|
||||
step,
|
||||
result: dict[str, Any],
|
||||
step_results: list[dict[str, Any]],
|
||||
permission_override: PermissionDecision | None = None,
|
||||
secret_override: str | None = None,
|
||||
password_override: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if self._recovery_limit <= 0 or not self._critic:
|
||||
return {"status": "failed", "result": {"reason": "recovery_unavailable"}}
|
||||
|
||||
decision = self._evaluate_recovery(task, step, result, step_results)
|
||||
action = decision.get("action", "fail")
|
||||
|
||||
if action == "continue":
|
||||
recovered = dict(result)
|
||||
recovered["status"] = "completed"
|
||||
recovered["recovery_decision"] = decision
|
||||
return {"status": "completed", "result": recovered}
|
||||
|
||||
if action == "respond":
|
||||
recovered = dict(result)
|
||||
recovered["status"] = "completed"
|
||||
recovered["recovery_decision"] = decision
|
||||
return {
|
||||
"status": "completed",
|
||||
"result": recovered,
|
||||
"finish": True,
|
||||
"message": decision.get("message") or decision.get("reason") or "Recovered by responding to user",
|
||||
}
|
||||
|
||||
if action == "retry":
|
||||
retry_tool = decision.get("tool") or step.tool
|
||||
retry_args = decision.get("args") or step.args
|
||||
retry_result = self._execute_tool(
|
||||
task=task,
|
||||
directive=ExecutionDirective(
|
||||
type="tool",
|
||||
payload={"tool": retry_tool, "args": retry_args},
|
||||
requires_permission=True,
|
||||
reason=decision.get("reason", "Recovery retry"),
|
||||
),
|
||||
permission_override=permission_override,
|
||||
secret_override=secret_override,
|
||||
password_override=password_override,
|
||||
)
|
||||
if retry_result.get("status") == "awaiting_permission":
|
||||
return retry_result
|
||||
retry_result["recovery_decision"] = decision
|
||||
if retry_result.get("status") == "completed":
|
||||
return {"status": "completed", "result": retry_result}
|
||||
return {"status": "failed", "result": {"decision": decision, "retry_result": retry_result}}
|
||||
|
||||
return {"status": "failed", "result": decision}
|
||||
|
||||
def _evaluate_recovery(
|
||||
self,
|
||||
task: UserTask,
|
||||
step,
|
||||
result: dict[str, Any],
|
||||
step_results: list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
prompt = self._build_recovery_prompt(task, step, result, step_results)
|
||||
self._publish(task, CRITIC_CALLED, {"step_id": step.id, "mode": "recovery"})
|
||||
|
||||
try:
|
||||
output = asyncio.run(self._critic.generate(prompt, max_tokens=512))
|
||||
decision = self._parse_recovery_decision(output)
|
||||
self._publish(task, CRITIC_RESULT, {
|
||||
"step_id": step.id,
|
||||
"mode": "recovery",
|
||||
"decision": decision,
|
||||
"raw": output,
|
||||
})
|
||||
return decision
|
||||
except Exception as e:
|
||||
logger.warning(f"Recovery evaluation failed: {e}")
|
||||
self._publish(task, CRITIC_RESULT, {
|
||||
"step_id": step.id,
|
||||
"mode": "recovery",
|
||||
"error": str(e),
|
||||
})
|
||||
return {"action": "fail", "reason": str(e)}
|
||||
|
||||
def _build_recovery_prompt(
|
||||
self,
|
||||
task: UserTask,
|
||||
step,
|
||||
result: dict[str, Any],
|
||||
step_results: list[dict[str, Any]],
|
||||
) -> str:
|
||||
return f"""You are a recovery controller for an agent runtime.
|
||||
|
||||
Decide what to do after a failed tool step. A non-zero exit code is not always fatal.
|
||||
Interpret the failure in context.
|
||||
|
||||
Allowed actions:
|
||||
- continue: failure is acceptable information; continue the plan.
|
||||
- retry: try one alternative tool call. Include "tool" and "args".
|
||||
- respond: stop and answer the user with available information. Include "message".
|
||||
- fail: real failure; stop the task.
|
||||
|
||||
Return ONLY JSON:
|
||||
{{"action":"continue|retry|respond|fail","reason":"...","tool":"shell_exec","args":{{...}},"message":"..."}}
|
||||
|
||||
Task:
|
||||
{task.input}
|
||||
|
||||
Failed step:
|
||||
id={step.id}
|
||||
tool={step.tool}
|
||||
args={json.dumps(step.args, ensure_ascii=False)}
|
||||
description={step.description}
|
||||
|
||||
Failed result:
|
||||
{json.dumps(result, ensure_ascii=False, indent=2)}
|
||||
|
||||
Previous step results:
|
||||
{json.dumps(step_results, ensure_ascii=False, indent=2)}
|
||||
"""
|
||||
|
||||
def _parse_recovery_decision(self, output: str) -> dict[str, Any]:
|
||||
try:
|
||||
json_start = output.find("{")
|
||||
json_end = output.rfind("}") + 1
|
||||
if json_start < 0 or json_end <= 0:
|
||||
return {"action": "fail", "reason": "Recovery output was not JSON"}
|
||||
data = json.loads(output[json_start:json_end])
|
||||
action = data.get("action", "fail")
|
||||
if action not in {"continue", "retry", "respond", "fail"}:
|
||||
action = "fail"
|
||||
data["action"] = action
|
||||
return data
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
return {"action": "fail", "reason": f"Recovery JSON parse failed: {e}"}
|
||||
|
||||
def _get_ready_steps(
|
||||
self,
|
||||
graph: dict[str, Any],
|
||||
|
|
|
|||
|
|
@ -133,6 +133,7 @@ class RuntimeController:
|
|||
memory_policy=self._memory_policy,
|
||||
memory_interface=self._memory_interface,
|
||||
prompts=self._prompts,
|
||||
recovery_limit=runtime_config.tool_retry_limit,
|
||||
)
|
||||
|
||||
self.runtime_loop = RuntimeLoop(
|
||||
|
|
|
|||
|
|
@ -24,12 +24,10 @@ class Tool(BaseTool):
|
|||
stdin_data=str(stdin_secret) if stdin_secret is not None else None,
|
||||
)
|
||||
output = completed.stdout if completed.returncode == 0 else completed.stderr or completed.stdout
|
||||
grep_no_matches = "grep" in command and completed.returncode == 1 and not completed.stderr
|
||||
ok = completed.returncode == 0 or grep_no_matches
|
||||
return ToolResult(
|
||||
tool=self.name,
|
||||
ok=ok,
|
||||
ok=completed.returncode == 0,
|
||||
output=output,
|
||||
error=None if ok else f"Command failed with exit code {completed.returncode}",
|
||||
metadata={"exit_code": completed.returncode, "no_matches": grep_no_matches},
|
||||
error=None if completed.returncode == 0 else f"Command failed with exit code {completed.returncode}",
|
||||
metadata={"exit_code": completed.returncode},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -29,9 +29,6 @@ class ShellExecTool(BaseTool):
|
|||
)
|
||||
output = completed.stdout if completed.returncode == 0 else completed.stderr or completed.stdout
|
||||
error_output = completed.stderr or completed.stdout
|
||||
grep_no_matches = "grep" in command and completed.returncode == 1 and not completed.stderr
|
||||
ok = completed.returncode == 0 or grep_no_matches
|
||||
|
||||
is_sudo_error = (
|
||||
completed.returncode != 0 and
|
||||
("permission denied" in error_output.lower() or
|
||||
|
|
@ -42,8 +39,8 @@ class ShellExecTool(BaseTool):
|
|||
|
||||
return ToolResult(
|
||||
tool=self.name,
|
||||
ok=ok,
|
||||
ok=completed.returncode == 0,
|
||||
output=output,
|
||||
error=None if ok else f"Command failed with exit code {completed.returncode}",
|
||||
metadata={"exit_code": completed.returncode, "needs_sudo": is_sudo_error, "no_matches": grep_no_matches},
|
||||
error=None if completed.returncode == 0 else f"Command failed with exit code {completed.returncode}",
|
||||
metadata={"exit_code": completed.returncode, "needs_sudo": is_sudo_error},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.contracts import UserTask
|
||||
from app.core.contracts import ExecutionDirective, UserTask
|
||||
from app.runtime.runtime_controller import RuntimeController
|
||||
|
||||
|
||||
|
|
@ -112,21 +112,37 @@ def test_shell_exec_allows_safe_command(tmp_path: Path) -> None:
|
|||
assert str(tmp_path) in result["result"]["output"]
|
||||
|
||||
|
||||
def test_shell_exec_treats_grep_no_matches_as_information(tmp_path: Path) -> None:
|
||||
class _RecoveryCritic:
|
||||
async def generate(self, prompt: str, max_tokens: int | None = None) -> str:
|
||||
return '{"action":"continue","reason":"No matches is acceptable information for this exploratory check."}'
|
||||
|
||||
|
||||
def test_failed_shell_step_can_recover_and_continue(tmp_path: Path) -> None:
|
||||
_write_config_tree(tmp_path)
|
||||
controller = RuntimeController(base_dir=tmp_path)
|
||||
result = controller.handle_task(
|
||||
controller.execution_engine.set_critic(_RecoveryCritic())
|
||||
controller.execution_engine._recovery_limit = 1
|
||||
result = controller.execution_engine.execute(
|
||||
UserTask(
|
||||
input="run grep with no matches",
|
||||
context={
|
||||
"requested_tool": "shell_exec",
|
||||
"tool_args": {"command": "printf 'abc\\n' | grep definitely_missing"},
|
||||
input="run grep with no matches and recover",
|
||||
),
|
||||
ExecutionDirective(
|
||||
type="plan",
|
||||
payload={
|
||||
"steps": [
|
||||
{
|
||||
"id": "1",
|
||||
"tool": "shell_exec",
|
||||
"args": {"command": "printf 'abc\\n' | grep definitely_missing"},
|
||||
"depends_on": [],
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
),
|
||||
)
|
||||
assert result["status"] == "completed"
|
||||
assert result["result"]["metadata"]["exit_code"] == 1
|
||||
assert result["result"]["metadata"]["no_matches"] is True
|
||||
failed_result = result["result"]["step_results"][0]["result"]["result"]
|
||||
assert failed_result["metadata"]["exit_code"] == 1
|
||||
|
||||
|
||||
def test_permission_resolution_can_resume_task(tmp_path: Path) -> None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue