258 lines
9.0 KiB
Python
258 lines
9.0 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from uuid import uuid4
|
|
|
|
import aiosqlite
|
|
from pydantic import BaseModel, Field
|
|
|
|
from duck_core.tasks.store import utc_now
|
|
|
|
|
|
class MemoryRecord(BaseModel):
|
|
id: int | None = None
|
|
memory_id: str
|
|
text: str
|
|
scope: str = "workspace"
|
|
workspace: str
|
|
conversation_id: str | None = None
|
|
memory_type: str = "note"
|
|
importance: float = 0.5
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
created_at: str
|
|
updated_at: str
|
|
|
|
|
|
class MemoryStore:
|
|
def __init__(self, db_path: str):
|
|
self.db_path = Path(db_path)
|
|
|
|
async def init(self) -> None:
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
await db.execute(
|
|
"""
|
|
create table if not exists memories (
|
|
id integer primary key autoincrement,
|
|
memory_id text not null unique,
|
|
text text not null,
|
|
scope text not null default 'workspace',
|
|
workspace text not null,
|
|
conversation_id text,
|
|
memory_type text not null,
|
|
importance real not null,
|
|
metadata_json text not null,
|
|
created_at text not null,
|
|
updated_at text not null
|
|
)
|
|
"""
|
|
)
|
|
await db.execute(
|
|
"""
|
|
create index if not exists idx_memories_workspace_created
|
|
on memories(workspace, created_at)
|
|
"""
|
|
)
|
|
await self._ensure_scope_column(db)
|
|
await db.commit()
|
|
|
|
async def _ensure_scope_column(self, db: aiosqlite.Connection) -> None:
|
|
cursor = await db.execute("pragma table_info(memories)")
|
|
columns = {row[1] for row in await cursor.fetchall()}
|
|
if "scope" not in columns:
|
|
await db.execute(
|
|
"alter table memories add column scope text not null default 'workspace'"
|
|
)
|
|
|
|
async def add(
|
|
self,
|
|
text: str,
|
|
workspace: str,
|
|
scope: str = "workspace",
|
|
conversation_id: str | None = None,
|
|
memory_type: str = "note",
|
|
importance: float = 0.5,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> MemoryRecord:
|
|
await self.init()
|
|
now = utc_now()
|
|
memory_id = f"mem_{uuid4().hex[:12]}"
|
|
clean_text = " ".join(text.strip().split())
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
cursor = await db.execute(
|
|
"""
|
|
insert into memories(
|
|
memory_id, text, scope, workspace, conversation_id, memory_type,
|
|
importance, metadata_json, created_at, updated_at
|
|
) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
memory_id,
|
|
clean_text,
|
|
self._normalize_scope(scope),
|
|
workspace,
|
|
conversation_id,
|
|
memory_type or "note",
|
|
max(0.0, min(float(importance), 1.0)),
|
|
json.dumps(metadata or {}, ensure_ascii=False),
|
|
now,
|
|
now,
|
|
),
|
|
)
|
|
await db.commit()
|
|
row_id = cursor.lastrowid
|
|
return MemoryRecord(
|
|
id=row_id,
|
|
memory_id=memory_id,
|
|
text=clean_text,
|
|
scope=self._normalize_scope(scope),
|
|
workspace=workspace,
|
|
conversation_id=conversation_id,
|
|
memory_type=memory_type or "note",
|
|
importance=max(0.0, min(float(importance), 1.0)),
|
|
metadata=metadata or {},
|
|
created_at=now,
|
|
updated_at=now,
|
|
)
|
|
|
|
async def list(
|
|
self, workspace: str | None = None, limit: int = 50
|
|
) -> list[MemoryRecord]:
|
|
await self.init()
|
|
bounded_limit = min(max(limit, 1), 200)
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
if workspace:
|
|
cursor = await db.execute(
|
|
"""
|
|
select * from memories
|
|
where workspace = ? or scope = 'global'
|
|
order by importance desc, created_at desc
|
|
limit ?
|
|
""",
|
|
(workspace, bounded_limit),
|
|
)
|
|
else:
|
|
cursor = await db.execute(
|
|
"""
|
|
select * from memories
|
|
order by importance desc, created_at desc
|
|
limit ?
|
|
""",
|
|
(bounded_limit,),
|
|
)
|
|
rows = await cursor.fetchall()
|
|
return [self._row_to_record(row) for row in rows]
|
|
|
|
async def search(
|
|
self, query: str, workspace: str | None = None, limit: int = 20
|
|
) -> list[MemoryRecord]:
|
|
await self.init()
|
|
bounded_limit = min(max(limit, 1), 100)
|
|
pattern = f"%{query.strip()}%"
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
if workspace:
|
|
cursor = await db.execute(
|
|
"""
|
|
select * from memories
|
|
where workspace = ?
|
|
and (text like ? or memory_type like ? or metadata_json like ?)
|
|
order by importance desc, created_at desc
|
|
limit ?
|
|
""",
|
|
(workspace, pattern, pattern, pattern, bounded_limit),
|
|
)
|
|
else:
|
|
cursor = await db.execute(
|
|
"""
|
|
select * from memories
|
|
where text like ? or memory_type like ? or metadata_json like ?
|
|
order by importance desc, created_at desc
|
|
limit ?
|
|
""",
|
|
(pattern, pattern, pattern, bounded_limit),
|
|
)
|
|
rows = await cursor.fetchall()
|
|
return [self._row_to_record(row) for row in rows]
|
|
|
|
async def relevant(
|
|
self,
|
|
workspace: str,
|
|
conversation_id: str | None = None,
|
|
query: str = "",
|
|
limit: int = 8,
|
|
) -> list[MemoryRecord]:
|
|
await self.init()
|
|
bounded_limit = min(max(limit, 1), 30)
|
|
terms = [term.lower() for term in query.split() if len(term) >= 3]
|
|
async with aiosqlite.connect(self.db_path) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute(
|
|
"""
|
|
select * from memories
|
|
where scope = 'global'
|
|
or (scope = 'workspace' and workspace = ?)
|
|
or (scope = 'conversation' and workspace = ? and conversation_id = ?)
|
|
order by
|
|
case scope
|
|
when 'global' then 0
|
|
when 'conversation' then 1
|
|
else 2
|
|
end,
|
|
importance desc,
|
|
created_at desc
|
|
limit ?
|
|
""",
|
|
(workspace, workspace, conversation_id, bounded_limit * 3),
|
|
)
|
|
rows = await cursor.fetchall()
|
|
records = [self._row_to_record(row) for row in rows]
|
|
if not terms:
|
|
return records[:bounded_limit]
|
|
matching = [
|
|
record
|
|
for record in records
|
|
if any(term in record.text.lower() for term in terms)
|
|
or record.scope == "global"
|
|
]
|
|
return matching[:bounded_limit]
|
|
|
|
def infer_scope(self, text: str, workspace: str, conversation_id: str | None) -> str:
|
|
lowered = text.lower()
|
|
global_markers = (
|
|
"user prefers",
|
|
"пользователь предпочитает",
|
|
"отвечай",
|
|
"always",
|
|
"всегда",
|
|
"rx580",
|
|
"radeon",
|
|
"vulkan",
|
|
)
|
|
if any(marker in lowered for marker in global_markers):
|
|
return "global"
|
|
if conversation_id and any(marker in lowered for marker in ("this chat", "этот чат", "диалог")):
|
|
return "conversation"
|
|
return "workspace" if workspace else "global"
|
|
|
|
def _normalize_scope(self, scope: str) -> str:
|
|
return scope if scope in {"global", "workspace", "conversation"} else "workspace"
|
|
|
|
def _row_to_record(self, row: aiosqlite.Row) -> MemoryRecord:
|
|
return MemoryRecord(
|
|
id=row["id"],
|
|
memory_id=row["memory_id"],
|
|
text=row["text"],
|
|
scope=row["scope"],
|
|
workspace=row["workspace"],
|
|
conversation_id=row["conversation_id"],
|
|
memory_type=row["memory_type"],
|
|
importance=float(row["importance"]),
|
|
metadata=json.loads(row["metadata_json"]),
|
|
created_at=row["created_at"],
|
|
updated_at=row["updated_at"],
|
|
)
|