Add tool failure recovery decisions

This commit is contained in:
mirivlad 2026-05-11 00:33:06 +08:00
parent 717e931a5e
commit 1b4f4c836e
5 changed files with 206 additions and 30 deletions

View File

@ -48,6 +48,7 @@ class ExecutionEngine:
memory_policy: MemoryWritePolicy | None = None,
memory_interface: MemoryInterface | None = None,
prompts: dict[str, str] | None = None,
recovery_limit: int = 1,
) -> None:
self._event_bus = event_bus
self._tool_registry = tool_registry
@ -58,6 +59,7 @@ class ExecutionEngine:
self._memory_policy = memory_policy
self._memory_interface = memory_interface
self._prompts = prompts or {}
self._recovery_limit = recovery_limit
def set_critic(self, critic: AsyncCriticAdapter) -> None:
self._critic = critic
@ -229,16 +231,40 @@ class ExecutionEngine:
"status": result.get("status"),
})
# If tool needs permission or failed - return immediately, don't continue execution
if result.get("status") == "failed":
return {
"status": "failed",
"result": {
"error": f"Step {step.id} failed",
"failed_step": step.id,
"step_results": step_results,
},
}
recovery = self._recover_failed_step(
task=task,
step=step,
result=result,
step_results=step_results,
permission_override=permission_override,
secret_override=secret_override,
password_override=password_override,
)
if recovery.get("status") == "awaiting_permission":
return recovery
if recovery.get("status") == "completed":
recovered_result = recovery.get("result")
if recovered_result:
step_results[-1]["result"] = recovered_result
if recovery.get("finish"):
return {
"status": "completed",
"result": {
"message": recovery.get("message", "Recovered from failed step"),
"step_results": step_results,
},
}
else:
return {
"status": "failed",
"result": {
"error": f"Step {step.id} failed",
"failed_step": step.id,
"step_results": step_results,
"recovery": recovery.get("result"),
},
}
requires_execution = directive.payload.get("requires_execution", True)
if requires_execution and self._critic:
@ -260,6 +286,144 @@ class ExecutionEngine:
},
}
def _recover_failed_step(
self,
task: UserTask,
step,
result: dict[str, Any],
step_results: list[dict[str, Any]],
permission_override: PermissionDecision | None = None,
secret_override: str | None = None,
password_override: str | None = None,
) -> dict[str, Any]:
if self._recovery_limit <= 0 or not self._critic:
return {"status": "failed", "result": {"reason": "recovery_unavailable"}}
decision = self._evaluate_recovery(task, step, result, step_results)
action = decision.get("action", "fail")
if action == "continue":
recovered = dict(result)
recovered["status"] = "completed"
recovered["recovery_decision"] = decision
return {"status": "completed", "result": recovered}
if action == "respond":
recovered = dict(result)
recovered["status"] = "completed"
recovered["recovery_decision"] = decision
return {
"status": "completed",
"result": recovered,
"finish": True,
"message": decision.get("message") or decision.get("reason") or "Recovered by responding to user",
}
if action == "retry":
retry_tool = decision.get("tool") or step.tool
retry_args = decision.get("args") or step.args
retry_result = self._execute_tool(
task=task,
directive=ExecutionDirective(
type="tool",
payload={"tool": retry_tool, "args": retry_args},
requires_permission=True,
reason=decision.get("reason", "Recovery retry"),
),
permission_override=permission_override,
secret_override=secret_override,
password_override=password_override,
)
if retry_result.get("status") == "awaiting_permission":
return retry_result
retry_result["recovery_decision"] = decision
if retry_result.get("status") == "completed":
return {"status": "completed", "result": retry_result}
return {"status": "failed", "result": {"decision": decision, "retry_result": retry_result}}
return {"status": "failed", "result": decision}
def _evaluate_recovery(
self,
task: UserTask,
step,
result: dict[str, Any],
step_results: list[dict[str, Any]],
) -> dict[str, Any]:
prompt = self._build_recovery_prompt(task, step, result, step_results)
self._publish(task, CRITIC_CALLED, {"step_id": step.id, "mode": "recovery"})
try:
output = asyncio.run(self._critic.generate(prompt, max_tokens=512))
decision = self._parse_recovery_decision(output)
self._publish(task, CRITIC_RESULT, {
"step_id": step.id,
"mode": "recovery",
"decision": decision,
"raw": output,
})
return decision
except Exception as e:
logger.warning(f"Recovery evaluation failed: {e}")
self._publish(task, CRITIC_RESULT, {
"step_id": step.id,
"mode": "recovery",
"error": str(e),
})
return {"action": "fail", "reason": str(e)}
def _build_recovery_prompt(
self,
task: UserTask,
step,
result: dict[str, Any],
step_results: list[dict[str, Any]],
) -> str:
return f"""You are a recovery controller for an agent runtime.
Decide what to do after a failed tool step. A non-zero exit code is not always fatal.
Interpret the failure in context.
Allowed actions:
- continue: failure is acceptable information; continue the plan.
- retry: try one alternative tool call. Include "tool" and "args".
- respond: stop and answer the user with available information. Include "message".
- fail: real failure; stop the task.
Return ONLY JSON:
{{"action":"continue|retry|respond|fail","reason":"...","tool":"shell_exec","args":{{...}},"message":"..."}}
Task:
{task.input}
Failed step:
id={step.id}
tool={step.tool}
args={json.dumps(step.args, ensure_ascii=False)}
description={step.description}
Failed result:
{json.dumps(result, ensure_ascii=False, indent=2)}
Previous step results:
{json.dumps(step_results, ensure_ascii=False, indent=2)}
"""
def _parse_recovery_decision(self, output: str) -> dict[str, Any]:
try:
json_start = output.find("{")
json_end = output.rfind("}") + 1
if json_start < 0 or json_end <= 0:
return {"action": "fail", "reason": "Recovery output was not JSON"}
data = json.loads(output[json_start:json_end])
action = data.get("action", "fail")
if action not in {"continue", "retry", "respond", "fail"}:
action = "fail"
data["action"] = action
return data
except (json.JSONDecodeError, TypeError, ValueError) as e:
return {"action": "fail", "reason": f"Recovery JSON parse failed: {e}"}
def _get_ready_steps(
self,
graph: dict[str, Any],

View File

@ -133,6 +133,7 @@ class RuntimeController:
memory_policy=self._memory_policy,
memory_interface=self._memory_interface,
prompts=self._prompts,
recovery_limit=runtime_config.tool_retry_limit,
)
self.runtime_loop = RuntimeLoop(

View File

@ -24,12 +24,10 @@ class Tool(BaseTool):
stdin_data=str(stdin_secret) if stdin_secret is not None else None,
)
output = completed.stdout if completed.returncode == 0 else completed.stderr or completed.stdout
grep_no_matches = "grep" in command and completed.returncode == 1 and not completed.stderr
ok = completed.returncode == 0 or grep_no_matches
return ToolResult(
tool=self.name,
ok=ok,
ok=completed.returncode == 0,
output=output,
error=None if ok else f"Command failed with exit code {completed.returncode}",
metadata={"exit_code": completed.returncode, "no_matches": grep_no_matches},
error=None if completed.returncode == 0 else f"Command failed with exit code {completed.returncode}",
metadata={"exit_code": completed.returncode},
)

