Add tool failure recovery decisions
This commit is contained in:
parent
717e931a5e
commit
1b4f4c836e
|
|
@ -48,6 +48,7 @@ class ExecutionEngine:
|
||||||
memory_policy: MemoryWritePolicy | None = None,
|
memory_policy: MemoryWritePolicy | None = None,
|
||||||
memory_interface: MemoryInterface | None = None,
|
memory_interface: MemoryInterface | None = None,
|
||||||
prompts: dict[str, str] | None = None,
|
prompts: dict[str, str] | None = None,
|
||||||
|
recovery_limit: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._event_bus = event_bus
|
self._event_bus = event_bus
|
||||||
self._tool_registry = tool_registry
|
self._tool_registry = tool_registry
|
||||||
|
|
@ -58,6 +59,7 @@ class ExecutionEngine:
|
||||||
self._memory_policy = memory_policy
|
self._memory_policy = memory_policy
|
||||||
self._memory_interface = memory_interface
|
self._memory_interface = memory_interface
|
||||||
self._prompts = prompts or {}
|
self._prompts = prompts or {}
|
||||||
|
self._recovery_limit = recovery_limit
|
||||||
|
|
||||||
def set_critic(self, critic: AsyncCriticAdapter) -> None:
|
def set_critic(self, critic: AsyncCriticAdapter) -> None:
|
||||||
self._critic = critic
|
self._critic = critic
|
||||||
|
|
@ -229,16 +231,40 @@ class ExecutionEngine:
|
||||||
"status": result.get("status"),
|
"status": result.get("status"),
|
||||||
})
|
})
|
||||||
|
|
||||||
# If tool needs permission or failed - return immediately, don't continue execution
|
|
||||||
if result.get("status") == "failed":
|
if result.get("status") == "failed":
|
||||||
return {
|
recovery = self._recover_failed_step(
|
||||||
"status": "failed",
|
task=task,
|
||||||
"result": {
|
step=step,
|
||||||
"error": f"Step {step.id} failed",
|
result=result,
|
||||||
"failed_step": step.id,
|
step_results=step_results,
|
||||||
"step_results": step_results,
|
permission_override=permission_override,
|
||||||
},
|
secret_override=secret_override,
|
||||||
}
|
password_override=password_override,
|
||||||
|
)
|
||||||
|
if recovery.get("status") == "awaiting_permission":
|
||||||
|
return recovery
|
||||||
|
if recovery.get("status") == "completed":
|
||||||
|
recovered_result = recovery.get("result")
|
||||||
|
if recovered_result:
|
||||||
|
step_results[-1]["result"] = recovered_result
|
||||||
|
if recovery.get("finish"):
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"result": {
|
||||||
|
"message": recovery.get("message", "Recovered from failed step"),
|
||||||
|
"step_results": step_results,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"result": {
|
||||||
|
"error": f"Step {step.id} failed",
|
||||||
|
"failed_step": step.id,
|
||||||
|
"step_results": step_results,
|
||||||
|
"recovery": recovery.get("result"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
requires_execution = directive.payload.get("requires_execution", True)
|
requires_execution = directive.payload.get("requires_execution", True)
|
||||||
if requires_execution and self._critic:
|
if requires_execution and self._critic:
|
||||||
|
|
@ -260,6 +286,144 @@ class ExecutionEngine:
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _recover_failed_step(
|
||||||
|
self,
|
||||||
|
task: UserTask,
|
||||||
|
step,
|
||||||
|
result: dict[str, Any],
|
||||||
|
step_results: list[dict[str, Any]],
|
||||||
|
permission_override: PermissionDecision | None = None,
|
||||||
|
secret_override: str | None = None,
|
||||||
|
password_override: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if self._recovery_limit <= 0 or not self._critic:
|
||||||
|
return {"status": "failed", "result": {"reason": "recovery_unavailable"}}
|
||||||
|
|
||||||
|
decision = self._evaluate_recovery(task, step, result, step_results)
|
||||||
|
action = decision.get("action", "fail")
|
||||||
|
|
||||||
|
if action == "continue":
|
||||||
|
recovered = dict(result)
|
||||||
|
recovered["status"] = "completed"
|
||||||
|
recovered["recovery_decision"] = decision
|
||||||
|
return {"status": "completed", "result": recovered}
|
||||||
|
|
||||||
|
if action == "respond":
|
||||||
|
recovered = dict(result)
|
||||||
|
recovered["status"] = "completed"
|
||||||
|
recovered["recovery_decision"] = decision
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"result": recovered,
|
||||||
|
"finish": True,
|
||||||
|
"message": decision.get("message") or decision.get("reason") or "Recovered by responding to user",
|
||||||
|
}
|
||||||
|
|
||||||
|
if action == "retry":
|
||||||
|
retry_tool = decision.get("tool") or step.tool
|
||||||
|
retry_args = decision.get("args") or step.args
|
||||||
|
retry_result = self._execute_tool(
|
||||||
|
task=task,
|
||||||
|
directive=ExecutionDirective(
|
||||||
|
type="tool",
|
||||||
|
payload={"tool": retry_tool, "args": retry_args},
|
||||||
|
requires_permission=True,
|
||||||
|
reason=decision.get("reason", "Recovery retry"),
|
||||||
|
),
|
||||||
|
permission_override=permission_override,
|
||||||
|
secret_override=secret_override,
|
||||||
|
password_override=password_override,
|
||||||
|
)
|
||||||
|
if retry_result.get("status") == "awaiting_permission":
|
||||||
|
return retry_result
|
||||||
|
retry_result["recovery_decision"] = decision
|
||||||
|
if retry_result.get("status") == "completed":
|
||||||
|
return {"status": "completed", "result": retry_result}
|
||||||
|
return {"status": "failed", "result": {"decision": decision, "retry_result": retry_result}}
|
||||||
|
|
||||||
|
return {"status": "failed", "result": decision}
|
||||||
|
|
||||||
|
def _evaluate_recovery(
|
||||||
|
self,
|
||||||
|
task: UserTask,
|
||||||
|
step,
|
||||||
|
result: dict[str, Any],
|
||||||
|
step_results: list[dict[str, Any]],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
prompt = self._build_recovery_prompt(task, step, result, step_results)
|
||||||
|
self._publish(task, CRITIC_CALLED, {"step_id": step.id, "mode": "recovery"})
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = asyncio.run(self._critic.generate(prompt, max_tokens=512))
|
||||||
|
decision = self._parse_recovery_decision(output)
|
||||||
|
self._publish(task, CRITIC_RESULT, {
|
||||||
|
"step_id": step.id,
|
||||||
|
"mode": "recovery",
|
||||||
|
"decision": decision,
|
||||||
|
"raw": output,
|
||||||
|
})
|
||||||
|
return decision
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Recovery evaluation failed: {e}")
|
||||||
|
self._publish(task, CRITIC_RESULT, {
|
||||||
|
"step_id": step.id,
|
||||||
|
"mode": "recovery",
|
||||||
|
"error": str(e),
|
||||||
|
})
|
||||||
|
return {"action": "fail", "reason": str(e)}
|
||||||
|
|
||||||
|
def _build_recovery_prompt(
|
||||||
|
self,
|
||||||
|
task: UserTask,
|
||||||
|
step,
|
||||||
|
result: dict[str, Any],
|
||||||
|
step_results: list[dict[str, Any]],
|
||||||
|
) -> str:
|
||||||
|
return f"""You are a recovery controller for an agent runtime.
|
||||||
|
|
||||||
|
Decide what to do after a failed tool step. A non-zero exit code is not always fatal.
|
||||||
|
Interpret the failure in context.
|
||||||
|
|
||||||
|
Allowed actions:
|
||||||
|
- continue: failure is acceptable information; continue the plan.
|
||||||
|
- retry: try one alternative tool call. Include "tool" and "args".
|
||||||
|
- respond: stop and answer the user with available information. Include "message".
|
||||||
|
- fail: real failure; stop the task.
|
||||||
|
|
||||||
|
Return ONLY JSON:
|
||||||
|
{{"action":"continue|retry|respond|fail","reason":"...","tool":"shell_exec","args":{{...}},"message":"..."}}
|
||||||
|
|
||||||
|
Task:
|
||||||
|
{task.input}
|
||||||
|
|
||||||
|
Failed step:
|
||||||
|
id={step.id}
|
||||||
|
tool={step.tool}
|
||||||
|
args={json.dumps(step.args, ensure_ascii=False)}
|
||||||
|
description={step.description}
|
||||||
|
|
||||||
|
Failed result:
|
||||||
|
{json.dumps(result, ensure_ascii=False, indent=2)}
|
||||||
|
|
||||||
|
Previous step results:
|
||||||
|
{json.dumps(step_results, ensure_ascii=False, indent=2)}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _parse_recovery_decision(self, output: str) -> dict[str, Any]:
|
||||||
|
try:
|
||||||
|
json_start = output.find("{")
|
||||||
|
json_end = output.rfind("}") + 1
|
||||||
|
if json_start < 0 or json_end <= 0:
|
||||||
|
return {"action": "fail", "reason": "Recovery output was not JSON"}
|
||||||
|
data = json.loads(output[json_start:json_end])
|
||||||
|
action = data.get("action", "fail")
|
||||||
|
if action not in {"continue", "retry", "respond", "fail"}:
|
||||||
|
action = "fail"
|
||||||
|
data["action"] = action
|
||||||
|
return data
|
||||||
|
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||||
|
return {"action": "fail", "reason": f"Recovery JSON parse failed: {e}"}
|
||||||
|
|
||||||
def _get_ready_steps(
|
def _get_ready_steps(
|
||||||
self,
|
self,
|
||||||
graph: dict[str, Any],
|
graph: dict[str, Any],
|
||||||
|
|
|
||||||
|
|
@ -133,6 +133,7 @@ class RuntimeController:
|
||||||
memory_policy=self._memory_policy,
|
memory_policy=self._memory_policy,
|
||||||
memory_interface=self._memory_interface,
|
memory_interface=self._memory_interface,
|
||||||
prompts=self._prompts,
|
prompts=self._prompts,
|
||||||
|
recovery_limit=runtime_config.tool_retry_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.runtime_loop = RuntimeLoop(
|
self.runtime_loop = RuntimeLoop(
|
||||||
|
|
|
||||||
|
|
@ -24,12 +24,10 @@ class Tool(BaseTool):
|
||||||
stdin_data=str(stdin_secret) if stdin_secret is not None else None,
|
stdin_data=str(stdin_secret) if stdin_secret is not None else None,
|
||||||
)
|
)
|
||||||
output = completed.stdout if completed.returncode == 0 else completed.stderr or completed.stdout
|
output = completed.stdout if completed.returncode == 0 else completed.stderr or completed.stdout
|
||||||
grep_no_matches = "grep" in command and completed.returncode == 1 and not completed.stderr
|
|
||||||
ok = completed.returncode == 0 or grep_no_matches
|
|
||||||
return ToolResult(
|
return ToolResult(
|
||||||
tool=self.name,
|
tool=self.name,
|
||||||
ok=ok,
|
ok=completed.returncode == 0,
|
||||||
output=output,
|
output=output,
|
||||||
error=None if ok else f"Command failed with exit code {completed.returncode}",
|
error=None if completed.returncode == 0 else f"Command failed with exit code {completed.returncode}",
|
||||||
metadata={"exit_code": completed.returncode, "no_matches": grep_no_matches},
|
metadata={"exit_code": completed.returncode},
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -29,9 +29,6 @@ class ShellExecTool(BaseTool):
|
||||||
)
|
)
|
||||||
output = completed.stdout if completed.returncode == 0 else completed.stderr or completed.stdout
|
output = completed.stdout if completed.returncode == 0 else completed.stderr or completed.stdout
|
||||||
error_output = completed.stderr or completed.stdout
|
error_output = completed.stderr or completed.stdout
|
||||||
grep_no_matches = "grep" in command and completed.returncode == 1 and not completed.stderr
|
|
||||||
ok = completed.returncode == 0 or grep_no_matches
|
|
||||||
|
|
||||||
is_sudo_error = (
|
is_sudo_error = (
|
||||||
completed.returncode != 0 and
|
completed.returncode != 0 and
|
||||||
("permission denied" in error_output.lower() or
|
("permission denied" in error_output.lower() or
|
||||||
|
|
@ -42,8 +39,8 @@ class ShellExecTool(BaseTool):
|
||||||
|
|
||||||
return ToolResult(
|
return ToolResult(
|
||||||
tool=self.name,
|
tool=self.name,
|
||||||
ok=ok,
|
ok=completed.returncode == 0,
|
||||||
output=output,
|
output=output,
|
||||||
error=None if ok else f"Command failed with exit code {completed.returncode}",
|
error=None if completed.returncode == 0 else f"Command failed with exit code {completed.returncode}",
|
||||||
metadata={"exit_code": completed.returncode, "needs_sudo": is_sudo_error, "no_matches": grep_no_matches},
|
metadata={"exit_code": completed.returncode, "needs_sudo": is_sudo_error},
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from app.core.contracts import UserTask
|
from app.core.contracts import ExecutionDirective, UserTask
|
||||||
from app.runtime.runtime_controller import RuntimeController
|
from app.runtime.runtime_controller import RuntimeController
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -112,21 +112,37 @@ def test_shell_exec_allows_safe_command(tmp_path: Path) -> None:
|
||||||
assert str(tmp_path) in result["result"]["output"]
|
assert str(tmp_path) in result["result"]["output"]
|
||||||
|
|
||||||
|
|
||||||
def test_shell_exec_treats_grep_no_matches_as_information(tmp_path: Path) -> None:
|
class _RecoveryCritic:
|
||||||
|
async def generate(self, prompt: str, max_tokens: int | None = None) -> str:
|
||||||
|
return '{"action":"continue","reason":"No matches is acceptable information for this exploratory check."}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_failed_shell_step_can_recover_and_continue(tmp_path: Path) -> None:
|
||||||
_write_config_tree(tmp_path)
|
_write_config_tree(tmp_path)
|
||||||
controller = RuntimeController(base_dir=tmp_path)
|
controller = RuntimeController(base_dir=tmp_path)
|
||||||
result = controller.handle_task(
|
controller.execution_engine.set_critic(_RecoveryCritic())
|
||||||
|
controller.execution_engine._recovery_limit = 1
|
||||||
|
result = controller.execution_engine.execute(
|
||||||
UserTask(
|
UserTask(
|
||||||
input="run grep with no matches",
|
input="run grep with no matches and recover",
|
||||||
context={
|
),
|
||||||
"requested_tool": "shell_exec",
|
ExecutionDirective(
|
||||||
"tool_args": {"command": "printf 'abc\\n' | grep definitely_missing"},
|
type="plan",
|
||||||
|
payload={
|
||||||
|
"steps": [
|
||||||
|
{
|
||||||
|
"id": "1",
|
||||||
|
"tool": "shell_exec",
|
||||||
|
"args": {"command": "printf 'abc\\n' | grep definitely_missing"},
|
||||||
|
"depends_on": [],
|
||||||
|
}
|
||||||
|
]
|
||||||
},
|
},
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
assert result["status"] == "completed"
|
assert result["status"] == "completed"
|
||||||
assert result["result"]["metadata"]["exit_code"] == 1
|
failed_result = result["result"]["step_results"][0]["result"]["result"]
|
||||||
assert result["result"]["metadata"]["no_matches"] is True
|
assert failed_result["metadata"]["exit_code"] == 1
|
||||||
|
|
||||||
|
|
||||||
def test_permission_resolution_can_resume_task(tmp_path: Path) -> None:
|
def test_permission_resolution_can_resume_task(tmp_path: Path) -> None:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue