Losses#

Loss Resolution#

The unified entry point for computing any loss — models call this instead of invoking loss functions directly. See Common Types for ModelInput / ModelOutput.

tropt.loss.resolution.resolve_and_compute_loss(model_output, model_input, loss_func)[source]#

Universal loss computation via automatic argument matching.

This function invokes loss_func with the model input (which includes the targets and slices), and output; it returns the computed loss tensor (size bsz).

Notes

  • Since loss functions in TROPT declare their required arguments following the naming

convention of ModelOutput and ModelInput fields, this function can automatically resolve which data to provide to the loss function by inspecting its __call__ signature.

  • In case insufficient arguments are available (e.g., because the model does not provide the required access),

a LossResolutionError is raised with details on what is missing.

  • Parameter Naming Convention:

    Loss functions should name their parameters exactly as they appear in ModelOutput and ModelInput: - full_logits: From ModelOutput - output_embeddings: From ModelOutput - full_attentions: From ModelOutput - input_trigger_ids: From ModelInput - input_slices: From ModelInput - message_targets: From ModelInput - etc.

Parameters:
  • model_output (ModelOutput) – Standardized model output containing available data

  • model_input (ModelInput) – Standardized model input containing triggers, slices, message targets

  • loss_func (BaseLoss) – The loss function to compute (must have proper __call__ signature)

Return type:

Float[Tensor, 'bsz']

Returns:

Loss tensor of shape (bsz,) containing per-sample losses

Raises:
  • LossResolutionError – If required parameter is not found in model data

  • TypeError – If loss function signature is invalid

Examples

>>> from tropt.common import MessageTargets
>>> # Encoder model with SimilarityLoss(output_embeddings, target_embeddings)
>>> output = ModelOutput(output_embeddings=torch.randn(4, 768))
>>> input_data = ModelInput(message_targets=MessageTargets(target_vectors=target_vecs))
>>> loss = resolve_and_compute_loss(output, input_data, SimilarityLoss())
>>> # Language model with PrefillCELoss(prefill_response_logits, message_targets)
>>> output = ModelOutput(prefill_response_logits=torch.randn(2, 50, 32000))
>>> input_data = ModelInput(
...     input_slices=[{SliceKey.APPENDED: slice(40, 50)}] * 2,
...     message_targets=MessageTargets(target_response_toks=target_ids)
... )
>>> loss = resolve_and_compute_loss(output, input_data, PrefillCELoss())
class tropt.loss.resolution.LossResolutionError[source]#

Bases: Exception

Raised when required data is missing for loss computation.

This exception indicates that a loss function requires specific model outputs or inputs that were not provided. The error message should clearly state what parameter is missing and where it should come from.

Examples

>>> raise LossResolutionError(
...     "Loss function requires parameter 'full_logits' but it was not "
...     "found in model_output or model_input."
... )

Loss Classes Interfaces#

class tropt.loss.BaseLoss[source]#

Bases: ABC

Base class for all loss functions.

contains_loss_type(loss_type)[source]#

Returns True if this loss is of the given type. Complicated losses (e.g., CombinedLoss) may override this method with different logic.

Return type:

bool

Parameters:

loss_type (type)

get_loss_log_dict()[source]#

Returns a loggable dict of the last computed loss value, keyed by loss class name. Useful for verbose loss logging in optimizers.

Return type:

dict

is_differentiable: ClassVar[bool] = True#

Whether this loss is back-propable. Set to False for losses that use external models, text generation, or other non-differentiable operations.

require_attentions: ClassVar[bool] = False#

Whether this loss requires the model to return attention weights.

require_first_token_logprobs: ClassVar[bool] = False#

Whether this loss requires first-token log-probabilities from generation.

require_generation: ClassVar[bool] = False#

Whether this loss requires autoregressive generation.

require_gradients: ClassVar[bool] = False#

Whether the loss-ranking path (otherwise run under torch.no_grad) must keep a live autograd graph for this loss. Set True by losses whose value is itself a gradient (e.g. gradient matching).

require_hidden_states: ClassVar[bool] = False#

Whether this loss requires the model to provid the forward pass’s hidden states.

require_target_prefill: ClassVar[bool] = False#

