37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
class EmbeddingsAdapter:
|
|
def __init__(
|
|
self,
|
|
model_path: str | Path | None = None,
|
|
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
|
embedding_dim: int = 384,
|
|
) -> None:
|
|
self._embedding_dim = embedding_dim
|
|
if model_path and Path(model_path).exists():
|
|
self._model = SentenceTransformer(str(model_path))
|
|
else:
|
|
self._model = SentenceTransformer(model_name)
|
|
|
|
def encode(self, texts: str | list[str]) -> np.ndarray:
|
|
is_single = isinstance(texts, str)
|
|
if is_single:
|
|
texts = [texts]
|
|
embeddings = self._model.encode(texts, convert_to_numpy=True)
|
|
if is_single:
|
|
return embeddings[0]
|
|
return embeddings
|
|
|
|
def encode_batch(self, texts: list[str], batch_size: int = 32) -> np.ndarray:
|
|
return self._model.encode(texts, batch_size=batch_size, convert_to_numpy=True)
|
|
|
|
@property
|
|
def embedding_dim(self) -> int:
|
|
return self._embedding_dim |