from __future__ import annotations import logging import numpy as np import hnswlib from pathlib import Path from typing import Any logger = logging.getLogger(__name__) class VectorIndex: def __init__( self, index_path: str | Path | None = None, embedding_dim: int = 384, max_elements: int = 10000, ) -> None: self._embedding_dim = embedding_dim self._index_path = Path(index_path) if index_path else None self._index: hnswlib.Index | None = None self._max_elements = max_elements self._loading = False # Prevent recursion self._init_index() def _init_index(self) -> None: if self._loading: return self._loading = True try: if self._index_path and self._index_path.exists(): self._load() else: self._index = hnswlib.Index( space="l2", dim=self._embedding_dim, ) self._index.init_index( max_elements=self._max_elements, ef_construction=200, M=16, ) except Exception as e: logger.warning(f"VectorIndex init failed: {e}") self._index = hnswlib.Index( space="l2", dim=self._embedding_dim, ) self._index.init_index( max_elements=self._max_elements, ef_construction=100, M=16, ) finally: self._loading = False def insert(self, memory_id: str, embedding: np.ndarray) -> None: if self._index is None: self._init_index() if self._index is None: return try: vector = self._normalize(embedding) internal_id = self._get_internal_id(memory_id) self._index.add_items(vector, ids=np.array([internal_id])) except Exception as e: logger.warning(f"VectorIndex insert failed: {e}") def search( self, query_embedding: np.ndarray, k: int = 5, ) -> tuple[list[str], list[float]]: if self._index is None: return [], [] try: if self._index.get_current_count() == 0: return [], [] # Set ef to at least k for proper search self._index.set_ef(max(k * 2, 50)) vector = self._normalize(query_embedding) labels, distances = self._index.knn_query(vector, k=k) memory_ids = [self._get_memory_id(int(label)) for label in labels[0]] scores = [1.0 - dist for dist in distances[0]] return memory_ids, scores except Exception as e: logger.warning(f"VectorIndex search failed: {e}") return [], [] def delete(self, memory_id: str) -> bool: return False def get_items(self, memory_ids: list[str]) -> np.ndarray: if self._index is None: raise RuntimeError("Index not initialized") internal_ids = [self._get_internal_id(mid) for mid in memory_ids] return self._index.get_items(np.array(internal_ids)) def save(self) -> None: if self._index and self._index_path: try: self._index_path.parent.mkdir(parents=True, exist_ok=True) self._index.save_index(str(self._index_path)) except Exception as e: logger.warning(f"VectorIndex save failed: {e}") def _load(self) -> None: if self._loading: return self._loading = True try: if self._index_path and self._index_path.exists(): self._index = hnswlib.Index(space="l2", dim=self._embedding_dim) self._index.load_index( str(self._index_path), max_elements=self._max_elements ) except Exception as e: logger.warning(f"VectorIndex load failed: {e}") self._init_index() finally: self._loading = False def _normalize(self, vector: np.ndarray) -> np.ndarray: vec = vector.flatten() norm = np.linalg.norm(vec) if norm > 0: vec = vec / norm return vec.reshape(1, -1) def _get_internal_id(self, memory_id: str) -> int: return hash(memory_id) % (2**31) def _get_memory_id(self, internal_id: int) -> str: return str(internal_id) @property def embedding_dim(self) -> int: return self._embedding_dim @property def element_count(self) -> int: return self._index.get_current_count() if self._index else 0