from __future__ import annotations
import logging
from functools import cached_property
from typing import Annotated, List, Optional
import torch
from jaxtyping import Float
from sentence_transformers import SentenceTransformer
from torch import Tensor
from transformers import PreTrainedModel
from tropt.common import (
ModelOutput,
Targets,
TextTemplates,
)
from tropt.model import (
EncoderBaseModel,
GradientTokenAccessMixin,
LossTextAccessMixin,
LossTokenAccessMixin,
)
from tropt.model.huggingface.base import (
HuggingFaceBackendModel,
HuggingFaceTokenInputManager,
)
from tropt.model.model_mixins import GradientEmbedAccessMixin
logger = logging.getLogger(__name__)
# ======================= Model logic =======================
[docs]
class EncoderHFModel(
# HF backend first so its `device`/`dtype` win MRO over `BaseModel`'s defaults:
HuggingFaceBackendModel,
EncoderBaseModel,
# token-level access mixins:
LossTokenAccessMixin,
GradientTokenAccessMixin,
GradientEmbedAccessMixin,
# text-level access mixins:
LossTextAccessMixin,
):
def __init__(
self,
model_name: Optional[str] = None,
device: Optional[str] = None,
dtype: Optional[str| torch.dtype] = None,
forward_pass_batch_size: int = 512,
backward_pass_batch_size: int = 28,
loaded_model: Optional[SentenceTransformer] = None,
set_model_to_train: bool = False,
run_additional_checks: bool = True,
**kwargs,
):
"""
Wrapper for HuggingFace Sentence Transformer Encoder Model.
Args:
model_name (str): Name of the HuggingFace model. (irrelevant if `loaded_model` is provided)
device (str): Device to load the model onto. If None, defaults to 'cuda' if available else 'cpu'.
dtype (str or torch.dtype): Data type for the model. If None, uses the model's default dtype.
forward_pass_batch_size (int): Batch size for forward passes.
backward_pass_batch_size (int): Batch size for backward passes.
loaded_model (SentenceTransformer, optional): Pre-loaded SentenceTransformer model.
set_model_to_train (bool): Keep the model trainable (train mode + unfrozen weights). Default False (eval + frozen).
run_additional_checks (bool): Whether to run additional checks on the model to ensure compliance with this module computations. Can be disabed for faster initialization, but is good for identfying incompatability issues (mostly with models overriding HF default code).
**kwargs: Additional arguments for SentenceTransformer.
"""
if loaded_model is not None:
assert isinstance(loaded_model, SentenceTransformer), "loaded_model must be a SentenceTransformer instance."
self._model = loaded_model
else:
try:
self._model = SentenceTransformer(
model_name,
device=device,
model_kwargs=dict(dtype=dtype or "auto"),
**kwargs
)
except Exception as e:
logger.error(f"Error loading model `{model_name}`. Please make sure you load the model properly per the HuggingFace model card (e.g., you might need to pass `trust_remote_code=True` to `{self.__class__.__name__}`): {e}")
raise e
# Add tokenizer and embedding layer:
self._tokenizer = self._model.tokenizer
self._embedding_layer = self._get_input_embeddings()
logger.warning("[General Warning:] Common embedding models often require an instruction prefix (e.g., `query: `). For optimal performance, please make sure a suitable one is applied in the textual input templates.")
@cached_property
def _hf_model(self) -> PreTrainedModel:
"""Inner HuggingFace ``PreTrainedModel`` extracted from the
``SentenceTransformer`` wrapper, used for HF-specific introspection
(``.config``, ``.dtype``, FLOP counting, the ``inputs_embeds`` probe).
Note that while the _model may have additional modules (e.g., dense pooling) we assume these are negilible and exclude them here.
"""
try:
inner = self._model._first_module().auto_model
except AttributeError as e:
raise ValueError(
f"Could not extract the inner HuggingFace model from the "
f"SentenceTransformer wrapper for `{self._model_name}`. The first "
f"module is expected to be a `sentence_transformers.models.Transformer` "
) from e
assert isinstance(inner, PreTrainedModel), (
f"Expected the inner ST module to be a transformers.PreTrainedModel, "
f"got {type(inner).__name__}."
)
return inner
@property
def d_model(self):
return self._model.get_sentence_embedding_dimension()
def _get_input_embeddings(self) -> torch.nn.Module:
# this is a bit hacky way to extract the embedding layer from sentence transformers,
# but as models may differ in implementation, we try multiple methods.
# Each function below either extracts the embedding layer, or raises an exception.
def _get_input_emb_v1():
# Should work for most HF encoder models.
return self._hf_model.get_input_embeddings()
def _get_input_emb_v2():
# Special case of NomicBertModel which lacks get_input_embeddings
return self._hf_model.embeddings.word_embeddings
for _get_input_emb in [_get_input_emb_v1, _get_input_emb_v2]:
try:
return _get_input_emb()
except Exception:
continue
raise ValueError(
f"Could not extract embedding layer from Sentence Transformer model `{self._model_name}`. This model might need special care. Please report this issue."
)
# ----------------------- set_inputs_from_tokens -----------------------
# ----------------------- invoke_from_tokens -----------------------
[docs]
def invoke_from_tokens(
self,
input_embeds: Float[Tensor, "bsz seq_len d_model"],
input_attention_mask: Optional[Float[Tensor, "bsz seq_len"]] = None,
count_backward: bool = False,
**kwargs
) -> ModelOutput:
"""Perform a white-box forward pass through the model using input embeddings.
Args:
input_embeds: Input embeddings tensor (bsz, seq_len, d_model).
input_attention_mask: Attention mask tensor (bsz, seq_len).
count_backward: Whether this forward pass will be back-propagated through.
Returns:
ModelOutput: The output from the model.
"""
assert input_embeds is not None, "input_embeds must be provided in invoke_from_tokens."
if input_attention_mask is None:
input_attention_mask = torch.ones(
input_embeds.shape[:-1], device=input_embeds.device, dtype=torch.int64
)
outputs = self._model(
dict(
inputs_embeds=input_embeds, # (bsz, seq_len, embd_dim)
attention_mask=input_attention_mask, # (bsz, seq_len)
)
)
self._update_invoke_stats(
n_tokens=int(input_attention_mask.sum().item()),
n_samples=input_embeds.shape[0],
count_backward=count_backward,
)
output_emb = outputs["sentence_embedding"] # (bsz, d_model)
return ModelOutput(
output_embeddings=output_emb,
)
# ----------------------- invoke_from_texts -----------------------
[docs]
def invoke_from_texts(
self,
input_texts: Annotated[List[str], "n_texts"],
**kwargs,
) -> ModelOutput:
"""
Get the embeddings for the given texts (n_texts elements).
Note: we mostly assume any prompting/instruction will be applied before the call to this function.
"""
assert isinstance(input_texts, list)
emb = self._model.encode(input_texts, convert_to_tensor=True, show_progress_bar=False)
self._update_invoke_stats(
n_tokens=sum(len(ids) for ids in self._tokenizer(input_texts)["input_ids"]),
n_samples=len(input_texts),
)
return ModelOutput(
output_embeddings=emb,
)