ducklm/app/memory/store.py

185 lines
6.7 KiB
Python

from __future__ import annotations
import json
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Sequence
from uuid import uuid4
from app.core.contracts import MemoryEntry
def utc_now() -> datetime:
return datetime.now(timezone.utc)
class MemoryStore:
def __init__(self, db_path: str | Path) -> None:
self._db_path = Path(db_path)
self._db_path.parent.mkdir(parents=True, exist_ok=True)
self._conn = sqlite3.connect(str(self._db_path), check_same_thread=False)
self._conn.row_factory = sqlite3.Row
self._init_tables()
def _init_tables(self) -> None:
self._conn.executescript("""
CREATE TABLE IF NOT EXISTS memory_items (
id TEXT PRIMARY KEY,
text TEXT NOT NULL,
kind TEXT NOT NULL,
source TEXT NOT NULL,
weight REAL NOT NULL DEFAULT 0.5,
task_id TEXT,
session_id TEXT,
metadata_json TEXT,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS memory_embeddings (
memory_id TEXT PRIMARY KEY,
embedding BLOB NOT NULL,
embedding_model TEXT NOT NULL,
embedding_dim INTEGER NOT NULL,
created_at TEXT NOT NULL,
FOREIGN KEY (memory_id) REFERENCES memory_items(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_memory_items_task ON memory_items(task_id);
CREATE INDEX IF NOT EXISTS idx_memory_items_session ON memory_items(session_id);
CREATE INDEX IF NOT EXISTS idx_memory_items_kind ON memory_items(kind);
CREATE INDEX IF NOT EXISTS idx_memory_embeddings_model ON memory_embeddings(embedding_model);
""")
self._conn.commit()
def insert(self, entry: MemoryEntry, embedding: bytes) -> None:
cursor = self._conn.cursor()
cursor.execute(
"""
INSERT INTO memory_items (id, text, kind, source, weight, task_id, session_id, metadata_json, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
entry.id,
entry.text,
entry.kind,
entry.source,
entry.weight,
entry.task_id,
entry.session_id,
json.dumps(entry.metadata) if entry.metadata else None,
entry.created_at.isoformat(),
utc_now().isoformat(),
),
)
cursor.execute(
"""
INSERT INTO memory_embeddings (memory_id, embedding, embedding_model, embedding_dim, created_at)
VALUES (?, ?, ?, ?, ?)
""",
(
entry.id,
embedding,
entry.embedding_model,
entry.embedding_dim,
utc_now().isoformat(),
),
)
self._conn.commit()
def get(self, memory_id: str) -> MemoryEntry | None:
cursor = self._conn.cursor()
row = cursor.execute(
"SELECT * FROM memory_items WHERE id = ?", (memory_id,)
).fetchone()
if not row:
return None
return self._row_to_entry(row)
def get_embedding(self, memory_id: str) -> bytes | None:
cursor = self._conn.cursor()
row = cursor.execute(
"SELECT embedding FROM memory_embeddings WHERE memory_id = ?", (memory_id,)
).fetchone()
return bytes(row["embedding"]) if row else None
def get_all(self, limit: int = 1000) -> list[MemoryEntry]:
cursor = self._conn.cursor()
rows = cursor.execute(
"SELECT * FROM memory_items ORDER BY created_at DESC LIMIT ?", (limit,)
).fetchall()
return [self._row_to_entry(row) for row in rows]
def get_by_task(self, task_id: str) -> list[MemoryEntry]:
cursor = self._conn.cursor()
rows = cursor.execute(
"SELECT * FROM memory_items WHERE task_id = ? ORDER BY created_at DESC", (task_id,)
).fetchall()
return [self._row_to_entry(row) for row in rows]
def get_by_session(self, session_id: str, limit: int = 100) -> list[MemoryEntry]:
cursor = self._conn.cursor()
rows = cursor.execute(
"SELECT * FROM memory_items WHERE session_id = ? ORDER BY created_at DESC LIMIT ?",
(session_id, limit),
).fetchall()
return [self._row_to_entry(row) for row in rows]
def get_by_kind(self, kind: str, limit: int = 100) -> list[MemoryEntry]:
cursor = self._conn.cursor()
rows = cursor.execute(
"SELECT * FROM memory_items WHERE kind = ? ORDER BY created_at DESC LIMIT ?", (kind, limit)
).fetchall()
return [self._row_to_entry(row) for row in rows]
def delete(self, memory_id: str) -> bool:
cursor = self._conn.cursor()
cursor.execute("DELETE FROM memory_embeddings WHERE memory_id = ?", (memory_id,))
cursor.execute("DELETE FROM memory_items WHERE id = ?", (memory_id,))
self._conn.commit()
return cursor.rowcount > 0
def update_weight(self, memory_id: str, weight: float) -> bool:
cursor = self._conn.cursor()
cursor.execute(
"UPDATE memory_items SET weight = ?, updated_at = ? WHERE id = ?",
(weight, utc_now().isoformat(), memory_id),
)
self._conn.commit()
return cursor.rowcount > 0
def search_text(self, query: str, limit: int = 10) -> list[MemoryEntry]:
cursor = self._conn.cursor()
rows = cursor.execute(
"SELECT * FROM memory_items WHERE text LIKE ? ORDER BY created_at DESC LIMIT ?",
(f"%{query}%", limit),
).fetchall()
return [self._row_to_entry(row) for row in rows]
def count(self) -> int:
cursor = self._conn.cursor()
row = cursor.execute("SELECT COUNT(*) FROM memory_items").fetchone()
return row[0] if row else 0
def close(self) -> None:
self._conn.close()
def _row_to_entry(self, row: sqlite3.Row) -> MemoryEntry:
metadata = {}
if row["metadata_json"]:
import json
metadata = json.loads(row["metadata_json"])
return MemoryEntry(
id=row["id"],
text=row["text"],
kind=row["kind"],
source=row["source"],
weight=row["weight"],
task_id=row["task_id"],
session_id=row["session_id"],
metadata=metadata,
created_at=datetime.fromisoformat(row["created_at"]),
embedding_model="",
embedding_dim=0,
)