ducklm/duck_core/context_builder.py

376 lines
14 KiB
Python

from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from typing import Any
from duck_core.tasks.state import TaskState
logger = logging.getLogger(__name__)
# Approximate tokens per character (rough heuristic: ~4 chars per token)
_CHARS_PER_TOKEN = 4
@dataclass
class RecallDecision:
records: list[dict[str, str]]
sufficient_to_answer: bool = False
reasoning: str = ""
def estimate_tokens(text: str) -> int:
"""Rough token estimate based on character count."""
return max(len(text) // _CHARS_PER_TOKEN, 1)
def estimate_messages_tokens(messages: list[dict[str, str]]) -> int:
"""Estimate total token count for a list of messages."""
total = 0
for msg in messages:
total += estimate_tokens(msg.get("content", "")) + 4 # role + formatting overhead
return total
class ContextBuilder:
"""Builds context messages with token budget awareness.
Priority order (highest first):
1. Current user message (always kept)
2. Active task state
3. Selected skill summary
4. Recent tool observations
5. Relevant memory
6. Summarized old events / history
7. Full conversation history (remaining budget)
"""
def __init__(
self,
max_input_tokens: int = 49152,
max_memory_tokens: int = 8000,
max_history_tokens: int = 12000,
summary_role: str = "summary",
recall_role: str = "recall",
model_client: Any | None = None,
):
self.max_input_tokens = max_input_tokens
self.max_memory_tokens = max_memory_tokens
self.max_history_tokens = max_history_tokens
self.summary_role = summary_role
self.recall_role = recall_role
self._model_client = model_client
async def recall_relevant_memory(
self,
query: str,
memory_records: list[dict[str, str]],
) -> list[dict[str, str]]:
"""Use recall-role LLM to filter memory records by relevance.
Returns only the memories that are relevant to the query.
Falls back to returning all records if LLM is unavailable.
"""
return (await self.recall_relevant_memory_decision(query, memory_records)).records
async def recall_relevant_memory_decision(
self,
query: str,
memory_records: list[dict[str, str]],
) -> RecallDecision:
if not memory_records or self._model_client is None:
return RecallDecision(records=memory_records, sufficient_to_answer=False)
try:
return await self._llm_recall(query, memory_records)
except Exception as exc:
logger.warning("Recall failed, using all memories: %s", exc)
return RecallDecision(records=memory_records, sufficient_to_answer=False)
async def _llm_recall(
self,
query: str,
memory_records: list[dict[str, str]],
) -> RecallDecision:
"""Call recall-role LLM to identify relevant memories."""
memories_text = "\n".join(
f"[{m.get('memory_id', i)}] {m.get('text', '')}"
for i, m in enumerate(memory_records)
)
response = await self._model_client.chat(
self.recall_role,
[{
"role": "user",
"content": (
f"User query: {query}\n\n"
f"Available memories:\n{memories_text}"
),
}],
response_format={
"type": "json_schema",
"json_schema": {
"name": "recall_result",
"schema": {
"type": "object",
"required": ["relevant_ids", "reasoning"],
"additionalProperties": False,
"properties": {
"relevant_ids": {
"type": "array",
"items": {"type": "string"},
},
"sufficient_to_answer": {
"type": "boolean",
"description": "True when selected memories are enough to answer without local tools/actions.",
},
"reasoning": {"type": "string"},
},
},
"strict": True,
},
},
)
data = json.loads(response.content)
relevant_ids = set(data.get("relevant_ids", []))
if not relevant_ids:
return RecallDecision(records=[], sufficient_to_answer=False, reasoning=str(data.get("reasoning", "")))
records = [
m for i, m in enumerate(memory_records)
if m.get("memory_id", str(i)) in relevant_ids
]
return RecallDecision(
records=records,
sufficient_to_answer=bool(data.get("sufficient_to_answer", False)) and bool(records),
reasoning=str(data.get("reasoning", "")),
)
def build_basic_messages(
self,
task: TaskState,
history_messages: list[dict[str, str]] | None = None,
memory_records: list[dict[str, str]] | None = None,
tool_observations: list[dict[str, Any]] | None = None,
skill_summary: str | None = None,
) -> list[dict[str, str]]:
"""Build context messages respecting token budget.
Args:
task: Current task state.
history_messages: Previous conversation messages.
memory_records: Relevant memory records.
tool_observations: Recent tool call results.
skill_summary: Selected skill description.
"""
messages: list[dict[str, str]] = []
budget_remaining = self.max_input_tokens
# 1. System-level context (memory + skill)
system_parts: list[str] = []
# Memory records
if memory_records:
memory_text = self._format_memory(memory_records)
mem_tokens = estimate_tokens(memory_text)
if mem_tokens > self.max_memory_tokens:
# Truncate memory to fit budget
memory_text = self._truncate_text(memory_text, self.max_memory_tokens)
system_parts.append(memory_text)
budget_remaining -= estimate_tokens(memory_text)
# Skill summary
if skill_summary:
skill_text = f"Active skill:\n{skill_summary}"
system_parts.append(skill_text)
budget_remaining -= estimate_tokens(skill_text)
if system_parts:
messages.append({
"role": "system",
"content": "\n\n".join(system_parts),
})
# 2. Tool observations (recent, high priority)
if tool_observations:
obs_text = "Tool observations:\n" + self._format_observations(tool_observations)
obs_tokens = estimate_tokens(obs_text)
if obs_tokens > budget_remaining * 0.4:
# Don't let observations consume more than 40% of remaining budget
obs_text = self._truncate_text(obs_text, int(budget_remaining * 0.4))
obs_tokens = estimate_tokens(obs_text)
messages.append({"role": "user", "content": obs_text})
budget_remaining -= obs_tokens
# 3. Conversation history (lower priority, may be summarized)
if history_messages:
hist_tokens = estimate_messages_tokens(history_messages)
if hist_tokens <= budget_remaining:
messages.extend(history_messages)
budget_remaining -= hist_tokens
elif budget_remaining > 100:
# Summarize old history if we have some budget left
summarized = self._summarize_history(history_messages, budget_remaining)
if summarized:
messages.append({
"role": "system",
"content": f"Conversation summary:\n{summarized}",
})
# else: no budget for history at all
# 4. Current user message (always last, always included)
messages.append({
"role": "user",
"content": task.user_message,
})
return messages
async def build_async_messages(
self,
task: TaskState,
history_messages: list[dict[str, str]] | None = None,
memory_records: list[dict[str, str]] | None = None,
tool_observations: list[dict[str, Any]] | None = None,
skill_summary: str | None = None,
) -> list[dict[str, str]]:
"""Async context builder variant that can use LLM summarization."""
messages: list[dict[str, str]] = []
budget_remaining = self.max_input_tokens
system_parts: list[str] = []
if memory_records:
memory_text = self._format_memory(memory_records)
mem_tokens = estimate_tokens(memory_text)
if mem_tokens > self.max_memory_tokens:
memory_text = self._truncate_text(memory_text, self.max_memory_tokens)
system_parts.append(memory_text)
budget_remaining -= estimate_tokens(memory_text)
if skill_summary:
skill_text = f"Active skill:\n{skill_summary}"
system_parts.append(skill_text)
budget_remaining -= estimate_tokens(skill_text)
if system_parts:
messages.append({"role": "system", "content": "\n\n".join(system_parts)})
if tool_observations:
obs_text = "Tool observations:\n" + self._format_observations(tool_observations)
obs_tokens = estimate_tokens(obs_text)
if obs_tokens > budget_remaining * 0.4:
obs_text = self._truncate_text(obs_text, int(budget_remaining * 0.4))
obs_tokens = estimate_tokens(obs_text)
messages.append({"role": "user", "content": obs_text})
budget_remaining -= obs_tokens
if history_messages:
hist_tokens = estimate_messages_tokens(history_messages)
if hist_tokens <= budget_remaining:
messages.extend(history_messages)
budget_remaining -= hist_tokens
elif budget_remaining > 100:
summarized = await self._summarize_history_async(
history_messages, budget_remaining
)
if summarized:
messages.append({
"role": "system",
"content": f"Conversation summary:\n{summarized}",
})
messages.append({"role": "user", "content": task.user_message})
return messages
def _format_memory(self, records: list[dict[str, str]]) -> str:
lines = [
f"- {record.get('scope', 'memory')}: {record.get('text', '')}"
for record in records
if record.get("text")
]
return "Relevant memory:\n" + "\n".join(lines) if lines else ""
def _format_observations(self, observations: list[dict[str, Any]]) -> str:
parts = []
for obs in observations:
tool = obs.get("tool", "unknown")
result = obs.get("result", {})
ok = result.get("ok", False)
output = result.get("output", "")
error = result.get("error", "")
status = "ok" if ok else "error"
part = f"- {tool} ({status})"
if output:
part += f"\n output: {output[:200]}"
if error:
part += f"\n error: {error[:200]}"
parts.append(part)
return "\n".join(parts)
def _truncate_text(self, text: str, max_tokens: int) -> str:
"""Truncate text to fit within max_tokens."""
max_chars = max_tokens * _CHARS_PER_TOKEN
if len(text) <= max_chars:
return text
return text[:max_chars] + "\n... (truncated)"
def _summarize_history(
self,
history: list[dict[str, str]],
budget_tokens: int,
) -> str | None:
"""Summarize conversation history to fit budget.
Synchronous callers use deterministic truncation. Runtime code should
call build_async_messages() when LLM summarization is desired.
"""
if not history:
return None
result = []
remaining = budget_tokens
for msg in reversed(history):
tokens = estimate_tokens(msg.get("content", "")) + 4
if tokens > remaining:
break
result.append(f"{msg['role']}: {msg['content'][:100]}")
remaining -= tokens
return "\n".join(reversed(result)) if result else None
async def _summarize_history_async(
self,
history: list[dict[str, str]],
budget_tokens: int,
) -> str | None:
if self._model_client is None:
return self._summarize_history(history, budget_tokens)
summarized = await self._llm_summarize_history(history, budget_tokens)
return summarized or self._summarize_history(history, budget_tokens)
async def _llm_summarize_history(
self,
history: list[dict[str, str]],
budget_tokens: int,
) -> str | None:
"""Use summary-role LLM to compress history."""
try:
history_text = "\n".join(
f"{m['role']}: {m.get('content', '')}" for m in history
)
response = await self._model_client.chat(
self.summary_role,
[{
"role": "user",
"content": (
"Summarize this conversation history. Keep decisions, outcomes, "
"and key facts. Be concise.\n\n"
+ history_text
),
}],
)
summary = response.content
# Ensure summary fits budget
return self._truncate_text(summary, budget_tokens)
except Exception as exc:
logger.warning("History summarization failed: %s", exc)
return None