Whether this loss requires the model to prefill the target response tokens (appending them to the input, as a response prefix).

class tropt.loss.AttentionBasedLoss[source]#

Bases: BaseLoss

Loss computed on model attention weights (full_attentions).

require_attentions: ClassVar[bool] = True#

Whether this loss requires the model to return attention weights.

class tropt.loss.EmbeddingBasedLoss[source]#

Bases: BaseLoss

Loss is computed based on model embeddings, compared to given target vectors.

Requires the target vectors (shape: (n_templates, d_model)) to be provided in the targets dict.

class tropt.loss.TextBasedLoss[source]#

Bases: BaseLoss

Marker base for losses that operate on text fields (e.g. input_texts, generated_response_strs).

is_differentiable: ClassVar[bool] = False#

Whether this loss is back-propable. Set to False for losses that use external models, text generation, or other non-differentiable operations.

class tropt.loss.CombinedLoss(loss_funcs, weights=None)[source]#

Bases: BaseLoss

Combines multiple losses with given weights.

Parameters:
  • loss_funcs (List[BaseLoss])

  • weights (Optional[List[float]])

contains_loss_type(loss_type)[source]#

Check if the CombinedLoss contains a loss of the specified type.

Return type:

bool

Parameters:

loss_type (type)

get_loss_log_dict()[source]#

Returns a loggable dict of the last computed loss value (of all the component losses), keyed by loss class name. Useful for verbose loss logging in optimizers.

Return type:

dict

property is_differentiable: bool#

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

property require_attentions: bool#

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

property require_first_token_logprobs: bool#

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

property require_generation: bool#

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

property require_gradients: bool#

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

property require_hidden_states: bool#

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.

property require_target_prefill: bool#

bool(x) -> bool

Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.


Loss Implementations#

class tropt.loss.AttentionEnhLoss(targeted_layers=slice(None, None, None), src_slc_name=SliceKey.TRIGGER, dst_slc_name=SliceKey.INPUT_AFTER)[source]#

Bases: AttentionBasedLoss

Encourages attention from the trigger tokens to the chat template after the adversarial trigger. Note: the sign of the loss is set such that minimizing the loss maximizes the attention.

Enable to instantiate the (different) losses from: https://arxiv.org/abs/2506.12880, https://arxiv.org/abs/2410.09040

Note that it requires setting use_eager_attention=True when loading the model (for explicit attention computations); also it is some slices are not supported when LM prefix caching is enabled, so it should be set to use_prefix_cache=False when loading the model.

Parameters:
dst_slc_name: SliceKey = 'input_after'#
src_slc_name: SliceKey = 'trigger'#
targeted_layers: slice = slice(None, None, None)#
class tropt.loss.BinaryLMJudgeLoss(positive_words=<factory>, negative_words=<factory>, model_name_or_path='HuggingFaceTB/SmolLM2-135M-Instruct', judge_lm_batch_size=256, device='cpu')[source]#

Bases: TextBasedLoss

Abstract base for Yes/No LLM judge losses.

Subclasses implement _create_prompt and __call__. The latter’s implementations should use _compute_scores for batched scoring; return -scores to make minimizing = maximizing YES.

Parameters:
  • positive_words (Set[str])

  • negative_words (Set[str])

  • model_name_or_path (str)

  • judge_lm_batch_size (int)

  • device (str)

device: str = 'cpu'#
judge_lm_batch_size: int = 256#
model_name_or_path: str = 'HuggingFaceTB/SmolLM2-135M-Instruct'#
negative_words: Set[str]#
positive_words: Set[str]#
class tropt.loss.ClassificationBasedLoss[source]#

Bases: BaseLoss

Loss computed on classifier logits (output_class_logits).

class tropt.loss.ExternalTriggerPerplexityLoss(naturalness_prefix='Here is a readable sentence: ', model_name_or_path='google/gemma-2-2b', device='cpu', max_batch_size=256)[source]#

Bases: BaseLoss

Perplexity of trigger under an external LM.

Notes: - Scores the whole sequence.

Parameters:
  • naturalness_prefix (str)

  • model_name_or_path (str)

  • device (str)

  • max_batch_size (int)

device: str = 'cpu'#
is_differentiable: ClassVar[bool] = False#

