Source code for pyseekdb.utils.embedding_functions.onnx_embedding_function

"""
ONNX-based embedding function implementation.

This module provides a generic ONNX embedding function that can run
sentence-transformer style models via onnxruntime.
"""

import contextlib
import logging
import os
from functools import cached_property
from pathlib import Path
from typing import Any

import numpy as np
import numpy.typing as npt

logger = logging.getLogger(__name__)

Documents = str | list[str]
Embeddings = list[list[float]]


[docs] class OnnxEmbeddingFunction: """ Generic ONNX runtime embedding function. This class handles model download, tokenizer/model loading, and embedding generation using onnxruntime. """ EXTRACTED_FOLDER_NAME = "onnx" ARCHIVE_FILENAME = "onnx.tar.gz"
[docs] def __init__( self, model_name: str, hf_model_id: str, dimension: int, download_path: Path | None = None, preferred_providers: list[str] | None = None, ): """ Initialize an ONNX embedding function. Args: model_name: Name of the model (used for cache directory naming). hf_model_id: Hugging Face model ID. dimension: Output embedding dimension. download_path: Optional cache path override. preferred_providers: Preferred ONNX runtime providers. """ if not model_name: raise ValueError("model_name must be a non-empty string") if not hf_model_id: raise ValueError("hf_model_id must be a non-empty string") if dimension <= 0: raise ValueError("dimension must be a positive integer") self.model_name = model_name self.hf_model_id = hf_model_id self._dimension = dimension self.download_path = ( download_path if download_path is not None else Path.home() / ".cache" / "pyseekdb" / "onnx_models" / model_name ) # Validate preferred_providers if preferred_providers and not all(isinstance(i, str) for i in preferred_providers): raise ValueError("Preferred providers must be a list of strings") if preferred_providers and len(preferred_providers) != len(set(preferred_providers)): raise ValueError("Preferred providers must be unique") self._preferred_providers = preferred_providers # Import required modules lazily to avoid hard dependencies at import time import onnxruntime as ort_module import tokenizers import tqdm self.ort = ort_module self.tokenizers = tokenizers # Store the module self.tqdm = tqdm.tqdm
@property def dimension(self) -> int: """Get the dimension of embeddings produced by this function.""" return self._dimension def _download(self, url: str, fname: str, chunk_size: int = 8192) -> None: """ Download a file from the URL and save it to the file path. Args: url: The URL to download the file from. fname: The path to save the file to. chunk_size: The chunk size to use when downloading. """ logger.info(f"Downloading from {url}") # Use Client to ensure correct handling of redirects import httpx with httpx.Client(timeout=600.0, follow_redirects=True) as client, client.stream("GET", url) as resp: resp.raise_for_status() total = int(resp.headers.get("content-length", 0)) with ( open(fname, "wb") as file, self.tqdm( desc=os.path.basename(fname), total=total, unit="iB", unit_scale=True, unit_divisor=1024, ) as bar, ): for data in resp.iter_bytes(chunk_size=chunk_size): size = file.write(data) bar.update(size) def _get_hf_endpoint(self) -> str: """Get Hugging Face endpoint URL, using HF_ENDPOINT environment variable if set.""" return os.environ.get("HF_ENDPOINT", "https://hf-mirror.com") def _download_from_huggingface(self) -> bool: # noqa: C901 """ Download model files from Hugging Face (supports mirror acceleration). Returns: True if download successful, False otherwise. """ try: hf_endpoint = self._get_hf_endpoint() # Remove trailing slash hf_endpoint = hf_endpoint.rstrip("/") # List of files to download # ONNX model files are in the onnx/ subdirectory, other files in the root directory files_to_download = { "onnx/model.onnx": "model.onnx", # ONNX file in onnx subdirectory "tokenizer.json": "tokenizer.json", "config.json": "config.json", "special_tokens_map.json": "special_tokens_map.json", "tokenizer_config.json": "tokenizer_config.json", "vocab.txt": "vocab.txt", } extracted_folder = os.path.join(self.download_path, self.EXTRACTED_FOLDER_NAME) os.makedirs(extracted_folder, exist_ok=True) logger.info(f"Downloading model from Hugging Face (endpoint: {hf_endpoint})") import httpx # Download each file for hf_filename, local_filename in files_to_download.items(): local_path = os.path.join(extracted_folder, local_filename) # Skip if file already exists if os.path.exists(local_path): continue # Construct Hugging Face download URL url = f"{hf_endpoint}/{self.hf_model_id}/resolve/main/{hf_filename}" try: # First check if file exists (HEAD request) with contextlib.suppress(Exception): head_resp = httpx.head(url, timeout=10.0, follow_redirects=True) if head_resp.status_code == 404: logger.warning(f"File {hf_filename} not found on Hugging Face (404), will try fallback") return False self._download(url, local_path, chunk_size=8192) logger.info(f"Successfully downloaded {local_filename}") except httpx.HTTPStatusError as e: if e.response.status_code == 404: logger.warning(f"File {hf_filename} not found on Hugging Face (404), will try fallback") return False logger.warning(f"HTTP error downloading {hf_filename} from Hugging Face: {e}") if os.path.exists(local_path): os.remove(local_path) return False except Exception as e: logger.warning(f"Failed to download {hf_filename} from Hugging Face: {e}") # If download fails, try to delete partially downloaded file if os.path.exists(local_path): os.remove(local_path) return False # Verify critical files exist if not os.path.exists(os.path.join(extracted_folder, "model.onnx")): logger.error("model.onnx not found after download") return False if not os.path.exists(os.path.join(extracted_folder, "tokenizer.json")): logger.error("tokenizer.json not found after download") return False logger.info("Successfully downloaded all model files from Hugging Face") return True # noqa: TRY300 except Exception: logger.exception("Error downloading from Hugging Face") return False def _forward(self, documents: list[str], batch_size: int = 32) -> npt.NDArray[np.float32]: """ Generate embeddings for a list of documents. Args: documents: The documents to generate embeddings for. batch_size: The batch size to use when generating embeddings. Returns: The embeddings for the documents. """ all_embeddings = [] for i in range(0, len(documents), batch_size): batch = documents[i : i + batch_size] # Encode each document separately encoded = [self.tokenizer.encode(d) for d in batch] # Check if any document exceeds the max tokens for doc_tokens in encoded: if len(doc_tokens.ids) > self.max_tokens(): raise ValueError( f"Document length {len(doc_tokens.ids)} is greater than the max tokens {self.max_tokens()}" ) # Create input arrays exactly like the working standalone script # Create input arrays, ensuring int64 type input_ids = np.array([e.ids for e in encoded], dtype=np.int64) attention_mask = np.array([e.attention_mask for e in encoded], dtype=np.int64) # Ensure 2D arrays (batch_size, seq_length) if input_ids.ndim == 1: input_ids = input_ids.reshape(1, -1) if attention_mask.ndim == 1: attention_mask = attention_mask.reshape(1, -1) # Use zeros_like to create token_type_ids, ensuring exact shape match token_type_ids = np.zeros_like(input_ids, dtype=np.int64) # Ensure all arrays are contiguous, which is important for onnxruntime 1.19.0 input_ids = np.ascontiguousarray(input_ids, dtype=np.int64) attention_mask = np.ascontiguousarray(attention_mask, dtype=np.int64) token_type_ids = np.ascontiguousarray(token_type_ids, dtype=np.int64) onnx_input = { "input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, } model_output = self.model.run(None, onnx_input) last_hidden_state = model_output[0] # Mean pooling (exactly as in the code) # Note: attention_mask needs to be converted to float type for floating point operations attention_mask_float = attention_mask.astype(np.float32) input_mask_expanded = np.broadcast_to(np.expand_dims(attention_mask_float, -1), last_hidden_state.shape) embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip( input_mask_expanded.sum(1), a_min=1e-9, a_max=None ) embeddings = embeddings.astype(np.float32) all_embeddings.append(embeddings) return np.concatenate(all_embeddings) @cached_property def tokenizer(self) -> Any: """ Get the tokenizer for the model. Returns: The tokenizer for the model. """ tokenizer = self.tokenizers.Tokenizer.from_file( os.path.join(self.download_path, self.EXTRACTED_FOLDER_NAME, "tokenizer.json") ) # max_seq_length = 256, for some reason sentence-transformers uses 256 # even though the HF config has a max length of 128 tokenizer.enable_truncation(max_length=256) tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256) # noqa: S106 return tokenizer @cached_property def model(self) -> Any: """ Get the model. Returns: The model. """ if self._preferred_providers is None or len(self._preferred_providers) == 0: if len(self.ort.get_available_providers()) > 0: logger.debug( f"WARNING: No ONNX providers provided, defaulting to available providers: " f"{self.ort.get_available_providers()}" ) self._preferred_providers = self.ort.get_available_providers() elif not set(self._preferred_providers).issubset(set(self.ort.get_available_providers())): raise ValueError( f"Preferred providers must be subset of available providers: {self.ort.get_available_providers()}" ) # Create minimal session options to avoid issues so = self.ort.SessionOptions() so.log_severity_level = 3 # Disable all optimizations that might cause issues so.graph_optimization_level = self.ort.GraphOptimizationLevel.ORT_DISABLE_ALL so.execution_mode = self.ort.ExecutionMode.ORT_SEQUENTIAL so.inter_op_num_threads = 1 so.intra_op_num_threads = 1 if self._preferred_providers and "CoreMLExecutionProvider" in self._preferred_providers: # remove CoreMLExecutionProvider from the list, it is not as well optimized as CPU. self._preferred_providers.remove("CoreMLExecutionProvider") return self.ort.InferenceSession( os.path.join(self.download_path, self.EXTRACTED_FOLDER_NAME, "model.onnx"), # Force CPU execution provider to avoid provider issues providers=["CPUExecutionProvider"], sess_options=so, ) def _download_model_if_not_exists(self) -> None: """ Download from Hugging Face with image mirror if the model doesn't exist. """ onnx_files = [ "config.json", "model.onnx", "special_tokens_map.json", "tokenizer_config.json", "tokenizer.json", "vocab.txt", ] extracted_folder = os.path.join(self.download_path, self.EXTRACTED_FOLDER_NAME) onnx_files_exist = True for f in onnx_files: if not os.path.exists(os.path.join(extracted_folder, f)): onnx_files_exist = False break # Model is not downloaded yet if not onnx_files_exist: os.makedirs(self.download_path, exist_ok=True) logger.info("Attempting to download model from Hugging Face...") hf_endpoint = self._get_hf_endpoint() if not self._download_from_huggingface(): raise RuntimeError( f"Failed to download model from Hugging Face (endpoint: {hf_endpoint}). " f"Please check your network connection or set HF_ENDPOINT environment variable " f"to use a mirror site (e.g., export HF_ENDPOINT='https://hf-mirror.com'). " f"Model ID: {self.hf_model_id}" ) logger.info("Model downloaded successfully from Hugging Face")
[docs] def max_tokens(self) -> int: """Get the maximum number of tokens supported by the model.""" return 256
def __call__(self, documents: Documents) -> Embeddings: """ Generate embeddings for the given documents. Args: documents: Single document (str) or list of documents (List[str]) Returns: List of embedding vectors """ # Handle single string input if isinstance(documents, str): documents = [documents] # Handle empty input if not documents: return [] # Only download the model when it is actually used self._download_model_if_not_exists() # Generate embeddings embeddings = self._forward(documents) # Convert numpy arrays to lists return [embedding.tolist() for embedding in embeddings] def __repr__(self) -> str: return f"OnnxEmbeddingFunction(model_name='{self.model_name}')"