Source code for tropt.model.model_base

from __future__ import annotations
"""
Base definitions, classes, and mixins for targeted text models.
"""

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional

import torch
from jaxtyping import Float, Int
from torch import Tensor
from transformers import BatchEncoding

from tropt.common import MessageTargets, ModelOutput
from tropt.model.flop_counter import FlopCounterBase, ManualFlopCounter

# ====================== Model Base Classes =======================


[docs] class BaseModel(ABC): def __init__(self, model_name: str): pass @property def device(self): """ Defaults to available accelerator. Should be overriden if the model is local, and on a specific device. """ return 'cuda' if torch.cuda.is_available() else 'cpu' @abstractmethod def __call__(self, *args, **kwargs): """ Forward pass through the model, returns model default output (e.g., text response for LMs). - This method sould also update the usage stats (e.g., token counts, forward call counts, etc.) """ raise NotImplementedError # ... prepare inputs methods will be added upon expansion ... # ... compute loss/grad/... methods will be added upon expansion ... # ... usage stats methods will be added upon expansion ...
[docs] def get_model_name(self) -> str: """Returns the model identifier string.""" return getattr(self, "_model_name", getattr(self, "model_name", type(self).__name__))
_model: Optional[Any] = None """The underlying model object. Set by subclass ``__init__``. For HuggingFace models this is a ``PreTrainedModel``; for black-box models it may be ``None``. """ _forward_pass_batch_size: int = 4096 """Starting batch size for forward-pass compute methods (loss / logits / ... ). Used as ``starting_batch_size`` for ``find_executable_batch_size`` in ``compute_loss_from_*`` methods. Intentionally set high -- on GPU backends it is automatically halved on OOM, and for API backends it simply caps the chunk size sent to ``invoke_from_texts`` per step. Any subclass may override (at the class level or in ``__init__``). """ _backward_pass_batch_size: int = 512 """Starting batch size for backward-pass compute methods (gradients). Used as ``starting_batch_size`` for ``find_executable_batch_size`` in ``compute_grad_from_*`` methods. Intentionally set high -- automatically halved on OOM. Any subclass may override. """ _flop_counter: Optional[FlopCounterBase] = None """Active counter object (set by :meth:`set_flop_counting`). Must implement ``count_forward(n_tokens) -> int`` and ``count_forward_backward(n_tokens) -> int``. """
[docs] def set_flop_counting(self, mode: Literal["manual", "none"] = "manual"): """Enable or disable FLOP counting. Args: mode: The method to use for counting FLOPs. Options: -> "manual": Uses a `ManualFlopCounter` that estimates FLOPs based on token counts and model architecture (follows Kaplan et al. 2020). Requires the model to expose an inner HuggingFace ``PreTrainedModel`` (via ``_hf_model`` on HF backends). This option the default. -> "none": Disables FLOP counting. Note: - FLOP counting will appear in :meth:`get_usage_stats` under ``"usage/total_flops"``. - Since the optimizer logs all entries under `model.get_usage_stats()`, this makes self.log() automatically include FLOP counts in all optimizer logs, without needing to explicitly log it in each optimizer method. - FLOPs are counted at the model `invoke_from_tokens` / `invoke_from_texts` level only -- the cost of optimizer-internal and loss-internal computation (e.g. candidate sampling, sorting, Gumbel draws) is knowingly excluded. """ if mode == "manual": hf_model = self._hf_model self._flop_counter = ManualFlopCounter(hf_model) else: self._flop_counter = None
[docs] def get_usage_stats(self) -> Dict[str, int]: """Returns summary of model usage statistics for logging.""" stats: dict[str, Any | int] = { "total_tokens": getattr(self, "_token_used", 0), "forward_calls": getattr(self, "_forward_call_count", 0), "forward_samples": getattr(self, "_forward_sample_count", 0), "grad_calls": getattr(self, "_grad_call_count", 0), "grad_samples": getattr(self, "_grad_sample_count", 0), } if self._flop_counter is not None: stats["total_flops"] = getattr(self, "_total_flops", 0) return stats
def _update_invoke_stats( self, *, n_tokens: int, n_samples: int, count_backward: bool = False, ): """Record usage stats and FLOPs for a single model invocation. Call this inside ``invoke_from_tokens`` / ``invoke_from_texts`` after each *raw* model call (!), to avoid double-counting. It handles forward/backward counters **and** FLOP computation. Args: n_tokens: Total tokens processed in this invocation. n_samples: Batch size (number of sequences). count_backward: Whether this forward pass will also be back-propagated through (set by gradient methods). """ flops = 0 if self._flop_counter is not None: if count_backward: flops = self._flop_counter.count_forward_backward(n_tokens) else: flops = self._flop_counter.count_forward(n_tokens) self._update_usage_stats( tokens=n_tokens, forward_calls=1, forward_samples=n_samples, flops=flops, grad_calls=1 if count_backward else 0, grad_samples=n_samples if count_backward else 0, ) def _update_usage_stats( self, tokens: int = 0, forward_calls: int = 0, forward_samples: int = 0, grad_calls: int = 0, grad_samples: int = 0, flops: int = 0, ): """Low-level accumulator. Prefer :meth:`_update_invoke_stats`.""" if not hasattr(self, "_token_used"): self._token_used = 0 self._forward_call_count = 0 self._forward_sample_count = 0 self._grad_call_count = 0 self._grad_sample_count = 0 self._total_flops = 0 self._token_used += tokens self._forward_call_count += forward_calls self._forward_sample_count += forward_samples self._grad_call_count += grad_calls self._grad_sample_count += grad_samples self._total_flops += flops
[docs] def reset_usage_stats(self): """Resets the usage statistics.""" self._token_used = 0 self._forward_call_count = 0 self._forward_sample_count = 0 self._grad_call_count = 0 self._grad_sample_count = 0 self._total_flops = 0
## -------- Base models by model type ------- ##
[docs] class LMBaseModel(BaseModel): """Language model base class.""" def __call__( self, input_texts: List[str], return_full_output: bool = False, **kwargs ) -> List[str] | ModelOutput: """ Generate text completions for the given input texts. Args: input_texts (List[str]): List of input prompt strings to generate completions for. return_full_output (bool): If True, returns the full ModelOutput. If False, returns just the generated response strings. """ result: ModelOutput = self.invoke_from_texts(input_texts=input_texts, require_generation=True, **kwargs) if return_full_output: return result assert result.generated_response_strs is not None, "Model did not generate response strings" return result.generated_response_strs
[docs] @abstractmethod def invoke_from_texts( self, input_texts: List[str], message_targets: Optional[MessageTargets] = None, require_target_prefill: bool = False, require_generation: bool = False, **kwargs ) -> ModelOutput: """ Generates text completions for the given input texts. Args: input_texts (List[str]): List of input strings. message_targets (Optional[MessageTargets]): Targets for the messages. require_target_prefill (bool): Whether to prefill the target response from `message_targets`, and return the corresponding logits (e.g., for LMs). require_generation (bool): Whether to perform autoregressive generation after the forward pass (for LMs). Always returns ModelOutput with at least `generated_response_strs` populated. This method also updates the usage stats (e.g., token counts, forward call counts, etc.). It must call `_update_invoke_stats` after each raw model call. """ raise NotImplementedError
[docs] class EncoderBaseModel(BaseModel): """Encoder model base class.""" def __call__( self, input_texts: List[str], return_full_output: bool = False, **kwargs ) -> Float[Tensor, "n_texts d_model"] | ModelOutput: """ Computes encoder embeddings for the given input texts. Args: input_texts (List[str]): List of input strings to compute embeddings for. return_full_output (bool): If True, returns the full ModelOutput. If False, returns just the output embeddings. """ result: ModelOutput = self.invoke_from_texts(input_texts=input_texts, **kwargs) if return_full_output: return result assert result.output_embeddings is not None, "Model did not return embeddings" return result.output_embeddings
[docs] @abstractmethod def invoke_from_texts( self, input_texts: List[str], **kwargs ) -> ModelOutput: """ Computes encoder embeddings for the given input texts. Always returns ModelOutput with at least `output_embeddings` populated. This method also updates the usage stats (e.g., token counts, forward call counts, etc.). It must call `_update_invoke_stats` after each raw model call. """ raise NotImplementedError
@property @abstractmethod def d_model(self) -> int: """Returns the dimensionality of the output embeddings.""" raise NotImplementedError
[docs] class ClassifierBaseModel(BaseModel): """Classifier model base class.""" def __call__( self, input_texts: List[str], return_full_output: bool = False, **kwargs ) -> Float[Tensor, "n_texts n_classes"] | ModelOutput: """ Compute classification logits for the given input texts. Args: input_texts (List[str]): List of input strings to classify. return_full_output (bool): If True, returns the full ModelOutput. If False, returns just the class logits. """ result: ModelOutput = self.invoke_from_texts(input_texts=input_texts, **kwargs) if return_full_output: return result assert result.output_class_logits is not None, "Model did not return class logits" return result.output_class_logits
[docs] @abstractmethod def invoke_from_texts( self, input_texts: List[str], **kwargs ) -> ModelOutput: """ Compute classification logits for the given input texts. Always returns ModelOutput with at least `output_class_logits` populated. This method also updates the usage stats. It must call `_update_invoke_stats` after each raw model call. """ raise NotImplementedError
@property @abstractmethod def n_classes(self) -> int: """Returns the number of output classes.""" raise NotImplementedError
# ====================== Tokenzier base classes ===================
[docs] class BaseTokenizer(ABC): """ Abstract base class for tokenizers to ensure a unified interface compatible with Hugging Face-style usage. """ @property @abstractmethod def vocab_size(self) -> int: """Returns the size of the vocabulary.""" pass @abstractmethod def __call__( self, text: List[str], return_tensors: Literal["list", "pt", "np"] = "list", **kwargs, ) -> BatchEncoding: """Tokenize a batch of strings. ``text`` must be a list (single strings rejected; use :meth:`encode`).""" pass
[docs] @abstractmethod def decode(self, ids: int | List[int] | torch.Tensor, **kwargs) -> str: """Converts token IDs back to a string.""" pass
[docs] @abstractmethod def encode(self, text: str | List[str], **kwargs) -> List[int]: """Converts a string to token IDs.""" pass
[docs] @abstractmethod def batch_decode(self, ids: List[int] | List[List[int]] | torch.Tensor, **kwargs) -> List[str]: """Converts a batch of token IDs back to a list of strings.""" pass
@property @abstractmethod def name_or_path(self) -> str: pass # --- Specific helpers (concrete, build on abstract primitives above) ---
[docs] def encode_trigger(self, trigger_str: str) -> Int[Tensor, "trigger_seq_len"]: """Encode a trigger string -> 1-D tensor (no special tokens).""" ids = self.encode(trigger_str, add_special_tokens=False) return torch.tensor(ids, dtype=torch.int64)
[docs] def decode_trigger(self, trigger_ids: Int[Tensor, "trigger_seq_len"]) -> str: """Decode a 1-D trigger ids tensor -> string (special tokens skipped).""" return self.decode(trigger_ids, skip_special_tokens=True)
[docs] def decode_triggers(self, trigger_ids: Int[Tensor, "bsz trigger_seq_len"]) -> List[str]: """Batch-decode a 2-D trigger ids tensor -> list of strings (special tokens skipped).""" return self.batch_decode(trigger_ids, skip_special_tokens=True)
# --- Equality for tokenizer comparison --- def __eq__(self, other: object) -> bool: """ To compare tokenizers; useful for optimizers to check if an identical vocabulary is expected. - Default falls back to `name_or_path` (a reasonable proxy for "same tokenizer") - Override in subclass for different logic. """ if not isinstance(other, BaseTokenizer): return NotImplemented if self is other: return True my_name = self.name_or_path if my_name == "unknown" or other.name_or_path == "unknown": return False return my_name == other.name_or_path
[docs] class HFTokenizerWrapper(BaseTokenizer): """Wraps a HuggingFace PreTrainedTokenizerBase, exposing it as a BaseTokenizer. All attributes not defined here are transparently forwarded to the underlying HF tokenizer, so existing code that accesses tokenizer internals (padding_side, apply_chat_template, etc.) continues to work unchanged. """ def __init__(self, tokenizer): # Store without triggering __getattr__ object.__setattr__(self, "_hf_tokenizer", tokenizer) def __getattr__(self, name: str): return getattr(object.__getattribute__(self, "_hf_tokenizer"), name) # --- BaseTokenizer abstract method implementations --- @property def vocab_size(self) -> int: return self._hf_tokenizer.vocab_size def __call__(self, text, return_tensors=None, **kwargs): assert isinstance(text, list), "BaseTokenizer.__call__ requires a list of strings." # "list" in our abstraction maps to None (Python objects) in HF if return_tensors == "list": return_tensors = None return self._hf_tokenizer(text, return_tensors=return_tensors, **kwargs)
[docs] def decode(self, ids, **kwargs) -> str: return self._hf_tokenizer.decode(ids, **kwargs)
[docs] def encode(self, text, **kwargs) -> List[int]: return self._hf_tokenizer.encode(text, **kwargs)
[docs] def batch_decode(self, ids, **kwargs) -> List[str]: return self._hf_tokenizer.batch_decode(ids, **kwargs)
@property def name_or_path(self) -> str: return self._hf_tokenizer.name_or_path