376 lines
14 KiB
Python
376 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from duck_core.tasks.state import TaskState
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Approximate tokens per character (rough heuristic: ~4 chars per token)
|
|
_CHARS_PER_TOKEN = 4
|
|
|
|
|
|
@dataclass
|
|
class RecallDecision:
|
|
records: list[dict[str, str]]
|
|
sufficient_to_answer: bool = False
|
|
reasoning: str = ""
|
|
|
|
|
|
def estimate_tokens(text: str) -> int:
|
|
"""Rough token estimate based on character count."""
|
|
return max(len(text) // _CHARS_PER_TOKEN, 1)
|
|
|
|
|
|
def estimate_messages_tokens(messages: list[dict[str, str]]) -> int:
|
|
"""Estimate total token count for a list of messages."""
|
|
total = 0
|
|
for msg in messages:
|
|
total += estimate_tokens(msg.get("content", "")) + 4 # role + formatting overhead
|
|
return total
|
|
|
|
|
|
class ContextBuilder:
|
|
"""Builds context messages with token budget awareness.
|
|
|
|
Priority order (highest first):
|
|
1. Current user message (always kept)
|
|
2. Active task state
|
|
3. Selected skill summary
|
|
4. Recent tool observations
|
|
5. Relevant memory
|
|
6. Summarized old events / history
|
|
7. Full conversation history (remaining budget)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_input_tokens: int = 49152,
|
|
max_memory_tokens: int = 8000,
|
|
max_history_tokens: int = 12000,
|
|
summary_role: str = "summary",
|
|
recall_role: str = "recall",
|
|
model_client: Any | None = None,
|
|
):
|
|
self.max_input_tokens = max_input_tokens
|
|
self.max_memory_tokens = max_memory_tokens
|
|
self.max_history_tokens = max_history_tokens
|
|
self.summary_role = summary_role
|
|
self.recall_role = recall_role
|
|
self._model_client = model_client
|
|
|
|
async def recall_relevant_memory(
|
|
self,
|
|
query: str,
|
|
memory_records: list[dict[str, str]],
|
|
) -> list[dict[str, str]]:
|
|
"""Use recall-role LLM to filter memory records by relevance.
|
|
|
|
Returns only the memories that are relevant to the query.
|
|
Falls back to returning all records if LLM is unavailable.
|
|
"""
|
|
return (await self.recall_relevant_memory_decision(query, memory_records)).records
|
|
|
|
async def recall_relevant_memory_decision(
|
|
self,
|
|
query: str,
|
|
memory_records: list[dict[str, str]],
|
|
) -> RecallDecision:
|
|
if not memory_records or self._model_client is None:
|
|
return RecallDecision(records=memory_records, sufficient_to_answer=False)
|
|
|
|
try:
|
|
return await self._llm_recall(query, memory_records)
|
|
except Exception as exc:
|
|
logger.warning("Recall failed, using all memories: %s", exc)
|
|
return RecallDecision(records=memory_records, sufficient_to_answer=False)
|
|
|
|
async def _llm_recall(
|
|
self,
|
|
query: str,
|
|
memory_records: list[dict[str, str]],
|
|
) -> RecallDecision:
|
|
"""Call recall-role LLM to identify relevant memories."""
|
|
memories_text = "\n".join(
|
|
f"[{m.get('memory_id', i)}] {m.get('text', '')}"
|
|
for i, m in enumerate(memory_records)
|
|
)
|
|
response = await self._model_client.chat(
|
|
self.recall_role,
|
|
[{
|
|
"role": "user",
|
|
"content": (
|
|
f"User query: {query}\n\n"
|
|
f"Available memories:\n{memories_text}"
|
|
),
|
|
}],
|
|
response_format={
|
|
"type": "json_schema",
|
|
"json_schema": {
|
|
"name": "recall_result",
|
|
"schema": {
|
|
"type": "object",
|
|
"required": ["relevant_ids", "reasoning"],
|
|
"additionalProperties": False,
|
|
"properties": {
|
|
"relevant_ids": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
},
|
|
"sufficient_to_answer": {
|
|
"type": "boolean",
|
|
"description": "True when selected memories are enough to answer without local tools/actions.",
|
|
},
|
|
"reasoning": {"type": "string"},
|
|
},
|
|
},
|
|
"strict": True,
|
|
},
|
|
},
|
|
)
|
|
data = json.loads(response.content)
|
|
relevant_ids = set(data.get("relevant_ids", []))
|
|
if not relevant_ids:
|
|
return RecallDecision(records=[], sufficient_to_answer=False, reasoning=str(data.get("reasoning", "")))
|
|
records = [
|
|
m for i, m in enumerate(memory_records)
|
|
if m.get("memory_id", str(i)) in relevant_ids
|
|
]
|
|
return RecallDecision(
|
|
records=records,
|
|
sufficient_to_answer=bool(data.get("sufficient_to_answer", False)) and bool(records),
|
|
reasoning=str(data.get("reasoning", "")),
|
|
)
|
|
|
|
def build_basic_messages(
|
|
self,
|
|
task: TaskState,
|
|
history_messages: list[dict[str, str]] | None = None,
|
|
memory_records: list[dict[str, str]] | None = None,
|
|
tool_observations: list[dict[str, Any]] | None = None,
|
|
skill_summary: str | None = None,
|
|
) -> list[dict[str, str]]:
|
|
"""Build context messages respecting token budget.
|
|
|
|
Args:
|
|
task: Current task state.
|
|
history_messages: Previous conversation messages.
|
|
memory_records: Relevant memory records.
|
|
tool_observations: Recent tool call results.
|
|
skill_summary: Selected skill description.
|
|
"""
|
|
messages: list[dict[str, str]] = []
|
|
budget_remaining = self.max_input_tokens
|
|
|
|
# 1. System-level context (memory + skill)
|
|
system_parts: list[str] = []
|
|
|
|
# Memory records
|
|
if memory_records:
|
|
memory_text = self._format_memory(memory_records)
|
|
mem_tokens = estimate_tokens(memory_text)
|
|
if mem_tokens > self.max_memory_tokens:
|
|
# Truncate memory to fit budget
|
|
memory_text = self._truncate_text(memory_text, self.max_memory_tokens)
|
|
system_parts.append(memory_text)
|
|
budget_remaining -= estimate_tokens(memory_text)
|
|
|
|
# Skill summary
|
|
if skill_summary:
|
|
skill_text = f"Active skill:\n{skill_summary}"
|
|
system_parts.append(skill_text)
|
|
budget_remaining -= estimate_tokens(skill_text)
|
|
|
|
if system_parts:
|
|
messages.append({
|
|
"role": "system",
|
|
"content": "\n\n".join(system_parts),
|
|
})
|
|
|
|
# 2. Tool observations (recent, high priority)
|
|
if tool_observations:
|
|
obs_text = "Tool observations:\n" + self._format_observations(tool_observations)
|
|
obs_tokens = estimate_tokens(obs_text)
|
|
if obs_tokens > budget_remaining * 0.4:
|
|
# Don't let observations consume more than 40% of remaining budget
|
|
obs_text = self._truncate_text(obs_text, int(budget_remaining * 0.4))
|
|
obs_tokens = estimate_tokens(obs_text)
|
|
messages.append({"role": "user", "content": obs_text})
|
|
budget_remaining -= obs_tokens
|
|
|
|
# 3. Conversation history (lower priority, may be summarized)
|
|
if history_messages:
|
|
hist_tokens = estimate_messages_tokens(history_messages)
|
|
if hist_tokens <= budget_remaining:
|
|
messages.extend(history_messages)
|
|
budget_remaining -= hist_tokens
|
|
elif budget_remaining > 100:
|
|
# Summarize old history if we have some budget left
|
|
summarized = self._summarize_history(history_messages, budget_remaining)
|
|
if summarized:
|
|
messages.append({
|
|
"role": "system",
|
|
"content": f"Conversation summary:\n{summarized}",
|
|
})
|
|
# else: no budget for history at all
|
|
|
|
# 4. Current user message (always last, always included)
|
|
messages.append({
|
|
"role": "user",
|
|
"content": task.user_message,
|
|
})
|
|
|
|
return messages
|
|
|
|
async def build_async_messages(
|
|
self,
|
|
task: TaskState,
|
|
history_messages: list[dict[str, str]] | None = None,
|
|
memory_records: list[dict[str, str]] | None = None,
|
|
tool_observations: list[dict[str, Any]] | None = None,
|
|
skill_summary: str | None = None,
|
|
) -> list[dict[str, str]]:
|
|
"""Async context builder variant that can use LLM summarization."""
|
|
messages: list[dict[str, str]] = []
|
|
budget_remaining = self.max_input_tokens
|
|
|
|
system_parts: list[str] = []
|
|
if memory_records:
|
|
memory_text = self._format_memory(memory_records)
|
|
mem_tokens = estimate_tokens(memory_text)
|
|
if mem_tokens > self.max_memory_tokens:
|
|
memory_text = self._truncate_text(memory_text, self.max_memory_tokens)
|
|
system_parts.append(memory_text)
|
|
budget_remaining -= estimate_tokens(memory_text)
|
|
|
|
if skill_summary:
|
|
skill_text = f"Active skill:\n{skill_summary}"
|
|
system_parts.append(skill_text)
|
|
budget_remaining -= estimate_tokens(skill_text)
|
|
|
|
if system_parts:
|
|
messages.append({"role": "system", "content": "\n\n".join(system_parts)})
|
|
|
|
if tool_observations:
|
|
obs_text = "Tool observations:\n" + self._format_observations(tool_observations)
|
|
obs_tokens = estimate_tokens(obs_text)
|
|
if obs_tokens > budget_remaining * 0.4:
|
|
obs_text = self._truncate_text(obs_text, int(budget_remaining * 0.4))
|
|
obs_tokens = estimate_tokens(obs_text)
|
|
messages.append({"role": "user", "content": obs_text})
|
|
budget_remaining -= obs_tokens
|
|
|
|
if history_messages:
|
|
hist_tokens = estimate_messages_tokens(history_messages)
|
|
if hist_tokens <= budget_remaining:
|
|
messages.extend(history_messages)
|
|
budget_remaining -= hist_tokens
|
|
elif budget_remaining > 100:
|
|
summarized = await self._summarize_history_async(
|
|
history_messages, budget_remaining
|
|
)
|
|
if summarized:
|
|
messages.append({
|
|
"role": "system",
|
|
"content": f"Conversation summary:\n{summarized}",
|
|
})
|
|
|
|
messages.append({"role": "user", "content": task.user_message})
|
|
return messages
|
|
|
|
def _format_memory(self, records: list[dict[str, str]]) -> str:
|
|
lines = [
|
|
f"- {record.get('scope', 'memory')}: {record.get('text', '')}"
|
|
for record in records
|
|
if record.get("text")
|
|
]
|
|
return "Relevant memory:\n" + "\n".join(lines) if lines else ""
|
|
|
|
def _format_observations(self, observations: list[dict[str, Any]]) -> str:
|
|
parts = []
|
|
for obs in observations:
|
|
tool = obs.get("tool", "unknown")
|
|
result = obs.get("result", {})
|
|
ok = result.get("ok", False)
|
|
output = result.get("output", "")
|
|
error = result.get("error", "")
|
|
status = "ok" if ok else "error"
|
|
part = f"- {tool} ({status})"
|
|
if output:
|
|
part += f"\n output: {output[:200]}"
|
|
if error:
|
|
part += f"\n error: {error[:200]}"
|
|
parts.append(part)
|
|
return "\n".join(parts)
|
|
|
|
def _truncate_text(self, text: str, max_tokens: int) -> str:
|
|
"""Truncate text to fit within max_tokens."""
|
|
max_chars = max_tokens * _CHARS_PER_TOKEN
|
|
if len(text) <= max_chars:
|
|
return text
|
|
return text[:max_chars] + "\n... (truncated)"
|
|
|
|
def _summarize_history(
|
|
self,
|
|
history: list[dict[str, str]],
|
|
budget_tokens: int,
|
|
) -> str | None:
|
|
"""Summarize conversation history to fit budget.
|
|
|
|
Synchronous callers use deterministic truncation. Runtime code should
|
|
call build_async_messages() when LLM summarization is desired.
|
|
"""
|
|
if not history:
|
|
return None
|
|
|
|
result = []
|
|
remaining = budget_tokens
|
|
for msg in reversed(history):
|
|
tokens = estimate_tokens(msg.get("content", "")) + 4
|
|
if tokens > remaining:
|
|
break
|
|
result.append(f"{msg['role']}: {msg['content'][:100]}")
|
|
remaining -= tokens
|
|
return "\n".join(reversed(result)) if result else None
|
|
|
|
async def _summarize_history_async(
|
|
self,
|
|
history: list[dict[str, str]],
|
|
budget_tokens: int,
|
|
) -> str | None:
|
|
if self._model_client is None:
|
|
return self._summarize_history(history, budget_tokens)
|
|
summarized = await self._llm_summarize_history(history, budget_tokens)
|
|
return summarized or self._summarize_history(history, budget_tokens)
|
|
|
|
async def _llm_summarize_history(
|
|
self,
|
|
history: list[dict[str, str]],
|
|
budget_tokens: int,
|
|
) -> str | None:
|
|
"""Use summary-role LLM to compress history."""
|
|
try:
|
|
history_text = "\n".join(
|
|
f"{m['role']}: {m.get('content', '')}" for m in history
|
|
)
|
|
response = await self._model_client.chat(
|
|
self.summary_role,
|
|
[{
|
|
"role": "user",
|
|
"content": (
|
|
"Summarize this conversation history. Keep decisions, outcomes, "
|
|
"and key facts. Be concise.\n\n"
|
|
+ history_text
|
|
),
|
|
}],
|
|
)
|
|
summary = response.content
|
|
# Ensure summary fits budget
|
|
return self._truncate_text(summary, budget_tokens)
|
|
except Exception as exc:
|
|
logger.warning("History summarization failed: %s", exc)
|
|
return None
|