View File

@ -29,9 +29,6 @@ class ShellExecTool(BaseTool):
)
output = completed.stdout if completed.returncode == 0 else completed.stderr or completed.stdout
error_output = completed.stderr or completed.stdout
grep_no_matches = "grep" in command and completed.returncode == 1 and not completed.stderr
ok = completed.returncode == 0 or grep_no_matches
is_sudo_error = (
completed.returncode != 0 and
("permission denied" in error_output.lower() or
@ -42,8 +39,8 @@ class ShellExecTool(BaseTool):
return ToolResult(
tool=self.name,
ok=ok,
ok=completed.returncode == 0,
output=output,
error=None if ok else f"Command failed with exit code {completed.returncode}",
metadata={"exit_code": completed.returncode, "needs_sudo": is_sudo_error, "no_matches": grep_no_matches},
error=None if completed.returncode == 0 else f"Command failed with exit code {completed.returncode}",
metadata={"exit_code": completed.returncode, "needs_sudo": is_sudo_error},
)

View File

@ -1,7 +1,7 @@
import json
from pathlib import Path
from app.core.contracts import UserTask
from app.core.contracts import ExecutionDirective, UserTask
from app.runtime.runtime_controller import RuntimeController
@ -112,21 +112,37 @@ def test_shell_exec_allows_safe_command(tmp_path: Path) -> None:
assert str(tmp_path) in result["result"]["output"]
def test_shell_exec_treats_grep_no_matches_as_information(tmp_path: Path) -> None:
class _RecoveryCritic:
async def generate(self, prompt: str, max_tokens: int | None = None) -> str:
return '{"action":"continue","reason":"No matches is acceptable information for this exploratory check."}'
def test_failed_shell_step_can_recover_and_continue(tmp_path: Path) -> None:
_write_config_tree(tmp_path)
controller = RuntimeController(base_dir=tmp_path)
result = controller.handle_task(
controller.execution_engine.set_critic(_RecoveryCritic())
controller.execution_engine._recovery_limit = 1
result = controller.execution_engine.execute(
UserTask(
input="run grep with no matches",
context={
"requested_tool": "shell_exec",
"tool_args": {"command": "printf 'abc\\n' | grep definitely_missing"},
input="run grep with no matches and recover",
),
ExecutionDirective(
type="plan",
payload={
"steps": [
{
"id": "1",
"tool": "shell_exec",
"args": {"command": "printf 'abc\\n' | grep definitely_missing"},
"depends_on": [],
}
]
},
)
),
)
assert result["status"] == "completed"
assert result["result"]["metadata"]["exit_code"] == 1
assert result["result"]["metadata"]["no_matches"] is True
failed_result = result["result"]["step_results"][0]["result"]["result"]
assert failed_result["metadata"]["exit_code"] == 1
def test_permission_resolution_can_resume_task(tmp_path: Path) -> None: