149 lines
4.7 KiB
Python
149 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import numpy as np
|
|
import hnswlib
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class VectorIndex:
|
|
def __init__(
|
|
self,
|
|
index_path: str | Path | None = None,
|
|
embedding_dim: int = 384,
|
|
max_elements: int = 10000,
|
|
) -> None:
|
|
self._embedding_dim = embedding_dim
|
|
self._index_path = Path(index_path) if index_path else None
|
|
self._index: hnswlib.Index | None = None
|
|
self._max_elements = max_elements
|
|
self._loading = False # Prevent recursion
|
|
|
|
self._init_index()
|
|
|
|
def _init_index(self) -> None:
|
|
if self._loading:
|
|
return
|
|
self._loading = True
|
|
try:
|
|
if self._index_path and self._index_path.exists():
|
|
self._load()
|
|
else:
|
|
self._index = hnswlib.Index(
|
|
space="l2",
|
|
dim=self._embedding_dim,
|
|
)
|
|
self._index.init_index(
|
|
max_elements=self._max_elements,
|
|
ef_construction=200,
|
|
M=16,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"VectorIndex init failed: {e}")
|
|
self._index = hnswlib.Index(
|
|
space="l2",
|
|
dim=self._embedding_dim,
|
|
)
|
|
self._index.init_index(
|
|
max_elements=self._max_elements,
|
|
ef_construction=100,
|
|
M=16,
|
|
)
|
|
finally:
|
|
self._loading = False
|
|
|
|
def insert(self, memory_id: str, embedding: np.ndarray) -> None:
|
|
if self._index is None:
|
|
self._init_index()
|
|
if self._index is None:
|
|
return
|
|
|
|
try:
|
|
vector = self._normalize(embedding)
|
|
internal_id = self._get_internal_id(memory_id)
|
|
self._index.add_items(vector, ids=np.array([internal_id]))
|
|
except Exception as e:
|
|
logger.warning(f"VectorIndex insert failed: {e}")
|
|
|
|
def search(
|
|
self,
|
|
query_embedding: np.ndarray,
|
|
k: int = 5,
|
|
) -> tuple[list[str], list[float]]:
|
|
if self._index is None:
|
|
return [], []
|
|
|
|
try:
|
|
if self._index.get_current_count() == 0:
|
|
return [], []
|
|
|
|
# Set ef to at least k for proper search
|
|
self._index.set_ef(max(k * 2, 50))
|
|
|
|
vector = self._normalize(query_embedding)
|
|
labels, distances = self._index.knn_query(vector, k=k)
|
|
|
|
memory_ids = [self._get_memory_id(int(label)) for label in labels[0]]
|
|
scores = [1.0 - dist for dist in distances[0]]
|
|
return memory_ids, scores
|
|
except Exception as e:
|
|
logger.warning(f"VectorIndex search failed: {e}")
|
|
return [], []
|
|
|
|
def delete(self, memory_id: str) -> bool:
|
|
return False
|
|
|
|
def get_items(self, memory_ids: list[str]) -> np.ndarray:
|
|
if self._index is None:
|
|
raise RuntimeError("Index not initialized")
|
|
internal_ids = [self._get_internal_id(mid) for mid in memory_ids]
|
|
return self._index.get_items(np.array(internal_ids))
|
|
|
|
def save(self) -> None:
|
|
if self._index and self._index_path:
|
|
try:
|
|
self._index_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._index.save_index(str(self._index_path))
|
|
except Exception as e:
|
|
logger.warning(f"VectorIndex save failed: {e}")
|
|
|
|
def _load(self) -> None:
|
|
if self._loading:
|
|
return
|
|
self._loading = True
|
|
try:
|
|
if self._index_path and self._index_path.exists():
|
|
self._index = hnswlib.Index(space="l2", dim=self._embedding_dim)
|
|
self._index.load_index(
|
|
str(self._index_path),
|
|
max_elements=self._max_elements
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"VectorIndex load failed: {e}")
|
|
self._init_index()
|
|
finally:
|
|
self._loading = False
|
|
|
|
def _normalize(self, vector: np.ndarray) -> np.ndarray:
|
|
vec = vector.flatten()
|
|
norm = np.linalg.norm(vec)
|
|
if norm > 0:
|
|
vec = vec / norm
|
|
return vec.reshape(1, -1)
|
|
|
|
def _get_internal_id(self, memory_id: str) -> int:
|
|
return hash(memory_id) % (2**31)
|
|
|
|
def _get_memory_id(self, internal_id: int) -> str:
|
|
return str(internal_id)
|
|
|
|
@property
|
|
def embedding_dim(self) -> int:
|
|
return self._embedding_dim
|
|
|
|
@property
|
|
def element_count(self) -> int:
|
|
return self._index.get_current_count() if self._index else 0 |