112 lines
3.9 KiB
Python
112 lines
3.9 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
from app.core.contracts import ToolResult, UserTask
|
|
from app.tools.base import BaseTool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Tool(BaseTool):
|
|
name = "memory"
|
|
description = "Memory operations: insert, search, list"
|
|
|
|
def __init__(self, memory_interface=None) -> None:
|
|
self._memory = memory_interface
|
|
|
|
def execute(self, task: UserTask, args: dict[str, Any]) -> ToolResult:
|
|
action = args.get("action", "search")
|
|
|
|
if action == "insert":
|
|
return self._insert(task, args)
|
|
elif action == "search":
|
|
return self._search(task, args)
|
|
elif action == "list":
|
|
return self._list(task, args)
|
|
else:
|
|
return ToolResult(tool=self.name, ok=False, error=f"Unknown action: {action}")
|
|
|
|
def _insert(self, task: UserTask, args: dict[str, Any]) -> ToolResult:
|
|
text = args.get("text", "")
|
|
kind = args.get("kind", "fact")
|
|
source = args.get("source", "user")
|
|
weight = args.get("weight", 0.5)
|
|
|
|
if not text:
|
|
return ToolResult(tool=self.name, ok=False, output="", error="text is required")
|
|
if not self._memory:
|
|
return ToolResult(tool=self.name, ok=False, output="", error="Memory not available")
|
|
|
|
try:
|
|
entry = self._memory.insert(
|
|
text=text,
|
|
kind=kind,
|
|
source=source,
|
|
task_id=task.task_id,
|
|
session_id=task.session_id,
|
|
weight=weight,
|
|
)
|
|
return ToolResult(
|
|
tool=self.name,
|
|
ok=True,
|
|
output=f"Stored: {entry.id}",
|
|
metadata={"entry_id": entry.id},
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Memory insert failed: {e}")
|
|
return ToolResult(tool=self.name, ok=False, output="", error=str(e))
|
|
|
|
def _search(self, task: UserTask, args: dict[str, Any]) -> ToolResult:
|
|
query = args.get("query", "")
|
|
top_k = args.get("top_k", 5)
|
|
|
|
if not query:
|
|
return ToolResult(tool=self.name, ok=False, output="", error="query is required")
|
|
if not self._memory:
|
|
return ToolResult(tool=self.name, ok=False, output="", error="Memory not available")
|
|
|
|
try:
|
|
results = self._memory.search(query, top_k=top_k)
|
|
if not results:
|
|
return ToolResult(tool=self.name, ok=True, output="No results found", metadata={"count": 0})
|
|
|
|
output_lines = []
|
|
for entry, score in results:
|
|
output_lines.append(f"[{score:.2f}] {entry.text[:100]}")
|
|
|
|
return ToolResult(
|
|
tool=self.name,
|
|
ok=True,
|
|
output="\n".join(output_lines),
|
|
metadata={"count": len(results)},
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Memory search failed: {e}")
|
|
return ToolResult(tool=self.name, ok=False, output="", error=str(e))
|
|
|
|
def _list(self, task: UserTask, args: dict[str, Any]) -> ToolResult:
|
|
limit = args.get("limit", 10)
|
|
|
|
if not self._memory:
|
|
return ToolResult(tool=self.name, ok=False, output="", error="Memory not available")
|
|
|
|
try:
|
|
entries = self._memory.get_recent(limit=limit)
|
|
if not entries:
|
|
return ToolResult(tool=self.name, ok=True, output="No memories", metadata={"count": 0})
|
|
|
|
output_lines = []
|
|
for entry in entries:
|
|
output_lines.append(f"{entry.kind}: {entry.text[:80]}")
|
|
|
|
return ToolResult(
|
|
tool=self.name,
|
|
ok=True,
|
|
output="\n".join(output_lines),
|
|
metadata={"count": len(entries)},
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Memory list failed: {e}")
|
|
return ToolResult(tool=self.name, ok=False, output="", error=str(e)) |