from __future__ import annotations
import logging
from typing import Dict, List, Optional
import torch
from jaxtyping import Float, Int
from torch import Tensor
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from tropt.common import (
ModelOutput,
Targets,
TextTemplates,
)
from tropt.model import (
ClassifierBaseModel,
GradientTokenAccessMixin,
LossTextAccessMixin,
LossTokenAccessMixin,
)
from tropt.model.huggingface.base import (
HuggingFaceBackendModel,
HuggingFaceTokenInputManager,
)
from tropt.model.model_mixins import GradientEmbedAccessMixin
logger = logging.getLogger(__name__)
[docs]
class ClassifierHFModel(
# HF backend first so its `device`/`dtype` win MRO over `BaseModel`'s defaults:
HuggingFaceBackendModel,
ClassifierBaseModel,
# token-level access mixins:
LossTokenAccessMixin,
GradientTokenAccessMixin,
GradientEmbedAccessMixin,
# text-level access mixins:
LossTextAccessMixin,
):
"""HuggingFace sequence classification model wrapper (for models loadable with `AutoModelForSequenceClassification`)."""
def __init__(
self,
model_name: Optional[str] = None,
device: Optional[str] = None,
dtype: Optional[str | torch.dtype] = None,
forward_pass_batch_size: int = 512,
backward_pass_batch_size: int = 28,
loaded_model=None,
set_model_to_train: bool = False,
**kwargs,
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if loaded_model is not None:
self._model = loaded_model
else:
self._model = AutoModelForSequenceClassification.from_pretrained(
model_name,
dtype=dtype or "auto",
**kwargs,
).to(device)
# Set tokenizer and embedding layer:
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
self._embedding_layer = self._model.get_input_embeddings()
@property
def n_classes(self) -> int:
return int(self._model.config.num_labels)
@property
def id2label(self) -> Dict[int, str]:
assert self._model.config.id2label is not None
for k, v in self._model.config.id2label.items():
assert isinstance(k, int) and isinstance(v, str)
return self._model.config.id2label # type: ignore[return-value] (verified by assertions)
# ----------------------- set_inputs_from_tokens -----------------------
# ----------------------- invoke_from_tokens -----------------------
[docs]
def invoke_from_tokens(
self,
input_embeds: Float[Tensor, "bsz seq_len d_model"],
input_attention_mask: Optional[Int[Tensor, "bsz seq_len"]] = None,
count_backward: bool = False,
**kwargs,
) -> ModelOutput:
assert input_embeds is not None
if input_attention_mask is None:
input_attention_mask = torch.ones(
input_embeds.shape[:2], device=input_embeds.device, dtype=torch.int64
)
outputs = self._model(
inputs_embeds=input_embeds,
attention_mask=input_attention_mask,
)
self._update_invoke_stats(
n_tokens=int(input_attention_mask.sum().item()),
n_samples=input_embeds.shape[0],
count_backward=count_backward,
)
return ModelOutput(
output_class_logits=outputs.logits, # (bsz, n_classes)
)
# ----------------------- invoke_from_texts -----------------------
[docs]
def invoke_from_texts(
self,
input_texts: List[str],
**kwargs,
) -> ModelOutput:
assert isinstance(input_texts, list) and all(isinstance(t, str) for t in input_texts)
inputs = self._tokenizer(
input_texts,
return_tensors="pt",
padding=True,
truncation=True,
).to(self.device)
outputs = self._model(**inputs)
self._update_invoke_stats(
n_tokens=int(inputs["attention_mask"].sum().item()),
n_samples=len(input_texts),
)
return ModelOutput(
output_class_logits=outputs.logits, # (bsz, n_classes)
)