185 lines
6.7 KiB
Python
185 lines
6.7 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import sqlite3
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any, Sequence
|
|
from uuid import uuid4
|
|
|
|
from app.core.contracts import MemoryEntry
|
|
|
|
|
|
def utc_now() -> datetime:
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
class MemoryStore:
|
|
def __init__(self, db_path: str | Path) -> None:
|
|
self._db_path = Path(db_path)
|
|
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._conn = sqlite3.connect(str(self._db_path), check_same_thread=False)
|
|
self._conn.row_factory = sqlite3.Row
|
|
self._init_tables()
|
|
|
|
def _init_tables(self) -> None:
|
|
self._conn.executescript("""
|
|
CREATE TABLE IF NOT EXISTS memory_items (
|
|
id TEXT PRIMARY KEY,
|
|
text TEXT NOT NULL,
|
|
kind TEXT NOT NULL,
|
|
source TEXT NOT NULL,
|
|
weight REAL NOT NULL DEFAULT 0.5,
|
|
task_id TEXT,
|
|
session_id TEXT,
|
|
metadata_json TEXT,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS memory_embeddings (
|
|
memory_id TEXT PRIMARY KEY,
|
|
embedding BLOB NOT NULL,
|
|
embedding_model TEXT NOT NULL,
|
|
embedding_dim INTEGER NOT NULL,
|
|
created_at TEXT NOT NULL,
|
|
FOREIGN KEY (memory_id) REFERENCES memory_items(id) ON DELETE CASCADE
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_memory_items_task ON memory_items(task_id);
|
|
CREATE INDEX IF NOT EXISTS idx_memory_items_session ON memory_items(session_id);
|
|
CREATE INDEX IF NOT EXISTS idx_memory_items_kind ON memory_items(kind);
|
|
CREATE INDEX IF NOT EXISTS idx_memory_embeddings_model ON memory_embeddings(embedding_model);
|
|
""")
|
|
self._conn.commit()
|
|
|
|
def insert(self, entry: MemoryEntry, embedding: bytes) -> None:
|
|
cursor = self._conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO memory_items (id, text, kind, source, weight, task_id, session_id, metadata_json, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
entry.id,
|
|
entry.text,
|
|
entry.kind,
|
|
entry.source,
|
|
entry.weight,
|
|
entry.task_id,
|
|
entry.session_id,
|
|
json.dumps(entry.metadata) if entry.metadata else None,
|
|
entry.created_at.isoformat(),
|
|
utc_now().isoformat(),
|
|
),
|
|
)
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO memory_embeddings (memory_id, embedding, embedding_model, embedding_dim, created_at)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
entry.id,
|
|
embedding,
|
|
entry.embedding_model,
|
|
entry.embedding_dim,
|
|
utc_now().isoformat(),
|
|
),
|
|
)
|
|
self._conn.commit()
|
|
|
|
def get(self, memory_id: str) -> MemoryEntry | None:
|
|
cursor = self._conn.cursor()
|
|
row = cursor.execute(
|
|
"SELECT * FROM memory_items WHERE id = ?", (memory_id,)
|
|
).fetchone()
|
|
if not row:
|
|
return None
|
|
return self._row_to_entry(row)
|
|
|
|
def get_embedding(self, memory_id: str) -> bytes | None:
|
|
cursor = self._conn.cursor()
|
|
row = cursor.execute(
|
|
"SELECT embedding FROM memory_embeddings WHERE memory_id = ?", (memory_id,)
|
|
).fetchone()
|
|
return bytes(row["embedding"]) if row else None
|
|
|
|
def get_all(self, limit: int = 1000) -> list[MemoryEntry]:
|
|
cursor = self._conn.cursor()
|
|
rows = cursor.execute(
|
|
"SELECT * FROM memory_items ORDER BY created_at DESC LIMIT ?", (limit,)
|
|
).fetchall()
|
|
return [self._row_to_entry(row) for row in rows]
|
|
|
|
def get_by_task(self, task_id: str) -> list[MemoryEntry]:
|
|
cursor = self._conn.cursor()
|
|
rows = cursor.execute(
|
|
"SELECT * FROM memory_items WHERE task_id = ? ORDER BY created_at DESC", (task_id,)
|
|
).fetchall()
|
|
return [self._row_to_entry(row) for row in rows]
|
|
|
|
def get_by_session(self, session_id: str, limit: int = 100) -> list[MemoryEntry]:
|
|
cursor = self._conn.cursor()
|
|
rows = cursor.execute(
|
|
"SELECT * FROM memory_items WHERE session_id = ? ORDER BY created_at DESC LIMIT ?",
|
|
(session_id, limit),
|
|
).fetchall()
|
|
return [self._row_to_entry(row) for row in rows]
|
|
|
|
def get_by_kind(self, kind: str, limit: int = 100) -> list[MemoryEntry]:
|
|
cursor = self._conn.cursor()
|
|
rows = cursor.execute(
|
|
"SELECT * FROM memory_items WHERE kind = ? ORDER BY created_at DESC LIMIT ?", (kind, limit)
|
|
).fetchall()
|
|
return [self._row_to_entry(row) for row in rows]
|
|
|
|
def delete(self, memory_id: str) -> bool:
|
|
cursor = self._conn.cursor()
|
|
cursor.execute("DELETE FROM memory_embeddings WHERE memory_id = ?", (memory_id,))
|
|
cursor.execute("DELETE FROM memory_items WHERE id = ?", (memory_id,))
|
|
self._conn.commit()
|
|
return cursor.rowcount > 0
|
|
|
|
def update_weight(self, memory_id: str, weight: float) -> bool:
|
|
cursor = self._conn.cursor()
|
|
cursor.execute(
|
|
"UPDATE memory_items SET weight = ?, updated_at = ? WHERE id = ?",
|
|
(weight, utc_now().isoformat(), memory_id),
|
|
)
|
|
self._conn.commit()
|
|
return cursor.rowcount > 0
|
|
|
|
def search_text(self, query: str, limit: int = 10) -> list[MemoryEntry]:
|
|
cursor = self._conn.cursor()
|
|
rows = cursor.execute(
|
|
"SELECT * FROM memory_items WHERE text LIKE ? ORDER BY created_at DESC LIMIT ?",
|
|
(f"%{query}%", limit),
|
|
).fetchall()
|
|
return [self._row_to_entry(row) for row in rows]
|
|
|
|
def count(self) -> int:
|
|
cursor = self._conn.cursor()
|
|
row = cursor.execute("SELECT COUNT(*) FROM memory_items").fetchone()
|
|
return row[0] if row else 0
|
|
|
|
def close(self) -> None:
|
|
self._conn.close()
|
|
|
|
def _row_to_entry(self, row: sqlite3.Row) -> MemoryEntry:
|
|
metadata = {}
|
|
if row["metadata_json"]:
|
|
import json
|
|
metadata = json.loads(row["metadata_json"])
|
|
return MemoryEntry(
|
|
id=row["id"],
|
|
text=row["text"],
|
|
kind=row["kind"],
|
|
source=row["source"],
|
|
weight=row["weight"],
|
|
task_id=row["task_id"],
|
|
session_id=row["session_id"],
|
|
metadata=metadata,
|
|
created_at=datetime.fromisoformat(row["created_at"]),
|
|
embedding_model="",
|
|
embedding_dim=0,
|
|
) |