Source code for tropt.model.voyage.encoder

from __future__ import annotations
from typing import List, Optional

import torch
from tenacity import (
    retry,
    retry_if_exception,
    stop_after_attempt,
    wait_random_exponential,
)

from tropt.common import ModelOutput
from tropt.model import EncoderBaseModel, LossTextAccessMixin


def _is_transient_voyage_error(e: BaseException) -> bool:
    # Transient: voyageai connection / timeout / 5xx / 429, plus the underlying
    # requests/urllib3/socket errors that surface through them.
    # Non-transient: 4xx (auth, bad request, etc.).
    name = type(e).__name__
    if name in {
        "APIConnectionError",
        "Timeout",
        "ServiceUnavailableError",
        "RateLimitError",
        "ConnectionError",
        "ConnectionResetError",
        "ConnectionAbortedError",
        "ProtocolError",
        "ReadTimeout",
        "WriteTimeout",
        "ConnectTimeout",
        "ChunkedEncodingError",
    }:
        return True
    code = (
        getattr(e, "http_status", None)
        or getattr(e, "status_code", None)
        or getattr(e, "code", None)
    )
    if isinstance(code, int) and (code == 429 or 500 <= code < 600):
        return True
    return False


def _log_voyage_retry(rs):
    e = rs.outcome.exception()
    print(
        f"[voyage retry] {type(e).__name__}: {str(e)[:80]} "
        f"-> sleeping {rs.next_action.sleep:.1f}s (attempt {rs.attempt_number})",
        flush=True,
    )


[docs] class EncoderVoyageModel(EncoderBaseModel, LossTextAccessMixin): """ Voyage AI Encoder model wrapper, with text-query access. https://docs.voyageai.com/docs/embeddings """ def __init__( self, model_name: str = "voyage-4", d_model: int = 1024, **kwargs, ): """ Initializes the Voyage Encoder Model wrapper. Args: model_name: The name of the Voyage embedding model to use. d_model: The dimensionality of the embeddings. For voyage-4 / voyage-3-large: supports 256, 512, 1024 (default), 2048. Note: Requires `os.environ["VOYAGE_API_KEY"]` to be set externally. """ # Import voyageai only when instantiating (optional dependency) import voyageai self._client = voyageai.Client() self.model_name = model_name self._d_model = d_model self._text_to_input_type = { "document": "document", "query": "query", } @property def d_model(self) -> int: return self._d_model
[docs] def invoke_from_texts( self, input_texts: List[str], text_type: Optional[str] = None, **kwargs, ) -> ModelOutput: """ Generates embeddings for the given texts using the Voyage API. Args: input_texts: A list of strings to embed. text_type: The type of text (e.g., "document" or "query") to guide the embedding generation. Returns: A ModelOutput containing the generated embeddings. """ assert text_type in ( None, "document", "query", ), f"Unsupported text_type {text_type}" input_type = self._text_to_input_type.get(text_type) if text_type else None # Voyage's /embeddings caps at 128 texts per call on the newer models; chunk. MAX_BATCH = 128 all_embeddings: list = [] total_tokens = 0 @retry( retry=retry_if_exception(_is_transient_voyage_error), wait=wait_random_exponential(multiplier=1.5, max=60), stop=stop_after_attempt(6), before_sleep=_log_voyage_retry, reraise=True, ) def _embed_chunk(chunk): return self._client.embed( texts=chunk, model=self.model_name, input_type=input_type, output_dimension=self._d_model, ) for start in range(0, len(input_texts), MAX_BATCH): chunk = input_texts[start : start + MAX_BATCH] response = _embed_chunk(chunk) assert response is not None and response.embeddings is not None, ( "voyage embed returned no embeddings" ) all_embeddings.extend(response.embeddings) total_tokens += getattr(response, "total_tokens", 0) result = torch.tensor(all_embeddings, dtype=torch.float32) self._update_invoke_stats( n_tokens=total_tokens, n_samples=len(input_texts), ) return ModelOutput( output_embeddings=result, )