Source code for tropt.model.flop_counter

from __future__ import annotations
"""FLOP counting utilities for model invoke methods.

FLOP counting is handled at the ``invoke_from_tokens`` / ``invoke_from_texts``
level — these are the model-call bottleneck, and the only place where usage
statistics and FLOPs are tracked. Non-invoke computation in ``compute_*``
methods (tensor stacking, loss aggregation, etc.) is negligible compared to
the model forward/backward passes.

Uses the Kaplan et al. (2020) approximation (https://arxiv.org/abs/2001.08361):
``FLOPs_fwd ≈ 2·N·T``, ``FLOPs_bwd ≈ 4·N·T``.
Cheap and deterministic. Requires ``_model`` to be a HuggingFace
``PreTrainedModel``.
Adapted from https://github.com/romovpa/claudini.

**For model implementers**: invoke methods must call
:meth:`BaseModel._update_invoke_stats` *after* each model call. This handles
both usage statistics and FLOP counting in one place. For backward passes
(gradient methods), pass ``count_backward=True`` to the invoke method, which
propagates to ``_update_invoke_stats``.

Usage::

    model = LMHFModel("meta-llama/Llama-3.2-1B", ...)
    model.set_flop_counting("manual")   # False to disable
    model.reset_usage_stats()
    # ... run optimization ...
    stats = model.get_usage_stats()
    print(stats["usage/total_flops"])
"""

import logging

from transformers import PreTrainedModel

logger = logging.getLogger(__name__)

class FlopCounterBase:
    """Base class for FLOP counting. Subclasses implement specific counting methods."""

    def count_forward(self, n_tokens: int) -> int:
        """Count forward-pass FLOPs for a given number of tokens."""
        raise NotImplementedError

    def count_backward(self, n_tokens: int) -> int:
        """Count backward-pass FLOPs for a given number of tokens."""
        raise NotImplementedError

    def count_forward_backward(self, n_tokens: int) -> int:
        """Count combined forward+backward FLOPs for a given number of tokens."""
        raise NotImplementedError


