from __future__ import annotations import json from pathlib import Path from typing import Any from uuid import uuid4 import aiosqlite from pydantic import BaseModel, Field from duck_core.tasks.store import utc_now class MemoryRecord(BaseModel): id: int | None = None memory_id: str text: str scope: str = "workspace" workspace: str conversation_id: str | None = None memory_type: str = "note" importance: float = 0.5 metadata: dict[str, Any] = Field(default_factory=dict) created_at: str updated_at: str class MemoryStore: def __init__(self, db_path: str): self.db_path = Path(db_path) async def init(self) -> None: self.db_path.parent.mkdir(parents=True, exist_ok=True) async with aiosqlite.connect(self.db_path) as db: await db.execute( """ create table if not exists memories ( id integer primary key autoincrement, memory_id text not null unique, text text not null, scope text not null default 'workspace', workspace text not null, conversation_id text, memory_type text not null, importance real not null, metadata_json text not null, created_at text not null, updated_at text not null ) """ ) await db.execute( """ create index if not exists idx_memories_workspace_created on memories(workspace, created_at) """ ) await self._ensure_scope_column(db) await db.commit() async def _ensure_scope_column(self, db: aiosqlite.Connection) -> None: cursor = await db.execute("pragma table_info(memories)") columns = {row[1] for row in await cursor.fetchall()} if "scope" not in columns: await db.execute( "alter table memories add column scope text not null default 'workspace'" ) async def add( self, text: str, workspace: str, scope: str = "workspace", conversation_id: str | None = None, memory_type: str = "note", importance: float = 0.5, metadata: dict[str, Any] | None = None, ) -> MemoryRecord: await self.init() now = utc_now() memory_id = f"mem_{uuid4().hex[:12]}" clean_text = " ".join(text.strip().split()) async with aiosqlite.connect(self.db_path) as db: cursor = await db.execute( """ insert into memories( memory_id, text, scope, workspace, conversation_id, memory_type, importance, metadata_json, created_at, updated_at ) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( memory_id, clean_text, self._normalize_scope(scope), workspace, conversation_id, memory_type or "note", max(0.0, min(float(importance), 1.0)), json.dumps(metadata or {}, ensure_ascii=False), now, now, ), ) await db.commit() row_id = cursor.lastrowid return MemoryRecord( id=row_id, memory_id=memory_id, text=clean_text, scope=self._normalize_scope(scope), workspace=workspace, conversation_id=conversation_id, memory_type=memory_type or "note", importance=max(0.0, min(float(importance), 1.0)), metadata=metadata or {}, created_at=now, updated_at=now, ) async def list( self, workspace: str | None = None, limit: int = 50 ) -> list[MemoryRecord]: await self.init() bounded_limit = min(max(limit, 1), 200) async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row if workspace: cursor = await db.execute( """ select * from memories where workspace = ? or scope = 'global' order by importance desc, created_at desc limit ? """, (workspace, bounded_limit), ) else: cursor = await db.execute( """ select * from memories order by importance desc, created_at desc limit ? """, (bounded_limit,), ) rows = await cursor.fetchall() return [self._row_to_record(row) for row in rows] async def search( self, query: str, workspace: str | None = None, limit: int = 20 ) -> list[MemoryRecord]: await self.init() bounded_limit = min(max(limit, 1), 100) pattern = f"%{query.strip()}%" async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row if workspace: cursor = await db.execute( """ select * from memories where workspace = ? and (text like ? or memory_type like ? or metadata_json like ?) order by importance desc, created_at desc limit ? """, (workspace, pattern, pattern, pattern, bounded_limit), ) else: cursor = await db.execute( """ select * from memories where text like ? or memory_type like ? or metadata_json like ? order by importance desc, created_at desc limit ? """, (pattern, pattern, pattern, bounded_limit), ) rows = await cursor.fetchall() return [self._row_to_record(row) for row in rows] async def relevant( self, workspace: str, conversation_id: str | None = None, query: str = "", limit: int = 8, ) -> list[MemoryRecord]: await self.init() bounded_limit = min(max(limit, 1), 30) terms = [term.lower() for term in query.split() if len(term) >= 3] async with aiosqlite.connect(self.db_path) as db: db.row_factory = aiosqlite.Row cursor = await db.execute( """ select * from memories where scope = 'global' or (scope = 'workspace' and workspace = ?) or (scope = 'conversation' and workspace = ? and conversation_id = ?) order by case scope when 'global' then 0 when 'conversation' then 1 else 2 end, importance desc, created_at desc limit ? """, (workspace, workspace, conversation_id, bounded_limit * 3), ) rows = await cursor.fetchall() records = [self._row_to_record(row) for row in rows] if not terms: return records[:bounded_limit] matching = [ record for record in records if any(term in record.text.lower() for term in terms) or record.scope == "global" ] return matching[:bounded_limit] def infer_scope(self, text: str, workspace: str, conversation_id: str | None) -> str: lowered = text.lower() global_markers = ( "user prefers", "пользователь предпочитает", "отвечай", "always", "всегда", "rx580", "radeon", "vulkan", ) if any(marker in lowered for marker in global_markers): return "global" if conversation_id and any(marker in lowered for marker in ("this chat", "этот чат", "диалог")): return "conversation" return "workspace" if workspace else "global" def _normalize_scope(self, scope: str) -> str: return scope if scope in {"global", "workspace", "conversation"} else "workspace" def _row_to_record(self, row: aiosqlite.Row) -> MemoryRecord: return MemoryRecord( id=row["id"], memory_id=row["memory_id"], text=row["text"], scope=row["scope"], workspace=row["workspace"], conversation_id=row["conversation_id"], memory_type=row["memory_type"], importance=float(row["importance"]), metadata=json.loads(row["metadata_json"]), created_at=row["created_at"], updated_at=row["updated_at"], )