Whether this loss is back-propable. Set to False for losses that use external models, text generation, or other non-differentiable operations.

max_batch_size: int = 256#
model_name_or_path: str = 'google/gemma-2-2b'#
naturalness_prefix: str = 'Here is a readable sentence: '#
class tropt.loss.FirstTokenNLLLoss(target_token='Sure', missing_logprob_value=-inf)[source]#

Bases: TextBasedLoss

Negative log-likelihood of a target token in the model’s first generated token.

From Andriushchenko et al., “Jailbreaking Leading Safety-Aligned LLMs with Simple Adaptive Attacks” (2024). The model is queried with max_tokens=1 and top_logprobs enabled. The loss is the negative log-probability of target_token among the returned logprobs. If the target token is not in the top-k, its logprob is treated as missing_logprob_value (default -inf), making the loss +inf (worst possible) — matching the paper.

To account for tokenizer quirks (leading space), the lookup tries both target_token and " " + target_token and takes the better one.

Parameters:
  • target_token (str)

  • missing_logprob_value (float)

is_differentiable: ClassVar[bool] = False#

Whether this loss is back-propable. Set to False for losses that use external models, text generation, or other non-differentiable operations.

missing_logprob_value: float = -inf#

Logprob value substituted when the target token is absent from the top-k logprobs. Negated to a loss in __call__; default -inf yields a +inf loss.

require_first_token_logprobs: ClassVar[bool] = True#

Whether this loss requires first-token log-probabilities from generation.

target_token: str = 'Sure'#

First target token whose probability we maximise.

class tropt.loss.GeneratedResponseBasedLoss[source]#

Bases: TextBasedLoss

Marker base for losses that operate on generated_response_strs.

require_generation: ClassVar[bool] = True#

Whether this loss requires autoregressive generation.

class tropt.loss.HiddenStateBasedLoss[source]#

Bases: BaseLoss

Loss computed on model hidden states (full_hidden_states).

require_hidden_states: ClassVar[bool] = True#

Whether this loss requires the model to provid the forward pass’s hidden states.

class tropt.loss.InputFluencyLoss(positive_words=<factory>, negative_words=<factory>, model_name_or_path='HuggingFaceTB/SmolLM2-135M-Instruct', judge_lm_batch_size=256, device='cpu')[source]#

Bases: BinaryLMJudgeLoss

Loss that encourages readable/fluent trigger text (operates on the whole trigger-combined prompt). Minimizing this loss maximizes readability.

https://arxiv.org/abs/2410.02163

Parameters:
  • positive_words (Set[str])

  • negative_words (Set[str])

  • model_name_or_path (str)

  • judge_lm_batch_size (int)

  • device (str)

exception tropt.loss.LossResolutionError[source]#

Bases: Exception

Raised when required data is missing for loss computation.

This exception indicates that a loss function requires specific model outputs or inputs that were not provided. The error message should clearly state what parameter is missing and where it should come from.

Examples

>>> raise LossResolutionError(
...     "Loss function requires parameter 'full_logits' but it was not "
...     "found in model_output or model_input."
... )
class tropt.loss.MisclassCELoss(targeted=False)[source]#

Bases: ClassificationBasedLoss

Encourages misclassification via cross-entropy on classifier logits.

Two modes: - Untargeted (targeted=False): minimizes probability of true_class_idx. - Targeted (targeted=True): maximizes probability of target_class_idx.

The class indices are per-template target data; pass them via Targets(true_class_idx=[…]) or Targets(target_class_idx=[…]) to optimize_trigger. See tropt/common.py.

Parameters:

targeted (bool)

targeted: bool = False#
class tropt.loss.PrefillBasedLoss[source]#

Bases: BaseLoss

Loss computed on prefilled response logits (prefill_response_logits). Requires target tokens (target_response_toks); commonly derived from target_response_strs. Using this loss usually implies that the model will prefill the response with these target tokens.

require_target_prefill: ClassVar[bool] = True#

Whether this loss requires the model to prefill the target response tokens (appending them to the input, as a response prefix).

class tropt.loss.PrefillCELoss(temperature=1.0, clamp_min_nll=None)[source]#

Bases: PrefillBasedLoss

