510 lines
20 KiB
Python
510 lines
20 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
from typing import Any
|
||
|
||
from app.core.contracts import ExecutionDirective
|
||
from app.core.intent_parser import IntentParser
|
||
from app.events.event_bus import EventBus
|
||
from app.events.event_types import (
|
||
ORCHESTRATOR_CALLED,
|
||
ORCHESTRATOR_FALLBACK_USED,
|
||
ORCHESTRATOR_RETRY,
|
||
ORCHESTRATOR_RESULT,
|
||
ORCHESTRATOR_UNAVAILABLE,
|
||
THINKER_CALLED,
|
||
THINKER_RESULT,
|
||
JSON_COMPILER_CALLED,
|
||
JSON_COMPILER_RESULT,
|
||
)
|
||
from app.models.async_adapters import AsyncOrchestratorAdapter
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AsyncRouter:
|
||
"""Async router using Thinker + JSON Compiler pipeline."""
|
||
|
||
def __init__(
|
||
self,
|
||
thinker: AsyncOrchestratorAdapter | None = None,
|
||
json_compiler: AsyncOrchestratorAdapter | None = None,
|
||
intent_parser: IntentParser | None = None,
|
||
prompts: dict[str, str] | None = None,
|
||
event_bus: EventBus | None = None,
|
||
tool_registry=None,
|
||
retry_limit: int = 2,
|
||
debug: bool = False,
|
||
log_length: int = 500,
|
||
json_fix_retry_limit: int = 2,
|
||
json_fix_use_sys_util: bool = True,
|
||
intent_classifier: str = "thinker",
|
||
) -> None:
|
||
self._thinker = thinker
|
||
self._json_compiler = json_compiler
|
||
self._intent_classifier = intent_classifier
|
||
self._sys_util = None
|
||
self._intent_parser = intent_parser or IntentParser()
|
||
self._prompts = prompts or {}
|
||
self._event_bus = event_bus
|
||
self._tool_registry = tool_registry
|
||
self._retry_limit = retry_limit
|
||
self._debug = debug
|
||
self._log_length = log_length
|
||
self._json_fix_retry_limit = json_fix_retry_limit
|
||
self._json_fix_use_sys_util = json_fix_use_sys_util
|
||
self._orchestrator = None # Set separately if needed for classification
|
||
|
||
def set_event_bus(self, event_bus: EventBus) -> None:
|
||
self._event_bus = event_bus
|
||
|
||
def set_thinker(self, thinker: AsyncOrchestratorAdapter) -> None:
|
||
self._thinker = thinker
|
||
|
||
def set_json_compiler(self, json_compiler: AsyncOrchestratorAdapter) -> None:
|
||
self._json_compiler = json_compiler
|
||
|
||
def set_sys_util(self, sys_util: AsyncOrchestratorAdapter) -> None:
|
||
self._sys_util = sys_util
|
||
|
||
def set_orchestrator(self, orchestrator: AsyncOrchestratorAdapter) -> None:
|
||
self._orchestrator = orchestrator
|
||
|
||
def set_tool_registry(self, tool_registry) -> None:
|
||
self._tool_registry = tool_registry
|
||
|
||
async def decide(
|
||
self,
|
||
state: dict[str, Any],
|
||
context: dict[str, Any],
|
||
task_id: str | None = None,
|
||
session_id: str | None = None,
|
||
) -> ExecutionDirective:
|
||
task_context = context.get("task_context", {})
|
||
requested_tool = task_context.get("requested_tool")
|
||
task_summary = str(context.get("task_summary", ""))
|
||
|
||
if requested_tool:
|
||
self._emit_event(
|
||
ORCHESTRATOR_RESULT,
|
||
{"reason": "explicit_tool_request", "tool": requested_tool},
|
||
task_id,
|
||
session_id,
|
||
)
|
||
return ExecutionDirective(
|
||
type="tool",
|
||
payload={
|
||
"tool": requested_tool,
|
||
"args": task_context.get("tool_args", {}),
|
||
},
|
||
requires_permission=requested_tool in {"shell_exec", "file_write"},
|
||
confidence=0.9,
|
||
reason="Task context explicitly requested a tool execution.",
|
||
)
|
||
|
||
if self._thinker is None:
|
||
fallback = self._fallback_directive(task_summary)
|
||
self._emit_event(
|
||
ORCHESTRATOR_FALLBACK_USED,
|
||
{"reason": "thinker_unavailable", "directive": fallback.model_dump(mode="json")},
|
||
task_id,
|
||
session_id,
|
||
)
|
||
return fallback
|
||
|
||
if self._json_compiler is None:
|
||
fallback = self._fallback_directive(task_summary)
|
||
self._emit_event(
|
||
ORCHESTRATOR_FALLBACK_USED,
|
||
{"reason": "json_compiler_unavailable", "directive": fallback.model_dump(mode="json")},
|
||
task_id,
|
||
session_id,
|
||
)
|
||
return fallback
|
||
|
||
mode_hint = await self._classify_intent(task_summary)
|
||
thinker_prompt = self._build_thinker_prompt(task_summary, context, mode_hint)
|
||
|
||
for thinker_attempt in range(self._retry_limit + 1):
|
||
if thinker_attempt > 0:
|
||
self._emit_event(
|
||
ORCHESTRATOR_RETRY,
|
||
{"attempt": thinker_attempt, "prompt": thinker_prompt},
|
||
task_id,
|
||
session_id,
|
||
)
|
||
thinker_prompt = self._add_thinker_feedback(thinker_prompt, last_thinker_error, thinker_attempt)
|
||
|
||
self._emit_event(
|
||
THINKER_CALLED,
|
||
{"attempt": thinker_attempt, "mode": mode_hint},
|
||
task_id,
|
||
session_id,
|
||
)
|
||
|
||
try:
|
||
thinker_result = await self._thinker.generate(thinker_prompt)
|
||
except Exception as e:
|
||
logger.warning(f"Thinker generate failed: {e}")
|
||
last_thinker_error = str(e)
|
||
continue
|
||
|
||
logger.info(f"Thinker result (attempt {thinker_attempt + 1}): {thinker_result}")
|
||
self._emit_event(
|
||
THINKER_RESULT,
|
||
{"result": thinker_result, "attempt": thinker_attempt},
|
||
task_id,
|
||
session_id,
|
||
)
|
||
|
||
# If mode_hint is conversation, only allow respond type
|
||
if mode_hint == "conversation" and not self._is_simple_response(thinker_result):
|
||
# Check if Thinker is trying to create an execution plan instead
|
||
if any(word in thinker_result.lower() for word in ["шаг", "step", "выполнить", "execute", "shell", "команда"]):
|
||
# Override to conversation-only response
|
||
respond_text = self._extract_conversation_response(thinker_result)
|
||
self._emit_event(
|
||
ORCHESTRATOR_RESULT,
|
||
{"directive": {"type": "respond", "payload": {"text": respond_text}}, "mode_violation": True},
|
||
task_id,
|
||
session_id,
|
||
)
|
||
return ExecutionDirective(
|
||
type="respond",
|
||
payload={"text": respond_text},
|
||
requires_permission=False,
|
||
reason="Mode violation: conversation only",
|
||
)
|
||
|
||
if self._is_simple_response(thinker_result):
|
||
json_compiler_prompt = self._build_json_compiler_prompt(thinker_result)
|
||
else:
|
||
json_compiler_prompt = self._build_json_compiler_prompt(thinker_result)
|
||
|
||
for compiler_attempt in range(self._json_fix_retry_limit + 1):
|
||
self._emit_event(
|
||
JSON_COMPILER_CALLED,
|
||
{"attempt": compiler_attempt, "plan": thinker_result},
|
||
task_id,
|
||
session_id,
|
||
)
|
||
|
||
try:
|
||
compiler_result = await self._json_compiler.generate(json_compiler_prompt)
|
||
except Exception as e:
|
||
logger.warning(f"JSON Compiler generate failed: {e}")
|
||
compiler_result = None
|
||
|
||
if compiler_result:
|
||
logger.info(f"JSON Compiler result (attempt {compiler_attempt + 1}): {compiler_result}")
|
||
self._emit_event(
|
||
JSON_COMPILER_RESULT,
|
||
{"result": compiler_result, "attempt": compiler_attempt},
|
||
task_id,
|
||
session_id,
|
||
)
|
||
|
||
directive = self._validate_directive(compiler_result, mode_hint) if compiler_result else None
|
||
if directive is not None:
|
||
directive = self._guard_rail_check(directive)
|
||
self._emit_event(
|
||
ORCHESTRATOR_RESULT,
|
||
{"directive": directive.model_dump(mode="json"), "thinker_attempt": thinker_attempt, "compiler_attempt": compiler_attempt},
|
||
task_id,
|
||
session_id,
|
||
)
|
||
return directive
|
||
|
||
if compiler_result:
|
||
logger.warning(f"JSON Compiler validation failed, attempting fix (attempt {compiler_attempt + 1})")
|
||
fix_result = await self._fix_invalid_json(compiler_result, compiler_attempt, task_id, session_id)
|
||
if fix_result:
|
||
fixed_directive = self._validate_directive(fix_result, mode_hint)
|
||
if fixed_directive is not None:
|
||
fixed_directive = self._guard_rail_check(fixed_directive)
|
||
self._emit_event(
|
||
ORCHESTRATOR_RESULT,
|
||
{"directive": fixed_directive.model_dump(mode="json"), "fixed": True},
|
||
task_id,
|
||
session_id,
|
||
)
|
||
return fixed_directive
|
||
|
||
last_thinker_error = f"JSON Compiler failed after {self._json_fix_retry_limit + 1} attempts"
|
||
|
||
self._emit_event(
|
||
ORCHESTRATOR_UNAVAILABLE,
|
||
{"reason": "retry_exhausted", "last_error": last_thinker_error},
|
||
task_id,
|
||
session_id,
|
||
)
|
||
raise RuntimeError(f"Thinker/Compiler pipeline failed after {self._retry_limit + 1} attempts")
|
||
|
||
def _fallback_directive(self, task_summary: str) -> ExecutionDirective:
|
||
parsed = self._intent_parser.parse(task_summary)
|
||
if parsed:
|
||
return parsed
|
||
|
||
return ExecutionDirective(
|
||
type="respond",
|
||
payload={"text": f"Runtime accepted task: {task_summary}"},
|
||
requires_permission=False,
|
||
confidence=0.4,
|
||
reason="Fallback response because local orchestration models are not loaded.",
|
||
)
|
||
|
||
def _is_simple_response(self, thinker_result: str) -> bool:
|
||
result_lower = thinker_result.lower().strip()
|
||
return result_lower.startswith("ответ:") or result_lower.startswith("response:") or "не нужно" in result_lower
|
||
|
||
def _extract_conversation_response(self, thinker_result: str) -> str:
|
||
"""Extract text response from thinker result for conversation mode."""
|
||
result_lower = thinker_result.lower()
|
||
|
||
# Skip the ПЛАН lines, just get the ОТВЕТ part
|
||
lines = thinker_result.split('\n')
|
||
response_lines = []
|
||
capture = False
|
||
|
||
for line in lines:
|
||
if line.strip().lower().startswith('ответ:') or line.strip().lower().startswith('response:'):
|
||
capture = True
|
||
response_lines.append(line)
|
||
elif capture and line.strip():
|
||
# Check if this is a new ПЛАН or step
|
||
if line.strip().lower().startswith('план') or line.strip().lower().startswith('step'):
|
||
break
|
||
response_lines.append(line)
|
||
|
||
if response_lines:
|
||
return '\n'.join(response_lines).replace('ответ:', '').replace('response:', '').strip()
|
||
|
||
# Fallback: return first few sentences
|
||
sentences = thinker_result.split('.')[:3]
|
||
return '. '.join(sentences).strip()
|
||
|
||
def _build_thinker_prompt(
|
||
self, task_summary: str, context: dict[str, Any], mode_hint: str
|
||
) -> str:
|
||
base_prompt = self._prompts.get("thinker", "")
|
||
memory_context = context.get("memory_context", [])
|
||
|
||
tools_json = "[]"
|
||
if self._tool_registry:
|
||
schemas = self._tool_registry.list_schemas()
|
||
tools_json = json.dumps(schemas, ensure_ascii=False, indent=2)
|
||
|
||
prompt_lines = [
|
||
base_prompt,
|
||
"",
|
||
f"Task: {task_summary}",
|
||
f"Mode hint: {mode_hint}",
|
||
]
|
||
|
||
if memory_context:
|
||
memory_text = "\n".join([f"- {m.get('text', '')}" for m in memory_context[:5]])
|
||
prompt_lines.append(f"\nRelevant memory:\n{memory_text}")
|
||
|
||
session_history = context.get("session_history", [])
|
||
if session_history:
|
||
history_text = "\n".join([f"- {h.get('text', '')}" for h in session_history[:3]])
|
||
prompt_lines.append(f"\nPrevious requests in this session:\n{history_text}")
|
||
|
||
prompt_lines.extend([
|
||
"",
|
||
f"AVAILABLE TOOLS (JSON):",
|
||
tools_json,
|
||
"",
|
||
])
|
||
|
||
return "\n".join(prompt_lines)
|
||
|
||
def _build_json_compiler_prompt(self, thinker_result: str) -> str:
|
||
base_prompt = self._prompts.get("json_compiler", "")
|
||
|
||
prompt_lines = [
|
||
base_prompt,
|
||
"",
|
||
"Thinker's plan:",
|
||
thinker_result,
|
||
"",
|
||
]
|
||
|
||
return "\n".join(prompt_lines)
|
||
|
||
def _determine_mode_from_context(self, context: dict[str, Any]) -> str:
|
||
"""Legacy method - kept for compatibility"""
|
||
task_summary = str(context.get("task_summary", "")).lower()
|
||
keywords = ["запусти", "выполни", "создай", "напиши", "удали", "run", "execute", "create"]
|
||
for kw in keywords:
|
||
if kw in task_summary:
|
||
return "execution"
|
||
return "conversation"
|
||
|
||
async def _classify_intent(self, task_summary: str) -> str:
|
||
"""LLM-based intent classification"""
|
||
if self._intent_classifier == "orchestrator" and self._orchestrator:
|
||
classifier_model = self._orchestrator
|
||
else:
|
||
classifier_model = self._thinker
|
||
|
||
if not classifier_model:
|
||
logger.warning("No classifier model available, using default")
|
||
return "conversation"
|
||
|
||
classification_prompt = f"""Классифицируй запрос пользователя: "{task_summary}"
|
||
|
||
Правила:
|
||
- execution: пользователь ХОЧЕТ выполнить действие (проверить, запустить, создать, удалить, найти, прочитать, записать)
|
||
- conversation: пользователь просто отвечает, задаёт вопрос или хочет информацию
|
||
- clarification_needed: непонятно что делать
|
||
|
||
Ответь ОДНИМ словом: execution / conversation / clarification_needed"""
|
||
|
||
try:
|
||
result = await classifier_model.generate(classification_prompt)
|
||
result = result.strip().lower()
|
||
|
||
# Extract first word - LLM often adds explanation
|
||
first_word = result.split()[0] if result.split() else ""
|
||
|
||
# Validate result is one of allowed values
|
||
allowed = {"execution", "conversation", "clarification_needed"}
|
||
if first_word in allowed:
|
||
logger.info(f"Intent classified: {first_word} for task: {task_summary}")
|
||
return first_word
|
||
|
||
if result in allowed:
|
||
logger.info(f"Intent classified: {result} for task: {task_summary}")
|
||
return result
|
||
|
||
logger.warning(f"Invalid classification result: {result}, defaulting to conversation")
|
||
return "conversation"
|
||
except Exception as e:
|
||
logger.warning(f"Intent classification failed: {e}, defaulting to conversation")
|
||
return "conversation"
|
||
|
||
def _validate_directive(self, output: str, mode_hint: str) -> ExecutionDirective | None:
|
||
if not output:
|
||
return None
|
||
|
||
try:
|
||
json_start = output.find("{")
|
||
json_end = output.rfind("}") + 1
|
||
if json_start < 0 or json_end <= 0:
|
||
return None
|
||
|
||
json_str = output[json_start:json_end]
|
||
data = json.loads(json_str)
|
||
|
||
if "type" not in data:
|
||
return None
|
||
|
||
msg_type = data.get("type", "")
|
||
payload = data.get("payload", {})
|
||
|
||
if msg_type == "step" and "tool" in payload:
|
||
tool = payload.get("tool", "")
|
||
args = payload.get("args", {})
|
||
payload = {"tool": tool, "args": args}
|
||
|
||
if msg_type == "plan":
|
||
payload = {"steps": payload.get("steps", [])}
|
||
|
||
return ExecutionDirective(
|
||
type=msg_type,
|
||
payload=payload,
|
||
confidence=data.get("confidence", 0.9),
|
||
reason=data.get("reason", ""),
|
||
)
|
||
except (json.JSONDecodeError, ValueError, TypeError) as e:
|
||
logger.warning(f"Directive JSON validation failed: {e}")
|
||
return None
|
||
|
||
def _guard_rail_check(self, directive: ExecutionDirective) -> ExecutionDirective:
|
||
tool_name = directive.payload.get("tool", "")
|
||
if tool_name in {"shell_exec", "file_write", "file_delete"}:
|
||
return ExecutionDirective(
|
||
type=directive.type,
|
||
payload=directive.payload,
|
||
requires_permission=True,
|
||
confidence=directive.confidence,
|
||
reason=directive.reason,
|
||
)
|
||
return directive
|
||
|
||
def _add_thinker_feedback(self, prompt: str, error: str, attempt: int) -> str:
|
||
feedback = f"\n[ATTEMPT {attempt + 1} FAILED: {error}]\n"
|
||
feedback += "Provide a valid semantic plan.\n"
|
||
return prompt + feedback
|
||
|
||
def _emit_event(
|
||
self,
|
||
event_type: str,
|
||
payload: dict[str, Any],
|
||
task_id: str | None,
|
||
session_id: str | None,
|
||
) -> None:
|
||
if self._event_bus and task_id:
|
||
from app.core.contracts import RuntimeEvent
|
||
event = RuntimeEvent(
|
||
task_id=task_id,
|
||
session_id=session_id or "unknown",
|
||
sequence=self._event_bus.next_sequence(task_id),
|
||
type=event_type,
|
||
payload=payload,
|
||
)
|
||
self._event_bus.publish(event)
|
||
|
||
SYS_UTIL_PROMPT = None
|
||
|
||
async def _fix_invalid_json(self, invalid_result: str, attempt: int, task_id: str | None, session_id: str | None) -> str | None:
|
||
"""Try to fix invalid JSON using sys_util model."""
|
||
if not self._sys_util:
|
||
return None
|
||
|
||
first_brace = invalid_result.find('{')
|
||
last_brace = invalid_result.rfind('}')
|
||
if first_brace < 0 or last_brace <= first_brace:
|
||
return None
|
||
|
||
truncated_json = invalid_result[first_brace:last_brace + 1]
|
||
|
||
error_msg = ""
|
||
try:
|
||
json.loads(truncated_json)
|
||
except json.JSONDecodeError as e:
|
||
error_msg = str(e)
|
||
|
||
sys_util_prompt = (
|
||
self._prompts.get("sys_util")
|
||
if self._prompts
|
||
else self.SYS_UTIL_PROMPT or (
|
||
"You are a STRICT JSON repair engine. "
|
||
"Your job is ONLY to fix invalid JSON syntax. "
|
||
"You MUST output valid JSON or nothing else."
|
||
)
|
||
)
|
||
fix_prompt = f"""{sys_util_prompt}
|
||
|
||
{error_msg}
|
||
|
||
Fixed JSON:"""
|
||
|
||
try:
|
||
logger.info(f"JSON fix using sys_util model (attempt {attempt + 1})")
|
||
fixed_result = await self._sys_util.generate(fix_prompt)
|
||
|
||
fixed_first = fixed_result.find('{')
|
||
fixed_last = fixed_result.rfind('}')
|
||
if fixed_first >= 0 and fixed_last > fixed_first:
|
||
return fixed_result[fixed_first:fixed_last + 1]
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.warning(f"JSON fix failed: {e}")
|
||
return None
|