212 lines
7.0 KiB
Python
212 lines
7.0 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from collections import deque
|
|
from typing import Any
|
|
|
|
from app.core.contracts import ExecutionDirective, PlanStep
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ExecutionScheduler:
|
|
def __init__(self, retry_limit: int = 2) -> None:
|
|
self._retry_limit = retry_limit
|
|
|
|
def parse_plan_steps(
|
|
self,
|
|
json_str: str,
|
|
task_id: str | None = None,
|
|
) -> list[PlanStep]:
|
|
try:
|
|
json_start = json_str.find("{")
|
|
json_end = json_str.rfind("}") + 1
|
|
if json_start < 0:
|
|
return []
|
|
|
|
json_str = json_str[json_start:json_end]
|
|
data = json.loads(json_str)
|
|
|
|
# Unified format: {"type": "plan", "payload": {"steps": [...]}}
|
|
# or direct: {"type": "step", "payload": {"tool": "...", "args": {...}}}
|
|
if isinstance(data, dict):
|
|
msg_type = data.get("type", "")
|
|
|
|
# Single step format: {"type": "step", "payload": {"tool": ..., "args": ...}}
|
|
if msg_type == "step":
|
|
payload = data.get("payload", {})
|
|
step = {
|
|
"id": "step-0",
|
|
"kind": "tool",
|
|
"tool": payload.get("tool"),
|
|
"args": payload.get("args", {}),
|
|
"description": payload.get("description", ""),
|
|
"depends_on": payload.get("depends_on", []),
|
|
}
|
|
data = [step]
|
|
|
|
# Plan format: {"type": "plan", "payload": {"steps": [...]}}
|
|
elif msg_type == "plan":
|
|
payload = data.get("payload", {})
|
|
steps_data = payload.get("steps", [])
|
|
|
|
# Normalize steps: handle {"type": "step", "payload": {"tool": ...}}
|
|
normalized = []
|
|
for step in steps_data:
|
|
if isinstance(step, dict) and step.get("type") == "step":
|
|
inner = step.get("payload", {})
|
|
normalized.append({
|
|
"tool": inner.get("tool"),
|
|
"args": inner.get("args", {}),
|
|
"description": inner.get("description", ""),
|
|
"depends_on": inner.get("depends_on", []),
|
|
})
|
|
else:
|
|
normalized.append(step)
|
|
steps_data = normalized
|
|
|
|
data = steps_data if steps_data else []
|
|
|
|
# Old format compatibility
|
|
elif "steps" in data:
|
|
data = data["steps"]
|
|
elif "plan" in data:
|
|
data = data["plan"]
|
|
else:
|
|
data = [data]
|
|
elif isinstance(data, str):
|
|
data = json.loads(data)
|
|
if isinstance(data, dict):
|
|
data = [data]
|
|
|
|
steps = []
|
|
for i, step_data in enumerate(data):
|
|
if isinstance(step_data, str):
|
|
step_data = {"id": f"step-{i}", "kind": "respond", "text": step_data}
|
|
|
|
if not isinstance(step_data, dict):
|
|
continue
|
|
|
|
step_data.setdefault("id", f"step-{i}")
|
|
|
|
# Tool-first: scheduler получает tool напрямую, без трансформаций
|
|
# kind определяется по наличию tool name
|
|
# args передаются напрямую
|
|
if step_data.get("tool"):
|
|
step_data["kind"] = "tool"
|
|
|
|
step_data.setdefault("kind", step_data.get("kind", "respond"))
|
|
step_data.setdefault("tool", step_data.get("tool"))
|
|
step_data.setdefault("args", step_data.get("args", {}))
|
|
step_data.setdefault("description", step_data.get("description", ""))
|
|
step_data.setdefault("requires_confirmation", False)
|
|
step_data.setdefault("depends_on", [])
|
|
|
|
if "description" not in step_data:
|
|
step_data["description"] = f"Step {i}"
|
|
|
|
steps.append(PlanStep(**step_data))
|
|
|
|
return steps
|
|
|
|
except (json.JSONDecodeError, ValueError, TypeError) as e:
|
|
logger.warning(f"Plan parsing failed: {e}")
|
|
return []
|
|
|
|
def validate_no_cycles(self, steps: list[PlanStep]) -> bool:
|
|
if not steps:
|
|
return True
|
|
|
|
graph: dict[str, set[str]] = {}
|
|
for step in steps:
|
|
graph[step.id] = set(step.depends_on)
|
|
|
|
visited: set[str] = set()
|
|
rec_stack: set[str] = set()
|
|
|
|
def has_cycle(node: str) -> bool:
|
|
if node in rec_stack:
|
|
return True
|
|
if node in visited:
|
|
return False
|
|
|
|
visited.add(node)
|
|
rec_stack.add(node)
|
|
|
|
for dep in graph.get(node, []):
|
|
if has_cycle(dep):
|
|
return True
|
|
|
|
rec_stack.remove(node)
|
|
return False
|
|
|
|
for step in steps:
|
|
if step.id not in visited:
|
|
if has_cycle(step.id):
|
|
logger.warning(f"Cycle detected in plan: {step.id}")
|
|
return False
|
|
|
|
return True
|
|
|
|
def build_task_graph(
|
|
self,
|
|
steps: list[PlanStep],
|
|
) -> dict[str, Any]:
|
|
if not steps:
|
|
return {"nodes": [], "edges": []}
|
|
|
|
if not self.validate_no_cycles(steps):
|
|
return {"nodes": [], "edges": [], "error": "Cycle detected in plan"}
|
|
|
|
nodes = []
|
|
edges = []
|
|
|
|
step_map = {s.id: s for s in steps}
|
|
|
|
for step in steps:
|
|
nodes.append({
|
|
"id": step.id,
|
|
"kind": step.kind,
|
|
"tool": step.tool,
|
|
"args": step.args,
|
|
"ready": len(step.depends_on) == 0,
|
|
})
|
|
|
|
for dep_id in step.depends_on:
|
|
edges.append({
|
|
"from": dep_id,
|
|
"to": step.id,
|
|
})
|
|
|
|
return {"nodes": nodes, "edges": edges, "step_map": step_map}
|
|
|
|
def get_ready_steps(
|
|
self,
|
|
graph: dict[str, Any],
|
|
completed: set[str],
|
|
) -> list[PlanStep]:
|
|
if not graph or not graph.get("nodes"):
|
|
return []
|
|
|
|
step_map: dict[str, PlanStep] = graph.get("step_map", {})
|
|
ready = []
|
|
|
|
for node in graph["nodes"]:
|
|
node_id = node["id"]
|
|
if node_id in completed:
|
|
continue
|
|
|
|
deps = node.get("depends_on", [])
|
|
if all(dep in completed for dep in deps):
|
|
step = step_map.get(node_id)
|
|
if step:
|
|
ready.append(step)
|
|
|
|
return ready
|
|
|
|
def next_directive(
|
|
self,
|
|
directive: ExecutionDirective,
|
|
) -> ExecutionDirective:
|
|
return directive |