ducklm/app/memory/vector_index.py

149 lines
4.7 KiB
Python

from __future__ import annotations
import logging
import numpy as np
import hnswlib
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
class VectorIndex:
def __init__(
self,
index_path: str | Path | None = None,
embedding_dim: int = 384,
max_elements: int = 10000,
) -> None:
self._embedding_dim = embedding_dim
self._index_path = Path(index_path) if index_path else None
self._index: hnswlib.Index | None = None
self._max_elements = max_elements
self._loading = False # Prevent recursion
self._init_index()
def _init_index(self) -> None:
if self._loading:
return
self._loading = True
try:
if self._index_path and self._index_path.exists():
self._load()
else:
self._index = hnswlib.Index(
space="l2",
dim=self._embedding_dim,
)
self._index.init_index(
max_elements=self._max_elements,
ef_construction=200,
M=16,
)
except Exception as e:
logger.warning(f"VectorIndex init failed: {e}")
self._index = hnswlib.Index(
space="l2",
dim=self._embedding_dim,
)
self._index.init_index(
max_elements=self._max_elements,
ef_construction=100,
M=16,
)
finally:
self._loading = False
def insert(self, memory_id: str, embedding: np.ndarray) -> None:
if self._index is None:
self._init_index()
if self._index is None:
return
try:
vector = self._normalize(embedding)
internal_id = self._get_internal_id(memory_id)
self._index.add_items(vector, ids=np.array([internal_id]))
except Exception as e:
logger.warning(f"VectorIndex insert failed: {e}")
def search(
self,
query_embedding: np.ndarray,
k: int = 5,
) -> tuple[list[str], list[float]]:
if self._index is None:
return [], []
try:
if self._index.get_current_count() == 0:
return [], []
# Set ef to at least k for proper search
self._index.set_ef(max(k * 2, 50))
vector = self._normalize(query_embedding)
labels, distances = self._index.knn_query(vector, k=k)
memory_ids = [self._get_memory_id(int(label)) for label in labels[0]]
scores = [1.0 - dist for dist in distances[0]]
return memory_ids, scores
except Exception as e:
logger.warning(f"VectorIndex search failed: {e}")
return [], []
def delete(self, memory_id: str) -> bool:
return False
def get_items(self, memory_ids: list[str]) -> np.ndarray:
if self._index is None:
raise RuntimeError("Index not initialized")
internal_ids = [self._get_internal_id(mid) for mid in memory_ids]
return self._index.get_items(np.array(internal_ids))
def save(self) -> None:
if self._index and self._index_path:
try:
self._index_path.parent.mkdir(parents=True, exist_ok=True)
self._index.save_index(str(self._index_path))
except Exception as e:
logger.warning(f"VectorIndex save failed: {e}")
def _load(self) -> None:
if self._loading:
return
self._loading = True
try:
if self._index_path and self._index_path.exists():
self._index = hnswlib.Index(space="l2", dim=self._embedding_dim)
self._index.load_index(
str(self._index_path),
max_elements=self._max_elements
)
except Exception as e:
logger.warning(f"VectorIndex load failed: {e}")
self._init_index()
finally:
self._loading = False
def _normalize(self, vector: np.ndarray) -> np.ndarray:
vec = vector.flatten()
norm = np.linalg.norm(vec)
if norm > 0:
vec = vec / norm
return vec.reshape(1, -1)
def _get_internal_id(self, memory_id: str) -> int:
return hash(memory_id) % (2**31)
def _get_memory_id(self, internal_id: int) -> str:
return str(internal_id)
@property
def embedding_dim(self) -> int:
return self._embedding_dim
@property
def element_count(self) -> int:
return self._index.get_current_count() if self._index else 0