Losses#
Loss Resolution#
The unified entry point for computing any loss — models call this instead of invoking loss functions directly. See Common Types for ModelInput / ModelOutput.
- tropt.loss.resolution.resolve_and_compute_loss(model_output, model_input, loss_func)[source]#
Universal loss computation via automatic argument matching.
This function invokes loss_func with the model input (which includes the targets and slices), and output; it returns the computed loss tensor (size bsz).
Notes
Since loss functions in TROPT declare their required arguments following the naming
convention of ModelOutput and ModelInput fields, this function can automatically resolve which data to provide to the loss function by inspecting its __call__ signature.
In case insufficient arguments are available (e.g., because the model does not provide the required access),
a LossResolutionError is raised with details on what is missing.
- Parameter Naming Convention:
Loss functions should name their parameters exactly as they appear in ModelOutput and ModelInput: - full_logits: From ModelOutput - output_embeddings: From ModelOutput - full_attentions: From ModelOutput - input_trigger_ids: From ModelInput - input_slices: From ModelInput - message_targets: From ModelInput - etc.
- Parameters:
model_output (
ModelOutput) – Standardized model output containing available datamodel_input (
ModelInput) – Standardized model input containing triggers, slices, message targetsloss_func (
BaseLoss) – The loss function to compute (must have proper __call__ signature)
- Return type:
Float[Tensor, 'bsz']- Returns:
Loss tensor of shape (bsz,) containing per-sample losses
- Raises:
LossResolutionError – If required parameter is not found in model data
TypeError – If loss function signature is invalid
Examples
>>> from tropt.common import MessageTargets >>> # Encoder model with SimilarityLoss(output_embeddings, target_embeddings) >>> output = ModelOutput(output_embeddings=torch.randn(4, 768)) >>> input_data = ModelInput(message_targets=MessageTargets(target_vectors=target_vecs)) >>> loss = resolve_and_compute_loss(output, input_data, SimilarityLoss())
>>> # Language model with PrefillCELoss(prefill_response_logits, message_targets) >>> output = ModelOutput(prefill_response_logits=torch.randn(2, 50, 32000)) >>> input_data = ModelInput( ... input_slices=[{SliceKey.APPENDED: slice(40, 50)}] * 2, ... message_targets=MessageTargets(target_response_toks=target_ids) ... ) >>> loss = resolve_and_compute_loss(output, input_data, PrefillCELoss())
- class tropt.loss.resolution.LossResolutionError[source]#
Bases:
ExceptionRaised when required data is missing for loss computation.
This exception indicates that a loss function requires specific model outputs or inputs that were not provided. The error message should clearly state what parameter is missing and where it should come from.
Examples
>>> raise LossResolutionError( ... "Loss function requires parameter 'full_logits' but it was not " ... "found in model_output or model_input." ... )
Loss Classes Interfaces#
- class tropt.loss.BaseLoss[source]#
Bases:
ABCBase class for all loss functions.
- contains_loss_type(loss_type)[source]#
Returns True if this loss is of the given type. Complicated losses (e.g., CombinedLoss) may override this method with different logic.
- Return type:
bool- Parameters:
loss_type (type)
- get_loss_log_dict()[source]#
Returns a loggable dict of the last computed loss value, keyed by loss class name. Useful for verbose loss logging in optimizers.
- Return type:
dict
-
is_differentiable:
ClassVar[bool] = True# Whether this loss is back-propable. Set to False for losses that use external models, text generation, or other non-differentiable operations.
-
require_attentions:
ClassVar[bool] = False# Whether this loss requires the model to return attention weights.
-
require_first_token_logprobs:
ClassVar[bool] = False# Whether this loss requires first-token log-probabilities from generation.
-
require_generation:
ClassVar[bool] = False# Whether this loss requires autoregressive generation.
-
require_gradients:
ClassVar[bool] = False# Whether the loss-ranking path (otherwise run under
torch.no_grad) must keep a live autograd graph for this loss. Set True by losses whose value is itself a gradient (e.g. gradient matching).
Whether this loss requires the model to provid the forward pass’s hidden states.
-
require_target_prefill:
ClassVar[bool] = False# Whether this loss requires the model to prefill the target response tokens (appending them to the input, as a response prefix).
- class tropt.loss.AttentionBasedLoss[source]#
Bases:
BaseLossLoss computed on model attention weights (full_attentions).
-
require_attentions:
ClassVar[bool] = True# Whether this loss requires the model to return attention weights.
-
require_attentions:
- class tropt.loss.EmbeddingBasedLoss[source]#
Bases:
BaseLossLoss is computed based on model embeddings, compared to given target vectors.
Requires the target vectors (shape: (n_templates, d_model)) to be provided in the targets dict.
- class tropt.loss.TextBasedLoss[source]#
Bases:
BaseLossMarker base for losses that operate on text fields (e.g. input_texts, generated_response_strs).
- is_differentiable: ClassVar[bool] = False#
Whether this loss is back-propable. Set to False for losses that use external models, text generation, or other non-differentiable operations.
- class tropt.loss.CombinedLoss(loss_funcs, weights=None)[source]#
Bases:
BaseLossCombines multiple losses with given weights.
- Parameters:
loss_funcs (List[BaseLoss])
weights (Optional[List[float]])
- contains_loss_type(loss_type)[source]#
Check if the CombinedLoss contains a loss of the specified type.
- Return type:
bool- Parameters:
loss_type (type)
- get_loss_log_dict()[source]#
Returns a loggable dict of the last computed loss value (of all the component losses), keyed by loss class name. Useful for verbose loss logging in optimizers.
- Return type:
dict
- property is_differentiable: bool#
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
- property require_attentions: bool#
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
- property require_first_token_logprobs: bool#
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
- property require_generation: bool#
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
- property require_gradients: bool#
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
- property require_target_prefill: bool#
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
Loss Implementations#
- class tropt.loss.AttentionEnhLoss(targeted_layers=slice(None, None, None), src_slc_name=SliceKey.TRIGGER, dst_slc_name=SliceKey.INPUT_AFTER)[source]#
Bases:
AttentionBasedLossEncourages attention from the trigger tokens to the chat template after the adversarial trigger. Note: the sign of the loss is set such that minimizing the loss maximizes the attention.
Enable to instantiate the (different) losses from: https://arxiv.org/abs/2506.12880, https://arxiv.org/abs/2410.09040
Note that it requires setting use_eager_attention=True when loading the model (for explicit attention computations); also it is some slices are not supported when LM prefix caching is enabled, so it should be set to use_prefix_cache=False when loading the model.
-
targeted_layers:
slice= slice(None, None, None)#
-
targeted_layers:
- class tropt.loss.BinaryLMJudgeLoss(positive_words=<factory>, negative_words=<factory>, model_name_or_path='HuggingFaceTB/SmolLM2-135M-Instruct', judge_lm_batch_size=256, device='cpu')[source]#
Bases:
TextBasedLossAbstract base for Yes/No LLM judge losses.
Subclasses implement _create_prompt and __call__. The latter’s implementations should use _compute_scores for batched scoring; return -scores to make minimizing = maximizing YES.
- Parameters:
positive_words (Set[str])
negative_words (Set[str])
model_name_or_path (str)
judge_lm_batch_size (int)
device (str)
- device: str = 'cpu'#
- judge_lm_batch_size: int = 256#
- model_name_or_path: str = 'HuggingFaceTB/SmolLM2-135M-Instruct'#
- negative_words: Set[str]#
- positive_words: Set[str]#
- class tropt.loss.ClassificationBasedLoss[source]#
Bases:
BaseLossLoss computed on classifier logits (output_class_logits).
- class tropt.loss.ExternalTriggerPerplexityLoss(naturalness_prefix='Here is a readable sentence: ', model_name_or_path='google/gemma-2-2b', device='cpu', max_batch_size=256)[source]#
Bases:
BaseLossPerplexity of trigger under an external LM.
Notes: - Scores the whole sequence.
- Parameters:
naturalness_prefix (str)
model_name_or_path (str)
device (str)
max_batch_size (int)
- device: str = 'cpu'#
- is_differentiable: ClassVar[bool] = False#
Whether this loss is back-propable. Set to False for losses that use external models, text generation, or other non-differentiable operations.
- max_batch_size: int = 256#
- model_name_or_path: str = 'google/gemma-2-2b'#
- naturalness_prefix: str = 'Here is a readable sentence: '#
- class tropt.loss.FirstTokenNLLLoss(target_token='Sure', missing_logprob_value=-inf)[source]#
Bases:
TextBasedLossNegative log-likelihood of a target token in the model’s first generated token.
From Andriushchenko et al., “Jailbreaking Leading Safety-Aligned LLMs with Simple Adaptive Attacks” (2024). The model is queried with
max_tokens=1andtop_logprobsenabled. The loss is the negative log-probability oftarget_tokenamong the returned logprobs. If the target token is not in the top-k, its logprob is treated asmissing_logprob_value(default-inf), making the loss+inf(worst possible) — matching the paper.To account for tokenizer quirks (leading space), the lookup tries both
target_tokenand" " + target_tokenand takes the better one.- Parameters:
target_token (str)
missing_logprob_value (float)
- is_differentiable: ClassVar[bool] = False#
Whether this loss is back-propable. Set to False for losses that use external models, text generation, or other non-differentiable operations.
- missing_logprob_value: float = -inf#
Logprob value substituted when the target token is absent from the top-k logprobs. Negated to a loss in
__call__; default-infyields a+infloss.
- require_first_token_logprobs: ClassVar[bool] = True#
Whether this loss requires first-token log-probabilities from generation.
- target_token: str = 'Sure'#
First target token whose probability we maximise.
- class tropt.loss.GeneratedResponseBasedLoss[source]#
Bases:
TextBasedLossMarker base for losses that operate on generated_response_strs.
- require_generation: ClassVar[bool] = True#
Whether this loss requires autoregressive generation.
- class tropt.loss.HiddenStateBasedLoss[source]#
Bases:
BaseLossLoss computed on model hidden states (full_hidden_states).
Whether this loss requires the model to provid the forward pass’s hidden states.
- class tropt.loss.InputFluencyLoss(positive_words=<factory>, negative_words=<factory>, model_name_or_path='HuggingFaceTB/SmolLM2-135M-Instruct', judge_lm_batch_size=256, device='cpu')[source]#
Bases:
BinaryLMJudgeLossLoss that encourages readable/fluent trigger text (operates on the whole trigger-combined prompt). Minimizing this loss maximizes readability.
https://arxiv.org/abs/2410.02163
- Parameters:
positive_words (Set[str])
negative_words (Set[str])
model_name_or_path (str)
judge_lm_batch_size (int)
device (str)
- exception tropt.loss.LossResolutionError[source]#
Bases:
ExceptionRaised when required data is missing for loss computation.
This exception indicates that a loss function requires specific model outputs or inputs that were not provided. The error message should clearly state what parameter is missing and where it should come from.
Examples
>>> raise LossResolutionError( ... "Loss function requires parameter 'full_logits' but it was not " ... "found in model_output or model_input." ... )
- class tropt.loss.MisclassCELoss(targeted=False)[source]#
Bases:
ClassificationBasedLossEncourages misclassification via cross-entropy on classifier logits.
Two modes: - Untargeted (targeted=False): minimizes probability of true_class_idx. - Targeted (targeted=True): maximizes probability of target_class_idx.
The class indices are per-template target data; pass them via Targets(true_class_idx=[…]) or Targets(target_class_idx=[…]) to optimize_trigger. See tropt/common.py.
- Parameters:
targeted (bool)
-
targeted:
bool= False#
- class tropt.loss.PrefillBasedLoss[source]#
Bases:
BaseLossLoss computed on prefilled response logits (prefill_response_logits). Requires target tokens (target_response_toks); commonly derived from target_response_strs. Using this loss usually implies that the model will prefill the response with these target tokens.
-
require_target_prefill:
ClassVar[bool] = True# Whether this loss requires the model to prefill the target response tokens (appending them to the input, as a response prefix).
-
require_target_prefill:
- class tropt.loss.PrefillCELoss(temperature=1.0, clamp_min_nll=None)[source]#
Bases:
PrefillBasedLossEncourages (=maximize likelihood) the model to produce the target output (mostly an affirmative response).
Loss computed on prefilled response logits (prefill_response_logits). Requires target tokens (target_response_toks); automatically derived from target_response_strs. Using this loss usually implies that the model will prefill the response with these target tokens.
- Parameters:
temperature (float)
clamp_min_nll (float | None)
-
clamp_min_nll:
Optional[float] = None# Floor on per-token NLL before averaging; tokens already below the floor contribute zero-gradient, freeing the optimizer to focus on “unsolved” positions. Defaults to None (no clamping), otherwise clamped at the given value.
FLRT (https://arxiv.org/abs/2407.17447, Eq. 5) uses
-log(0.6) ≈ 0.511.
-
temperature:
float= 1.0# Temperature applied to the prefill logits before softmax.
- class tropt.loss.PrefillCWLoss(cw_margin=5.0, first_token_weight=1.0)[source]#
Bases:
PrefillBasedLossEncourages (=maximize likelihood) the model to produce the target output (mostly an affirmative response). CW-inspired hinge loss on the difference between the largest and the target logits. https://arxiv.org/abs/2402.09674
Loss computed on prefilled response logits (prefill_response_logits). Requires target tokens (target_response_toks); automatically derived from target_response_strs. Using this loss usually implies that the model will prefill the response with these target tokens.
- Parameters:
cw_margin (float)
first_token_weight (float)
-
cw_margin:
float= 5.0#
-
first_token_weight:
float= 1.0#
- class tropt.loss.PrefillDistillationLoss(temperature=1.0, reference_temperature=1.0, clamp_min_nll=None)[source]#
Bases:
PrefillBasedLossEncourage the probability similarity between the model’s prefill logits and reference logits at the target positions. Concretely, it returns cross-entropy between the victim model’s probabilities and a softmax over pre-computed reference logits
Inspired by FLRT for logit-based distillation (https://arxiv.org/abs/2407.17447), where the ref logits come from a jailbroken copy of the victim model.
Loss computed on prefilled response logits (prefill_response_logits). Requires target logits (target_response_logits); commonly derived from a reference model’s output on the same input.
- Parameters:
temperature (float)
reference_temperature (float)
clamp_min_nll (float | None)
-
clamp_min_nll:
Optional[float] = None# Floor on the cross-entropy result to stop optimizing well-matched tokens. Defaults to None (disabled), common value is
-log(0.6) ~ 0.51.
-
reference_temperature:
float= 1.0# Temperature applied to the reference logits before softmax (teacher sharpening).
-
temperature:
float= 1.0#
- class tropt.loss.PrefillMellowMaxLoss(mellowmax_alpha=1.0, temperature=1.0)[source]#
Bases:
PrefillBasedLossEncourages the model to produce the target output by maximizing the mellowmax of the target logits. https://arxiv.org/pdf/1612.05628, http://confirmlabs.org/posts/TDC2023
Loss computed on prefilled response logits (prefill_response_logits). Requires target tokens (target_response_toks); automatically derived from target_response_strs. Using this loss usually implies that the model will prefill the response with these target tokens.
- Parameters:
mellowmax_alpha (float)
temperature (float)
-
mellowmax_alpha:
float= 1.0#
-
temperature:
float= 1.0#
- class tropt.loss.ResponseHarmfulnessLoss(positive_words=<factory>, negative_words=<factory>, model_name_or_path='HuggingFaceTB/SmolLM2-135M-Instruct', judge_lm_batch_size=256, device='cpu')[source]#
Bases:
BinaryLMJudgeLoss,GeneratedResponseBasedLossLoss that encourages harmful model responses (operates on generated response string).
- Parameters:
positive_words (Set[str])
negative_words (Set[str])
model_name_or_path (str)
judge_lm_batch_size (int)
device (str)
- class tropt.loss.SimilarityLoss[source]#
Bases:
EmbeddingBasedLossEncourages given representation(s) to align (cos-sim) with the given target vectors.
- class tropt.loss.SteeringActivationLoss(targeted_layers=slice(None, None, None), steer_away=False, slc_name=SliceKey.INPUT_LAST_TOKEN, do_cosine_sim=False, apply_square=False, apply_abs=False)[source]#
Bases:
HiddenStateBasedLossEncourages hidden activations at specific layers/positions to align with a target direction. - Each message has a target direction vector (optionally its own unique one).
target_directions: (n_templates, d_model)
Note that the direction will be applied to the whole target positions and layers.
- Default is steering towards a direction (maximizing alignment).
Here, minimizing the loss maximizes alignment (dot product) with the target direction.
Set steer_away=True to steer away (e.g., for refusal suppression).
References: - Was proposed as ‘refusal direction suppression’ combined with GCG:
- Was proposed for adapting attacks (e.g., GCG) for evading probe-based classifiers.
- Parameters:
targeted_layers (
slice) – Which layers to apply steering on (default: all layers)steer_away (
bool) – Whether to minimize alignment instead of maximizing (default: False = steer towards)slc_name (
SliceKey) – Which token positions to apply steering on (default: “last_input_token”)do_cosine_sim (
bool) – Whether to use cosine similarity instead of dot product (default: False)apply_square (
bool) – Whether to square the similarity scores (default: False)apply_abs (bool)
-
apply_abs:
bool= False#
-
apply_square:
bool= False#
-
do_cosine_sim:
bool= False#
-
steer_away:
bool= False#
-
targeted_layers:
slice= slice(None, None, None)#
- class tropt.loss.TriggerLogitBasedLoss[source]#
Bases:
BaseLossLoss computed on full-sequence logits (full_logits) sliced to trigger positions. Useful for optimizing properties of the triggers directly.
- class tropt.loss.TriggerPerplexityLoss(temperature=1.0, slc_name=SliceKey.TRIGGER)[source]#
Bases:
TriggerLogitBasedLossCalculates perplexity wrt to the target model logits themselves. Useful for penalizing non-fluent triggers.
- Parameters:
temperature (float)
slc_name (SliceKey)
-
temperature:
float= 1.0#
- tropt.loss.resolve_and_compute_loss(model_output, model_input, loss_func)[source]#
Universal loss computation via automatic argument matching.
This function invokes loss_func with the model input (which includes the targets and slices), and output; it returns the computed loss tensor (size bsz).
Notes
Since loss functions in TROPT declare their required arguments following the naming
convention of ModelOutput and ModelInput fields, this function can automatically resolve which data to provide to the loss function by inspecting its __call__ signature.
In case insufficient arguments are available (e.g., because the model does not provide the required access),
a LossResolutionError is raised with details on what is missing.
- Parameter Naming Convention:
Loss functions should name their parameters exactly as they appear in ModelOutput and ModelInput: - full_logits: From ModelOutput - output_embeddings: From ModelOutput - full_attentions: From ModelOutput - input_trigger_ids: From ModelInput - input_slices: From ModelInput - message_targets: From ModelInput - etc.
- Parameters:
model_output (
ModelOutput) – Standardized model output containing available datamodel_input (
ModelInput) – Standardized model input containing triggers, slices, message targetsloss_func (
BaseLoss) – The loss function to compute (must have proper __call__ signature)
- Return type:
Float[Tensor, 'bsz']- Returns:
Loss tensor of shape (bsz,) containing per-sample losses
- Raises:
LossResolutionError – If required parameter is not found in model data
TypeError – If loss function signature is invalid
Examples
>>> from tropt.common import MessageTargets >>> # Encoder model with SimilarityLoss(output_embeddings, target_embeddings) >>> output = ModelOutput(output_embeddings=torch.randn(4, 768)) >>> input_data = ModelInput(message_targets=MessageTargets(target_vectors=target_vecs)) >>> loss = resolve_and_compute_loss(output, input_data, SimilarityLoss())
>>> # Language model with PrefillCELoss(prefill_response_logits, message_targets) >>> output = ModelOutput(prefill_response_logits=torch.randn(2, 50, 32000)) >>> input_data = ModelInput( ... input_slices=[{SliceKey.APPENDED: slice(40, 50)}] * 2, ... message_targets=MessageTargets(target_response_toks=target_ids) ... ) >>> loss = resolve_and_compute_loss(output, input_data, PrefillCELoss())