Source code for tropt.loss.losses

from __future__ import annotations
"""
General loss functions.

Important note: The losses arguments must match the fields in ModelOutput and ModelInput
for unified loss resolution to work properly.
"""
import logging
from abc import abstractmethod
from dataclasses import dataclass
from typing import ClassVar, Optional

import torch
from jaxtyping import Float, Int
from torch import Tensor

from tropt.common import SliceKey
from tropt.loss.base import BaseLoss
from tropt.loss.utils import IGNORE_INDEX, masked_mean

logger = logging.getLogger(__name__)

############################
[docs] @dataclass class PrefillBasedLoss(BaseLoss): """ Loss computed on prefilled response logits (`prefill_response_logits`). Requires target tokens (`target_response_toks`); commonly derived from `target_response_strs`. Using this loss usually implies that the model will prefill the response with these target tokens. """ require_target_prefill: ClassVar[bool] = True @abstractmethod def __call__( self, prefill_response_logits: Float[Tensor, "bsz response_seq_len vocab_size"], target_response_toks: Int[Tensor, "response_seq_len"], ) -> Float[Tensor, "bsz"]: pass
[docs] @dataclass class PrefillCELoss(PrefillBasedLoss): """ Encourages (=maximize likelihood) the model to produce the target output (mostly an affirmative response). Loss computed on prefilled response logits (`prefill_response_logits`). Requires target tokens (`target_response_toks`); automatically derived from `target_response_strs`. Using this loss usually implies that the model will prefill the response with these target tokens. """ temperature: float = 1.0 """ Temperature applied to the prefill logits before softmax. """ clamp_min_nll: Optional[float] = None """Floor on per-token NLL before averaging; tokens already below the floor contribute zero-gradient, freeing the optimizer to focus on "unsolved" positions. Defaults to None (no clamping), otherwise clamped at the given value. FLRT (https://arxiv.org/abs/2407.17447, Eq. 5) uses ``-log(0.6) ≈ 0.511``. """ def __call__( self, prefill_response_logits: Float[Tensor, "bsz response_seq_len vocab_size"], target_response_toks: Int[Tensor, "response_seq_len"], ) -> Float[Tensor, "bsz"]: target_response_toks = target_response_toks.unsqueeze(0).expand( prefill_response_logits.shape[0], -1 ) # (bsz, response_seq_len) prefill_response_logits = prefill_response_logits / self.temperature assert ( prefill_response_logits.ndim == 3 and target_response_toks.ndim == 2 and prefill_response_logits.shape[1] == target_response_toks.shape[1] ), f"Shape mismatch: prefill_response_logits {prefill_response_logits.shape}, target_response_toks {target_response_toks.shape}" loss = torch.nn.functional.cross_entropy( prefill_response_logits.transpose(-1, -2), # move vocab size (= # classes) to 2nd dim target_response_toks, reduction="none", ignore_index=IGNORE_INDEX, ) # (bsz, seq_len) if self.clamp_min_nll is not None: loss = loss.clamp(min=self.clamp_min_nll) return masked_mean(loss, (target_response_toks != IGNORE_INDEX).float())
[docs] @dataclass class PrefillDistillationLoss(PrefillBasedLoss): """ Encourage the probability similarity between the model's prefill logits and reference logits at the target positions. Concretely, it returns cross-entropy between the victim model's probabilities and a softmax over pre-computed reference logits Inspired by FLRT for logit-based distillation (https://arxiv.org/abs/2407.17447), where the ref logits come from a jailbroken copy of the victim model. Loss computed on prefilled response logits (`prefill_response_logits`). Requires target logits (`target_response_logits`); commonly derived from a reference model's output on the same input. """ temperature: float = 1.0 reference_temperature: float = 1.0 """Temperature applied to the reference logits before softmax (teacher sharpening).""" clamp_min_nll: Optional[float] = None """Floor on the cross-entropy result to stop optimizing well-matched tokens. Defaults to None (disabled), common value is ``-log(0.6) ~ 0.51``.""" def __call__( self, prefill_response_logits: Float[Tensor, "bsz response_seq_len vocab_size"], target_response_logits: Float[Tensor, "response_seq_len vocab_size"], ) -> Float[Tensor, "bsz"]: target_response_logits = target_response_logits.to(prefill_response_logits) assert ( prefill_response_logits.ndim == 3 and target_response_logits.ndim == 2 and prefill_response_logits.shape[1] == target_response_logits.shape[0] and prefill_response_logits.shape[2] == target_response_logits.shape[1] ), ( f"Shape mismatch: prefill_response_logits {prefill_response_logits.shape}, " f"target_response_logits {target_response_logits.shape}" ) bsz, seq_len, vocab = prefill_response_logits.shape # Torch's CE requires probabilities: teacher_probs = torch.nn.functional.softmax( target_response_logits / self.reference_temperature, dim=-1 ) # (seq_len, vocab) teacher_probs = teacher_probs.unsqueeze(0).expand(bsz, -1, -1) # (bsz, seq_len, vocab) # Calculate cross-entropy between teacher probs and student logits (after temperature scaling). per_token = torch.nn.functional.cross_entropy( (prefill_response_logits / self.temperature).transpose(1, 2), # (bsz, vocab, seq_len) teacher_probs.transpose(1, 2), # (bsz, vocab, seq_len) reduction="none", ) # (bsz, seq_len) if self.clamp_min_nll is not None: per_token = per_token.clamp(min=self.clamp_min_nll) return per_token.mean(dim=-1) # (bsz,)
[docs] @dataclass class PrefillMellowMaxLoss(PrefillBasedLoss): """ Encourages the model to produce the target output by maximizing the mellowmax of the target logits. https://arxiv.org/pdf/1612.05628, http://confirmlabs.org/posts/TDC2023 Loss computed on prefilled response logits (`prefill_response_logits`). Requires target tokens (`target_response_toks`); automatically derived from `target_response_strs`. Using this loss usually implies that the model will prefill the response with these target tokens. """ mellowmax_alpha: float = 1.0 temperature: float = 1.0 def __call__( self, prefill_response_logits: Float[Tensor, "bsz response_seq_len vocab_size"], target_response_toks: Int[Tensor, "response_seq_len"], ) -> Float[Tensor, "bsz"]: target_response_toks = target_response_toks.unsqueeze(0).expand( prefill_response_logits.shape[0], -1 ) # (bsz, response_seq_len) prefill_response_logits = prefill_response_logits / self.temperature assert prefill_response_logits.shape[:-1] == target_response_toks.shape, "Shape mismatch" # 1. Create mask mask = target_response_toks != IGNORE_INDEX # replace ignore index with 0 to avoid index error (will be masked later anyway) target_response_toks = target_response_toks.masked_fill(~mask, 0) # 2. Gather the logits corresponding to the target IDs target_logits = prefill_response_logits.gather(-1, target_response_toks.unsqueeze(-1)).squeeze(-1) # Mellowmax maximizes its input, so to maximize the target_logits, # we minimize the negative of the target_logits. target_logits = -target_logits # 3. Prepare inputs for LogSumExp # We want to ignore padded tokens in the sum, so set them to -inf (exp(-inf) = 0) val_for_lse = self.mellowmax_alpha * target_logits val_for_lse = val_for_lse.masked_fill(~mask, float('-inf')) # 4. Calculate valid tokens per sequence n_valid = mask.sum(dim=-1).float().clamp(min=1.0) # Calculate Loss loss = ( 1.0 / self.mellowmax_alpha * ( torch.logsumexp(val_for_lse, dim=-1) - torch.log(n_valid) ) ) return loss # (bsz,)
[docs] @dataclass class PrefillCWLoss(PrefillBasedLoss): """ Encourages (=maximize likelihood) the model to produce the target output (mostly an affirmative response). CW-inspired hinge loss on the difference between the largest and the target logits. https://arxiv.org/abs/2402.09674 Loss computed on prefilled response logits (`prefill_response_logits`). Requires target tokens (`target_response_toks`); automatically derived from `target_response_strs`. Using this loss usually implies that the model will prefill the response with these target tokens. """ cw_margin: float = 5.0 first_token_weight: float = 1.0 def __call__( self, prefill_response_logits: Float[Tensor, "bsz response_seq_len vocab_size"], target_response_toks: Int[Tensor, "response_seq_len"], ) -> Float[Tensor, "bsz"]: target_response_toks = target_response_toks.unsqueeze(0).expand( prefill_response_logits.shape[0], -1 ) # (bsz, response_seq_len) assert prefill_response_logits.shape[:2] == target_response_toks.shape, (prefill_response_logits.shape, target_response_toks.shape) vocab_dim: int = -1 # dimension of vocab size # Create mask and safe indices mask = target_response_toks != IGNORE_INDEX target_response_toks = target_response_toks.masked_fill(~mask, 0) # replace ignore index with 0 to avoid index error (will be masked later anyway) # extract the target's logits (using the target ids as indices) tgt_logits = prefill_response_logits.gather(vocab_dim, target_response_toks.unsqueeze(-1)).squeeze(-1) # Set logits of target tok to -inf so it cannot be the largest tmp_logits = prefill_response_logits.clone() tmp_logits.scatter_(vocab_dim, target_response_toks.unsqueeze(-1), -torch.inf) # pick the largest logit among the non-target tokens largest_non_tgt_logits = tmp_logits.max(vocab_dim).values # calculate the CW loss: loss = largest_non_tgt_logits - tgt_logits loss = loss.clamp_min(-self.cw_margin) # Apply first-token weighting (e.g., to emphasize the affirmative "Sure" token) if self.first_token_weight != 1.0: weights = torch.ones_like(loss) weights[:, 0] = self.first_token_weight loss = loss * weights # Zero out loss for padding tokens loss = loss * mask.float() return masked_mean(loss, mask.float())
############################
[docs] @dataclass class TriggerLogitBasedLoss(BaseLoss): """ Loss computed on full-sequence logits (`full_logits`) sliced to trigger positions. Useful for optimizing properties of the triggers directly. """ @abstractmethod def __call__( self, full_logits: Float[Tensor, "bsz seq_len vocab_size"], input_trigger_ids: Int[Tensor, "trigger_seq_len"], input_slices: dict[SliceKey, slice], ) -> Float[Tensor, "bsz"]: pass
[docs] @dataclass class TriggerPerplexityLoss(TriggerLogitBasedLoss): """ Calculates perplexity wrt to the target model logits themselves. Useful for penalizing non-fluent triggers. """ temperature: float = 1.0 slc_name: SliceKey = SliceKey.TRIGGER # Which slice contains the trigger tokens def __call__( self, full_logits: Float[Tensor, "bsz seq_len vocab_size"], input_trigger_ids: Int[Tensor, "bsz trigger_seq_len"], input_slices: dict[SliceKey, slice], ) -> Float[Tensor, "bsz"]: """ Compute perplexity loss on the trigger tokens. Perplexity is computed wrt to the trigger logits in `full_logits` and the target ids in `input_trigger_ids`. """ # Extract trigger logits slc = input_slices[self.slc_name] assert slc.start > 0, ( f"TriggerPerplexityLoss requires the trigger slice to start after position 0 " f"(got slc.start={slc.start}). This could happen since loss is incompatible with `use_prefix_cache=True` -- try setting it to `False`." ) trigger_logits = full_logits[:, slc.start - 1 : slc.stop - 1, :] # (bsz, trigger_seq_len, vocab_size) trigger_logits = trigger_logits / self.temperature assert ( trigger_logits.ndim == 3 and trigger_logits.shape[:2] == input_trigger_ids.shape[:2] ), f"Shape mismatch: trigger_logits {trigger_logits.shape}, input_trigger_ids {input_trigger_ids.shape}" loss = torch.nn.functional.cross_entropy( trigger_logits.transpose(-1, -2), # (bsz, vocab_size, trigger_seq_len) input_trigger_ids, # (bsz, trigger_seq_len) reduction="none", ignore_index=IGNORE_INDEX, ) # (bsz, trigger_seq_len) ce_loss = masked_mean(loss, (input_trigger_ids != IGNORE_INDEX).float()) return ce_loss
#############################
[docs] @dataclass class AttentionBasedLoss(BaseLoss): """Loss computed on model attention weights (`full_attentions`).""" require_attentions: ClassVar[bool] = True @abstractmethod def __call__( self, full_attentions: Float[Tensor, "bsz n_layers n_heads seq_len[dst] seq_len[src]"], input_slices: dict[SliceKey, slice], ) -> Float[Tensor, "bsz"]: pass
[docs] @dataclass class AttentionEnhLoss(AttentionBasedLoss): """ Encourages attention from the trigger tokens to the chat template after the adversarial trigger. *Note*: the sign of the loss is set such that minimizing the loss maximizes the attention. Enable to instantiate the (different) losses from: https://arxiv.org/abs/2506.12880, https://arxiv.org/abs/2410.09040 Note that it requires setting `use_eager_attention=True` when loading the model (for explicit attention computations); also it is some slices are not supported when LM prefix caching is enabled, so it should be set to `use_prefix_cache=False` when loading the model. """ targeted_layers: slice = slice(None) src_slc_name: SliceKey = SliceKey.TRIGGER dst_slc_name: SliceKey = SliceKey.INPUT_AFTER def __call__( self, full_attentions: Float[Tensor, "bsz n_layers n_heads seq_len[dst] seq_len[src]"], input_slices: dict[SliceKey, slice], ) -> Float[Tensor, "bsz"]: if SliceKey.INPUT_AFTER in (self.src_slc_name, self.dst_slc_name): logger.debug("Note: `INPUT_AFTER` slice is currently only correct for LMs and on suffix attacks. If the usage is different, somethings may break, or worse -- be wrong.") slc_src = input_slices.get(self.src_slc_name, slice(None)) slc_dst = input_slices.get(self.dst_slc_name, slice(None)) # Extract attention weights for target slices and layers attn_subset = full_attentions[ :, self.targeted_layers, :, slc_dst, slc_src ] # (bsz, n_targeted_layers, n_heads, dst_len, src_len) # Mean over all dimensions except batch: layers, heads, dst, src loss = attn_subset.mean(dim=tuple(range(1, attn_subset.ndim))) # (bsz,) loss *= -1 # maximize attention return loss
############################
[docs] @dataclass class EmbeddingBasedLoss(BaseLoss): """Loss is computed based on model embeddings, compared to given target vectors. Requires the target vectors (shape: (n_templates, d_model)) to be provided in the targets dict. """ @abstractmethod def __call__( self, output_embeddings: Float[Tensor, "bsz d_model"], **kwargs, ) -> Float[Tensor, "bsz"]: pass
[docs] @dataclass class SimilarityLoss(EmbeddingBasedLoss): """ Encourages given representation(s) to align (cos-sim) with the given target vectors. """ def __call__( self, output_embeddings: Float[Tensor, "bsz d_model"], target_vectors: Float[Tensor, "d_model"], ) -> Float[Tensor, "bsz"]: target_vectors = target_vectors.unsqueeze(0).expand(output_embeddings.shape[0], -1) # (bsz, d_model) assert output_embeddings.ndim == target_vectors.ndim == 2, "Shape mismatch" target_vectors = target_vectors.to(output_embeddings.device) # normalize: output_embeddings = output_embeddings / output_embeddings.norm(dim=-1, keepdim=True) target_vectors = target_vectors / target_vectors.norm(dim=-1, keepdim=True) # cosine similarity via normalized dot product: cos_sim = (output_embeddings * target_vectors).sum(dim=-1, keepdim=True) loss = -1 * cos_sim # maximize cos-sim <=> minimize (-1 * cos-sim) return loss.squeeze(-1)
############################
[docs] @dataclass class HiddenStateBasedLoss(BaseLoss): """Loss computed on model hidden states (`full_hidden_states`).""" require_hidden_states: ClassVar[bool] = True @abstractmethod def __call__( self, full_hidden_states: Float[Tensor, "bsz n_layers seq_len d_model"], **kwargs, ) -> Float[Tensor, "bsz"]: pass
[docs] @dataclass class SteeringActivationLoss(HiddenStateBasedLoss): """ Encourages hidden activations at specific layers/positions to align with a target direction. - Each message has a target direction vector (optionally its own unique one). - target_directions: (n_templates, d_model) - Note that the direction will be applied to the whole target positions and layers. - Default is steering *towards* a direction (maximizing alignment). - Here, minimizing the loss maximizes alignment (dot product) with the target direction. - Set steer_away=True to steer *away* (e.g., for refusal suppression). References: - Was proposed as 'refusal direction suppression' combined with GCG: https://aclanthology.org/2025.naacl-long.302/ - Was proposed for adapting attacks (e.g., GCG) for evading probe-based classifiers. https://arxiv.org/abs/2412.09565 Args: targeted_layers: Which layers to apply steering on (default: all layers) steer_away: Whether to minimize alignment instead of maximizing (default: False = steer towards) slc_name: Which token positions to apply steering on (default: "last_input_token") do_cosine_sim: Whether to use cosine similarity instead of dot product (default: False) apply_square: Whether to square the similarity scores (default: False) """ targeted_layers: slice = slice(None) steer_away: bool = False slc_name: SliceKey = SliceKey.INPUT_LAST_TOKEN do_cosine_sim: bool = False apply_square: bool = False apply_abs: bool = False def __call__( self, full_hidden_states: Float[Tensor, "bsz n_layers seq_len d_model"], target_directions: Float[Tensor, "d_model"], input_slices: Optional[dict[str, slice]] = None, ) -> Float[Tensor, "bsz"]: """ Compute steering loss by measuring cosine similarity between hidden states and target directions. Args: output_hidden_states: Model hidden states from all layers and positions (bsz, n_layers, seq_len, d_model) target_directions: Direction vectors to align with (, d_model) input_slices: Position slices reflecting the input tokens (dict mapping slice names to slices) Returns: Loss tensor of shape (bsz,). """ target_directions = target_directions.to(full_hidden_states.device) target_directions = target_directions.unsqueeze(0).expand( full_hidden_states.shape[0], -1 ) # (bsz, d_model) # Normalize target directions target_directions = target_directions / target_directions.norm(dim=-1, keepdim=True) # Extract slices for the tokens we want to steer if input_slices is not None: slc = input_slices.get(self.slc_name, slice(None)) else: slc = slice(None) # Use all positions if no slices provided # (bsz, n_targeted_layers, slc_seq_len, d_model) h = full_hidden_states[:, self.targeted_layers, slc, :] if self.do_cosine_sim: h = h / h.norm(dim=-1, keepdim=True) # (bsz, 1, 1, d_model) -> broadcast dot product -> (bsz, n_targeted_layers, slc_seq_len) res = (h * target_directions[:, None, None, :]).sum(dim=-1) # Optionally {abs, square, ..} the similarity scores if self.apply_abs: res = res.abs() if self.apply_square: res = res.pow(2) # Average over layers and positions -> (bsz,) loss = res.mean(dim=(-1, -2)) # Apply sign based on steering direction if not self.steer_away: # Default: steer towards (maximize alignment) # Negate so minimizing loss maximizes dot product loss = -loss # else: steer away (minimize alignment) # Keep positive so minimizing loss minimizes dot product return loss
############################
[docs] @dataclass class ClassificationBasedLoss(BaseLoss): """Loss computed on classifier logits (`output_class_logits`).""" @abstractmethod def __call__( self, output_class_logits: Float[Tensor, "bsz n_classes"], **kwargs, ) -> Float[Tensor, "bsz"]: pass
[docs] @dataclass class MisclassCELoss(ClassificationBasedLoss): """Encourages misclassification via cross-entropy on classifier logits. Two modes: - Untargeted (targeted=False): minimizes probability of `true_class_idx`. - Targeted (targeted=True): maximizes probability of `target_class_idx`. The class indices are per-template target data; pass them via `Targets(true_class_idx=[...])` or `Targets(target_class_idx=[...])` to `optimize_trigger`. See `tropt/common.py`. """ targeted: bool = False def __call__( self, output_class_logits: Float[Tensor, "bsz n_classes"], target_class_idx: Optional[int] = None, true_class_idx: Optional[int] = None, ) -> Float[Tensor, "bsz"]: bsz = output_class_logits.shape[0] device = output_class_logits.device if self.targeted: assert target_class_idx is not None, ( "MisclassCELoss(targeted=True) requires `target_class_idx` " "in Targets / MessageTargets." ) # Minimize CE w.r.t. target class => maximize target class probability target = torch.full( (bsz,), target_class_idx, dtype=torch.long, device=device ) return torch.nn.functional.cross_entropy( output_class_logits, target, reduction="none" ) else: # untargeted mode assert true_class_idx is not None, ( "MisclassCELoss(targeted=False) requires `true_class_idx` " "in Targets / MessageTargets." ) # Negate CE w.r.t. true class => minimize true class probability target = torch.full( (bsz,), true_class_idx, dtype=torch.long, device=device ) return -torch.nn.functional.cross_entropy( output_class_logits, target, reduction="none" )