[docs] class ManualFlopCounter(FlopCounterBase): """Track FLOPs using Kaplan et al. (2020) approximation. FLOPs_fwd ≈ 2 · N_params · n_tokens FLOPs_bwd ≈ 4 · N_params · n_tokens For MoE models, N_params is the *active* parameter count (shared params + expert params scaled by top-k / num_experts). Note that this code may require adaptation once new models come out (e.g., MOE with slightly different API than it currently supports). """ def __init__(self, model: "PreTrainedModel"): from transformers import PreTrainedModel if not isinstance(model, PreTrainedModel): raise TypeError( f"ManualFlopCounter requires a HuggingFace PreTrainedModel, " f"got {type(model).__name__}. " f"Manual FLOP counting is only supported for HF models." ) self.total_params: int = model.num_parameters(exclude_embeddings=True) self.n_params: int = self._compute_active_params(model) logger.info( "ManualFlopCounter: %dM active params (%dM total non-embedding)", self.n_params // 10**6, self.total_params // 10**6, ) # ---- active-param computation (handles dense & MoE) ---- @staticmethod def _compute_active_params(model: "PreTrainedModel") -> int: """Active (per-token) parameter count, accounting for MoE sparsity. For dense models returns total non-embedding params. For MoE models expert parameters are scaled by (num_active / num_experts). Falls back to config-based counting when quantization is detected. """ # Multimodal configs (e.g. Gemma4ForConditionalGeneration) nest the LM fields under `text_config` config = getattr(model.config, "text_config", model.config) num_experts = getattr(config, "num_local_experts", None) or getattr( config, "num_experts", None ) num_active = ( getattr(config, "num_experts_per_tok", None) or getattr(config, "num_selected_experts", None) or getattr(config, "top_k_experts", None) or getattr(config, "top_k", None) ) if not num_experts or not num_active or num_experts <= 1: # Dense model reported = model.num_parameters(exclude_embeddings=True) config_estimate = ManualFlopCounter._params_from_config(config) if config_estimate and reported < config_estimate * 0.5: logger.info( "Quantized model detected: reported %dM params but config says %dM. " "Using config estimate.", reported // 10**6, config_estimate // 10**6, ) return config_estimate return reported # MoE model expert_params = 0 shared_params = 0 for name, param in model.named_parameters(): n = param.numel() if "embed" in name or "lm_head" in name: continue elif "expert" in name: expert_params += n else: shared_params += n config_expert_params = ManualFlopCounter._expert_params_from_config(config) if config_expert_params and expert_params < config_expert_params * 0.1: logger.info( "Quantized MoE detected: named_parameters reports %dM expert params " "but config says %dM. Using config-based counting.", expert_params // 10**6, config_expert_params // 10**6, ) expert_params = config_expert_params config_shared = ManualFlopCounter._shared_params_from_config(config) if config_shared: shared_params = config_shared active_expert_params = int(expert_params * num_active / num_experts) active_params = shared_params + active_expert_params total_non_emb = shared_params + expert_params logger.info( "MoE: %d experts, top-%d active. " "Params: %dM shared + %dM expert (%.0f%% active) = %dM active / %dM total", num_experts, num_active, shared_params // 10**6, expert_params // 10**6, 100 * num_active / num_experts, active_params // 10**6, total_non_emb // 10**6, ) return active_params # ---- config-based fallbacks (for quantized models) ---- @staticmethod def _expert_params_from_config(config) -> int | None: config = getattr(config, "text_config", config) h = getattr(config, "hidden_size", None) # Some MoE configs (e.g. Gemma4) report the expert MLP width in a dedicated field. intermediate = ( getattr(config, "moe_intermediate_size", None) or getattr(config, "intermediate_size", None )) n_layers = getattr(config, "num_hidden_layers", None) num_experts = getattr(config, "num_local_experts", None) or getattr( config, "num_experts", None ) if h is None or intermediate is None or n_layers is None or num_experts is None: return None return 3 * h * intermediate * num_experts * n_layers @staticmethod def _shared_params_from_config(config) -> int | None: config = getattr(config, "text_config", config) h = getattr(config, "hidden_size", None) n_layers = getattr(config, "num_hidden_layers", None) n_heads = getattr(config, "num_attention_heads", None) n_kv_heads = getattr(config, "num_key_value_heads", None) head_dim = getattr(config, "head_dim", None) num_experts = getattr(config, "num_local_experts", None) or getattr( config, "num_experts", None ) if h is None or n_layers is None or n_heads is None: return None if head_dim is None: head_dim = h // n_heads if n_kv_heads is None: n_kv_heads = n_heads # `attention_k_eq_v` means K and V share weights -> one KV projection instead of two. kv_factor = 1 if getattr(config, "attention_k_eq_v", False) else 2 attn = ( h * (n_heads * head_dim) + kv_factor * h * (n_kv_heads * head_dim) + (n_heads * head_dim) * h ) ln = 2 * h router = h * num_experts if num_experts else 0 return (attn + ln + router) * n_layers @staticmethod def _params_from_config(config) -> int | None: config = getattr(config, "text_config", config) h = getattr(config, "hidden_size", None) intermediate = getattr(config, "intermediate_size", None) n_layers = getattr(config, "num_hidden_layers", None) n_heads = getattr(config, "num_attention_heads", None) if h is None or intermediate is None or n_layers is None or n_heads is None: return None head_dim = getattr(config, "head_dim", h // n_heads) n_kv_heads = getattr(config, "num_key_value_heads", n_heads) kv_factor = 1 if getattr(config, "attention_k_eq_v", False) else 2 attn = ( h * (n_heads * head_dim) + kv_factor * h * (n_kv_heads * head_dim) + (n_heads * head_dim) * h ) mlp = 3 * h * intermediate ln = 2 * h return (attn + mlp + ln) * n_layers # ---- counting helpers ----
[docs] def count_forward(self, n_tokens: int) -> int: return 2 * self.n_params * n_tokens
[docs] def count_backward(self, n_tokens: int) -> int: return 4 * self.n_params * n_tokens
[docs] def count_forward_backward(self, n_tokens: int) -> int: return 6 * self.n_params * n_tokens
# --------------------------------------------------------------------------- # [DISABLED] Torch automatic FLOP-counting. Disabled due to potential instability and limitations. # --------------------------------------------------------------------------- # def track_flops(includes_backward: bool = False): # """Decorator for ``compute_*`` methods — dispatches FLOP counting by mode. # ``"torch"`` — wraps the method with ``FlopCounterMode``. # ``"manual"`` — measures the delta of ``_token_used`` before/after the # method, then applies the Kaplan approximation. Requires a # ``_kaplan_flop_counter`` (:class:`KaplanFlopCounter`) on the model — # provided automatically by :class:`HuggingFaceBackendModel`. # Args: # includes_backward: The wrapped method includes a backward pass # (``"manual"`` uses ``6·N·T`` instead of ``2·N·T``). # """ # def decorator(method): # @wraps(method) # def wrapper(self, *args, **kwargs): # mode = getattr(self, "count_flops", False) # if not mode: # return method(self, *args, **kwargs) # # --- torch mode --- # if mode == "torch": # from torch.utils.flop_counter import FlopCounterMode as _FlopCounterMode # with _FlopCounterMode(display=False) as flop_counter: # result = method(self, *args, **kwargs) # self._update_usage_stats(flops=flop_counter.get_total_flops()) # return result # # --- manual mode --- # if mode == "manual": # counter = getattr(self, "_kaplan_flop_counter", None) # if counter is None: # logger.warning( # "count_flops='manual' but no KaplanFlopCounter on model; " # "skipping FLOP counting." # ) # return method(self, *args, **kwargs) # tokens_before = getattr(self, "_token_used", 0) # result = method(self, *args, **kwargs) # tokens_after = getattr(self, "_token_used", 0) # delta_tokens = tokens_after - tokens_before # if includes_backward: # flops = counter.count_forward_backward(delta_tokens) # else: # flops = counter.count_forward(delta_tokens) # self._update_usage_stats(flops=flops) # return result # raise ValueError( # f"Unknown count_flops mode: {mode!r}. Use False, 'torch', or 'manual'." # ) # return wrapper # return decorator