from __future__ import annotations import json import logging from typing import Any from duck_core.tasks.state import TaskState logger = logging.getLogger(__name__) # Approximate tokens per character (rough heuristic: ~4 chars per token) _CHARS_PER_TOKEN = 4 def estimate_tokens(text: str) -> int: """Rough token estimate based on character count.""" return max(len(text) // _CHARS_PER_TOKEN, 1) def estimate_messages_tokens(messages: list[dict[str, str]]) -> int: """Estimate total token count for a list of messages.""" total = 0 for msg in messages: total += estimate_tokens(msg.get("content", "")) + 4 # role + formatting overhead return total class ContextBuilder: """Builds context messages with token budget awareness. Priority order (highest first): 1. Current user message (always kept) 2. Active task state 3. Selected skill summary 4. Recent tool observations 5. Relevant memory 6. Summarized old events / history 7. Full conversation history (remaining budget) """ def __init__( self, max_input_tokens: int = 49152, max_memory_tokens: int = 8000, max_history_tokens: int = 12000, summary_role: str = "summary", recall_role: str = "recall", model_client: Any | None = None, ): self.max_input_tokens = max_input_tokens self.max_memory_tokens = max_memory_tokens self.max_history_tokens = max_history_tokens self.summary_role = summary_role self.recall_role = recall_role self._model_client = model_client async def recall_relevant_memory( self, query: str, memory_records: list[dict[str, str]], ) -> list[dict[str, str]]: """Use recall-role LLM to filter memory records by relevance. Returns only the memories that are relevant to the query. Falls back to returning all records if LLM is unavailable. """ if not memory_records or self._model_client is None: return memory_records try: return await self._llm_recall(query, memory_records) except Exception as exc: logger.warning("Recall failed, using all memories: %s", exc) return memory_records async def _llm_recall( self, query: str, memory_records: list[dict[str, str]], ) -> list[dict[str, str]]: """Call recall-role LLM to identify relevant memories.""" memories_text = "\n".join( f"[{m.get('memory_id', i)}] {m.get('text', '')}" for i, m in enumerate(memory_records) ) response = await self._model_client.chat( self.recall_role, [{ "role": "user", "content": ( f"User query: {query}\n\n" f"Available memories:\n{memories_text}" ), }], response_format={ "type": "json_schema", "json_schema": { "name": "recall_result", "schema": { "type": "object", "required": ["relevant_ids", "reasoning"], "additionalProperties": False, "properties": { "relevant_ids": { "type": "array", "items": {"type": "string"}, }, "reasoning": {"type": "string"}, }, }, "strict": True, }, }, ) data = json.loads(response.content) relevant_ids = set(data.get("relevant_ids", [])) if not relevant_ids: return [] return [m for i, m in enumerate(memory_records) if m.get("memory_id", str(i)) in relevant_ids] def build_basic_messages( self, task: TaskState, history_messages: list[dict[str, str]] | None = None, memory_records: list[dict[str, str]] | None = None, tool_observations: list[dict[str, Any]] | None = None, skill_summary: str | None = None, ) -> list[dict[str, str]]: """Build context messages respecting token budget. Args: task: Current task state. history_messages: Previous conversation messages. memory_records: Relevant memory records. tool_observations: Recent tool call results. skill_summary: Selected skill description. """ messages: list[dict[str, str]] = [] budget_remaining = self.max_input_tokens # 1. System-level context (memory + skill) system_parts: list[str] = [] # Memory records if memory_records: memory_text = self._format_memory(memory_records) mem_tokens = estimate_tokens(memory_text) if mem_tokens > self.max_memory_tokens: # Truncate memory to fit budget memory_text = self._truncate_text(memory_text, self.max_memory_tokens) system_parts.append(memory_text) budget_remaining -= estimate_tokens(memory_text) # Skill summary if skill_summary: skill_text = f"Active skill:\n{skill_summary}" system_parts.append(skill_text) budget_remaining -= estimate_tokens(skill_text) if system_parts: messages.append({ "role": "system", "content": "\n\n".join(system_parts), }) # 2. Tool observations (recent, high priority) if tool_observations: obs_text = "Tool observations:\n" + self._format_observations(tool_observations) obs_tokens = estimate_tokens(obs_text) if obs_tokens > budget_remaining * 0.4: # Don't let observations consume more than 40% of remaining budget obs_text = self._truncate_text(obs_text, int(budget_remaining * 0.4)) obs_tokens = estimate_tokens(obs_text) messages.append({"role": "user", "content": obs_text}) budget_remaining -= obs_tokens # 3. Conversation history (lower priority, may be summarized) if history_messages: hist_tokens = estimate_messages_tokens(history_messages) if hist_tokens <= budget_remaining: messages.extend(history_messages) budget_remaining -= hist_tokens elif budget_remaining > 100: # Summarize old history if we have some budget left summarized = self._summarize_history(history_messages, budget_remaining) if summarized: messages.append({ "role": "system", "content": f"Conversation summary:\n{summarized}", }) # else: no budget for history at all # 4. Current user message (always last, always included) messages.append({ "role": "user", "content": task.user_message, }) return messages async def build_async_messages( self, task: TaskState, history_messages: list[dict[str, str]] | None = None, memory_records: list[dict[str, str]] | None = None, tool_observations: list[dict[str, Any]] | None = None, skill_summary: str | None = None, ) -> list[dict[str, str]]: """Async context builder variant that can use LLM summarization.""" messages: list[dict[str, str]] = [] budget_remaining = self.max_input_tokens system_parts: list[str] = [] if memory_records: memory_text = self._format_memory(memory_records) mem_tokens = estimate_tokens(memory_text) if mem_tokens > self.max_memory_tokens: memory_text = self._truncate_text(memory_text, self.max_memory_tokens) system_parts.append(memory_text) budget_remaining -= estimate_tokens(memory_text) if skill_summary: skill_text = f"Active skill:\n{skill_summary}" system_parts.append(skill_text) budget_remaining -= estimate_tokens(skill_text) if system_parts: messages.append({"role": "system", "content": "\n\n".join(system_parts)}) if tool_observations: obs_text = "Tool observations:\n" + self._format_observations(tool_observations) obs_tokens = estimate_tokens(obs_text) if obs_tokens > budget_remaining * 0.4: obs_text = self._truncate_text(obs_text, int(budget_remaining * 0.4)) obs_tokens = estimate_tokens(obs_text) messages.append({"role": "user", "content": obs_text}) budget_remaining -= obs_tokens if history_messages: hist_tokens = estimate_messages_tokens(history_messages) if hist_tokens <= budget_remaining: messages.extend(history_messages) budget_remaining -= hist_tokens elif budget_remaining > 100: summarized = await self._summarize_history_async( history_messages, budget_remaining ) if summarized: messages.append({ "role": "system", "content": f"Conversation summary:\n{summarized}", }) messages.append({"role": "user", "content": task.user_message}) return messages def _format_memory(self, records: list[dict[str, str]]) -> str: lines = [ f"- {record.get('scope', 'memory')}: {record.get('text', '')}" for record in records if record.get("text") ] return "Relevant memory:\n" + "\n".join(lines) if lines else "" def _format_observations(self, observations: list[dict[str, Any]]) -> str: parts = [] for obs in observations: tool = obs.get("tool", "unknown") result = obs.get("result", {}) ok = result.get("ok", False) output = result.get("output", "") error = result.get("error", "") status = "ok" if ok else "error" part = f"- {tool} ({status})" if output: part += f"\n output: {output[:200]}" if error: part += f"\n error: {error[:200]}" parts.append(part) return "\n".join(parts) def _truncate_text(self, text: str, max_tokens: int) -> str: """Truncate text to fit within max_tokens.""" max_chars = max_tokens * _CHARS_PER_TOKEN if len(text) <= max_chars: return text return text[:max_chars] + "\n... (truncated)" def _summarize_history( self, history: list[dict[str, str]], budget_tokens: int, ) -> str | None: """Summarize conversation history to fit budget. Synchronous callers use deterministic truncation. Runtime code should call build_async_messages() when LLM summarization is desired. """ if not history: return None result = [] remaining = budget_tokens for msg in reversed(history): tokens = estimate_tokens(msg.get("content", "")) + 4 if tokens > remaining: break result.append(f"{msg['role']}: {msg['content'][:100]}") remaining -= tokens return "\n".join(reversed(result)) if result else None async def _summarize_history_async( self, history: list[dict[str, str]], budget_tokens: int, ) -> str | None: if self._model_client is None: return self._summarize_history(history, budget_tokens) summarized = await self._llm_summarize_history(history, budget_tokens) return summarized or self._summarize_history(history, budget_tokens) async def _llm_summarize_history( self, history: list[dict[str, str]], budget_tokens: int, ) -> str | None: """Use summary-role LLM to compress history.""" try: history_text = "\n".join( f"{m['role']}: {m.get('content', '')}" for m in history ) response = await self._model_client.chat( self.summary_role, [{ "role": "user", "content": ( "Summarize this conversation history. Keep decisions, outcomes, " "and key facts. Be concise.\n\n" + history_text ), }], ) summary = response.content # Ensure summary fits budget return self._truncate_text(summary, budget_tokens) except Exception as exc: logger.warning("History summarization failed: %s", exc) return None