172 lines
5.3 KiB
Python
172 lines
5.3 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
from app.core.contracts import TaskCheckpoint, UserTask
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_BUDGETS = {
|
|
"system": 512,
|
|
"task": 512,
|
|
"memory": 2048,
|
|
"execution": 2048,
|
|
"tools": 1024,
|
|
"safety": 512,
|
|
}
|
|
|
|
|
|
class ContextBuilder:
|
|
def __init__(
|
|
self,
|
|
memory_interface=None,
|
|
tool_registry=None,
|
|
config: dict[str, Any] | None = None,
|
|
) -> None:
|
|
self._memory = memory_interface
|
|
self._tool_registry = tool_registry
|
|
self._config = config or {}
|
|
self._max_tokens = self._config.get("max_context_tokens", 8192)
|
|
self._budgets = self._config.get("context_budgets", DEFAULT_BUDGETS)
|
|
self._reserve_pct = self._config.get("reserve_for_generation_pct", 25)
|
|
|
|
def build(
|
|
self,
|
|
task: UserTask,
|
|
checkpoint: TaskCheckpoint | None = None,
|
|
query: str | None = None,
|
|
) -> dict[str, Any]:
|
|
task_summary = task.input
|
|
search_query = query or task_summary
|
|
session_id = task.session_id
|
|
|
|
memory_context = []
|
|
if self._memory:
|
|
memory_context = self._retrieve_memory(search_query, session_id=session_id)
|
|
|
|
budgets = self._calculate_budgets()
|
|
reserved = self._reserve_for_generation()
|
|
|
|
system_budget = budgets.get("system", 512)
|
|
task_budget = budgets.get("task", 512)
|
|
safety_budget = budgets.get("safety", 512)
|
|
memory_budget = budgets.get("memory", 2048)
|
|
|
|
truncated_memory = self._truncate_memory(
|
|
memory_context, memory_budget
|
|
)
|
|
|
|
# Get session history for follow-up context
|
|
session_history = self._get_session_history(session_id)
|
|
|
|
context = {
|
|
"system_prompt": "",
|
|
"task_summary": task_summary[:task_budget],
|
|
"task_context": task.context,
|
|
"memory_context": truncated_memory,
|
|
"session_history": session_history,
|
|
"execution_context": checkpoint.model_dump() if checkpoint else {},
|
|
"tool_context": self._get_tool_context(),
|
|
"safety_context": {},
|
|
"constraints": {
|
|
"budgets": budgets,
|
|
"reserved_for_generation": reserved,
|
|
"original_memory_count": len(memory_context),
|
|
"truncated_memory_count": len(truncated_memory),
|
|
},
|
|
}
|
|
|
|
return context
|
|
|
|
def _get_tool_context(self) -> list[dict[str, Any]]:
|
|
"""Expose available tools to orchestrator."""
|
|
if not self._tool_registry:
|
|
return []
|
|
|
|
tools = []
|
|
for name in self._tool_registry.list_names():
|
|
tool = self._tool_registry.get(name)
|
|
tools.append({
|
|
"name": name,
|
|
"description": getattr(tool, "description", ""),
|
|
})
|
|
return tools
|
|
|
|
def _calculate_budgets(self) -> dict[str, int]:
|
|
return dict(self._budgets)
|
|
|
|
def _reserve_for_generation(self) -> int:
|
|
return int(self._max_tokens * self._reserve_pct / 100)
|
|
|
|
def _retrieve_memory(
|
|
self,
|
|
query: str,
|
|
session_id: str | None = None,
|
|
top_k: int = 5,
|
|
) -> list[dict[str, Any]]:
|
|
if not self._memory:
|
|
return []
|
|
|
|
try:
|
|
results = self._memory.search(query, top_k=top_k, session_id=session_id)
|
|
return [
|
|
{
|
|
"id": entry.id,
|
|
"text": entry.text,
|
|
"kind": entry.kind,
|
|
"source": entry.source,
|
|
"weight": entry.weight,
|
|
"score": score,
|
|
}
|
|
for entry, score in results
|
|
]
|
|
except Exception as e:
|
|
logger.warning(f"Memory retrieval failed: {e}")
|
|
return []
|
|
|
|
def _get_session_history(self, session_id: str | None = None) -> list[dict[str, Any]]:
|
|
"""Get previous task summaries from the same session for context."""
|
|
if not self._memory or not session_id:
|
|
return []
|
|
|
|
try:
|
|
# Get recent entries from same session
|
|
entries = self._memory.get_by_session(session_id, limit=5)
|
|
# Filter to only task summaries
|
|
summaries = [
|
|
{
|
|
"id": entry.id,
|
|
"text": entry.text,
|
|
"kind": entry.kind,
|
|
"source": entry.source,
|
|
"weight": entry.weight,
|
|
}
|
|
for entry in entries
|
|
if entry.kind in ("summary", "tool_result")
|
|
]
|
|
return summaries
|
|
except Exception as e:
|
|
logger.warning(f"Session history retrieval failed: {e}")
|
|
return []
|
|
|
|
def _truncate_memory(
|
|
self,
|
|
memory_context: list[dict[str, Any]],
|
|
budget: int,
|
|
) -> list[dict[str, Any]]:
|
|
if not memory_context:
|
|
return []
|
|
|
|
estimated_per_entry = 50
|
|
max_entries = max(budget // estimated_per_entry, 1)
|
|
|
|
if len(memory_context) > max_entries:
|
|
return memory_context[:max_entries]
|
|
|
|
return memory_context
|
|
|
|
def estimate_tokens(self, text: str) -> int:
|
|
if not text:
|
|
return 0
|
|
return len(text.split()) * 4 // 3 |