Encourages (=maximize likelihood) the model to produce the target output (mostly an affirmative response).

Loss computed on prefilled response logits (prefill_response_logits). Requires target tokens (target_response_toks); automatically derived from target_response_strs. Using this loss usually implies that the model will prefill the response with these target tokens.

Parameters:
  • temperature (float)

  • clamp_min_nll (float | None)

clamp_min_nll: Optional[float] = None#

Floor on per-token NLL before averaging; tokens already below the floor contribute zero-gradient, freeing the optimizer to focus on “unsolved” positions. Defaults to None (no clamping), otherwise clamped at the given value.

FLRT (https://arxiv.org/abs/2407.17447, Eq. 5) uses -log(0.6) 0.511.

temperature: float = 1.0#

Temperature applied to the prefill logits before softmax.

class tropt.loss.PrefillCWLoss(cw_margin=5.0, first_token_weight=1.0)[source]#

Bases: PrefillBasedLoss

Encourages (=maximize likelihood) the model to produce the target output (mostly an affirmative response). CW-inspired hinge loss on the difference between the largest and the target logits. https://arxiv.org/abs/2402.09674

Loss computed on prefilled response logits (prefill_response_logits). Requires target tokens (target_response_toks); automatically derived from target_response_strs. Using this loss usually implies that the model will prefill the response with these target tokens.

Parameters:
  • cw_margin (float)

  • first_token_weight (float)

cw_margin: float = 5.0#
first_token_weight: float = 1.0#
class tropt.loss.PrefillDistillationLoss(temperature=1.0, reference_temperature=1.0, clamp_min_nll=None)[source]#

Bases: PrefillBasedLoss

Encourage the probability similarity between the model’s prefill logits and reference logits at the target positions. Concretely, it returns cross-entropy between the victim model’s probabilities and a softmax over pre-computed reference logits

Inspired by FLRT for logit-based distillation (https://arxiv.org/abs/2407.17447), where the ref logits come from a jailbroken copy of the victim model.

Loss computed on prefilled response logits (prefill_response_logits). Requires target logits (target_response_logits); commonly derived from a reference model’s output on the same input.

Parameters:
  • temperature (float)

  • reference_temperature (float)

  • clamp_min_nll (float | None)

clamp_min_nll: Optional[float] = None#

Floor on the cross-entropy result to stop optimizing well-matched tokens. Defaults to None (disabled), common value is -log(0.6) ~ 0.51.

reference_temperature: float = 1.0#

Temperature applied to the reference logits before softmax (teacher sharpening).

temperature: float = 1.0#
class tropt.loss.PrefillMellowMaxLoss(mellowmax_alpha=1.0, temperature=1.0)[source]#

Bases: PrefillBasedLoss

Encourages the model to produce the target output by maximizing the mellowmax of the target logits. https://arxiv.org/pdf/1612.05628, http://confirmlabs.org/posts/TDC2023

Loss computed on prefilled response logits (prefill_response_logits). Requires target tokens (target_response_toks); automatically derived from target_response_strs. Using this loss usually implies that the model will prefill the response with these target tokens.

Parameters:
  • mellowmax_alpha (float)

  • temperature (float)

mellowmax_alpha: float = 1.0#
temperature: float = 1.0#
class tropt.loss.ResponseHarmfulnessLoss(positive_words=<factory>, negative_words=<factory>, model_name_or_path='HuggingFaceTB/SmolLM2-135M-Instruct', judge_lm_batch_size=256, device='cpu')[source]#

Bases: BinaryLMJudgeLoss, GeneratedResponseBasedLoss

Loss that encourages harmful model responses (operates on generated response string).

Parameters:
  • positive_words (Set[str])

  • negative_words (Set[str])

  • model_name_or_path (str)

  • judge_lm_batch_size (int)

  • device (str)

class tropt.loss.SimilarityLoss[source]#

Bases: EmbeddingBasedLoss

Encourages given representation(s) to align (cos-sim) with the given target vectors.

class tropt.loss.SteeringActivationLoss(targeted_layers=slice(None, None, None), steer_away=False, slc_name=SliceKey.INPUT_LAST_TOKEN, do_cosine_sim=False, apply_square=False, apply_abs=False)[source]#

Bases: HiddenStateBasedLoss

Encourages hidden activations at specific layers/positions to align with a target direction. - Each message has a target direction vector (optionally its own unique one).

  • target_directions: (n_templates, d_model)

  • Note that the direction will be applied to the whole target positions and layers.

  • Default is steering towards a direction (maximizing alignment).
    • Here, minimizing the loss maximizes alignment (dot product) with the target direction.

    • Set steer_away=True to steer away (e.g., for refusal suppression).

References: - Was proposed as ‘refusal direction suppression’ combined with GCG:

Parameters:
  • targeted_layers (slice) – Which layers to apply steering on (default: all layers)

  • steer_away (bool) – Whether to minimize alignment instead of maximizing (default: False = steer towards)

  • slc_name (SliceKey) – Which token positions to apply steering on (default: “last_input_token”)

  • do_cosine_sim (bool) – Whether to use cosine similarity instead of dot product (default: False)

  • apply_square (bool) – Whether to square the similarity scores (default: False)

  • apply_abs (bool)

apply_abs: bool = False#
apply_square: bool = False#
do_cosine_sim: bool = False#
slc_name: SliceKey = 'input_last_token'#
steer_away: bool = False#
targeted_layers: slice = slice(None, None, None)#
class tropt.loss.TriggerLogitBasedLoss[source]#

Bases: BaseLoss

Loss computed on full-sequence logits (full_logits) sliced to trigger positions. Useful for optimizing properties of the triggers directly.

class tropt.loss.TriggerPerplexityLoss(temperature=1.0, slc_name=SliceKey.TRIGGER)[source]#

Bases: TriggerLogitBasedLoss

Calculates perplexity wrt to the target model logits themselves. Useful for penalizing non-fluent triggers.

Parameters:
  • temperature (float)

  • slc_name (SliceKey)

slc_name: SliceKey = 'trigger'#
temperature: float = 1.0#
tropt.loss.resolve_and_compute_loss(model_output, model_input, loss_func)[source]#

Universal loss computation via automatic argument matching.

This function invokes loss_func with the model input (which includes the targets and slices), and output; it returns the computed loss tensor (size bsz).

Notes

  • Since loss functions in TROPT declare their required arguments following the naming

convention of ModelOutput and ModelInput fields, this function can automatically resolve which data to provide to the loss function by inspecting its __call__ signature.

  • In case insufficient arguments are available (e.g., because the model does not provide the required access),

a LossResolutionError is raised with details on what is missing.

  • Parameter Naming Convention:

    Loss functions should name their parameters exactly as they appear in ModelOutput and ModelInput: - full_logits: From ModelOutput - output_embeddings: From ModelOutput - full_attentions: From ModelOutput - input_trigger_ids: From ModelInput - input_slices: From ModelInput - message_targets: From ModelInput - etc.

Parameters:
  • model_output (ModelOutput) – Standardized model output containing available data

  • model_input (ModelInput) – Standardized model input containing triggers, slices, message targets

  • loss_func (BaseLoss) – The loss function to compute (must have proper __call__ signature)

Return type:

Float[Tensor, 'bsz']

Returns:

Loss tensor of shape (bsz,) containing per-sample losses

Raises:
  • LossResolutionError – If required parameter is not found in model data

  • TypeError – If loss function signature is invalid

Examples

>>> from tropt.common import MessageTargets
>>> # Encoder model with SimilarityLoss(output_embeddings, target_embeddings)
>>> output = ModelOutput(output_embeddings=torch.randn(4, 768))
>>> input_data = ModelInput(message_targets=MessageTargets(target_vectors=target_vecs))
>>> loss = resolve_and_compute_loss(output, input_data, SimilarityLoss())
>>> # Language model with PrefillCELoss(prefill_response_logits, message_targets)
>>> output = ModelOutput(prefill_response_logits=torch.randn(2, 50, 32000))
>>> input_data = ModelInput(
...     input_slices=[{SliceKey.APPENDED: slice(40, 50)}] * 2,
...     message_targets=MessageTargets(target_response_toks=target_ids)
... )
>>> loss = resolve_and_compute_loss(output, input_data, PrefillCELoss())