Building a New Loss#
Going one level of abstraction down from picking an existing loss off the shelf as part of a recipe composition, this guide shows how to build your own custom loss.
A TROPT loss defines the optimization objective: given a model’s output on a trigger-combined input, it computes and returns a scalar for the optimizer to minimize.
Design in a nutshell. The loss is agnostic to both the model and the optimizer: it does not call the model — it operates on the model’s artifacts — and is not wired into the optimization loop directly.
Instead, the model component invokes the registered loss as part of its forward pass, and TROPT’s resolver (resolve_and_compute_loss()) automatically fills in the right model data by inspecting the loss’s __call__ signature.
This keeps losses small, swappable, and trivial to compose into recipes (as we’ve seen in Composing a Recipe).
This guide effectively explains how every loss in tropt.loss is implemented; browsing the existing losses there can provide helpful concrete examples.
If you would like to contribute a loss back to TROPT, see CONTRIBUTING.md. This guide focuses on building losses for your own use.
A Minimal Loss#
Implementing a loss in TROPT requires following a single rule: the names of your __call__ parameters must match field names on ModelOutput, ModelInput, or MessageTargets. Get the names right and everything connects automatically. The most common fields are listed in the reference table at the end of this guide; the canonical definitions live in tropt.common.
To demonstrate loss implementation, we start with a simple loss useful against embedding models: maximizing cosine similarity between an encoder’s output embedding and a target embedding vector — e.g., aligning a trigger-augmented text with a target vector.
from dataclasses import dataclass
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor
from tropt.loss import BaseLoss
@dataclass
class MyLoss(BaseLoss):
"""Encourages output embeddings to align with target vectors."""
def __call__(
self,
output_embeddings: Float[Tensor, "bsz d_model"],
target_vectors: Float[Tensor, "d_model"],
) -> Float[Tensor, "bsz"]:
sim = F.cosine_similarity(
output_embeddings, target_vectors.unsqueeze(0), dim=-1
)
return -sim # losses are minimized; negate to maximize similarity
Let’s break the implmenetation down, and see what every loss must satisfy:
Inherits from
BaseLoss. Any subclass ofBaseLossworks regardless of where it lives — no registration step, no entry point.Parameter names do all the wiring.
output_embeddingsis a field onModelOutput(populated by encoder models);target_vectorsis a field onMessageTargets. The resolver matches by name, no other registration needed.Return shape
(bsz,). One per-sample loss, no batch reduction. The optimizer handles aggregation. Note that there is no ned to handle batching internally within the loss; this is the caller duty.Sign convention. Losses are minimized. Negate inside the loss if you want to maximize the underlying quantity.
@dataclassdecoration. Not strictly required, but lets you add hyperparameters cleanly later (next section). All built-in losses use it.
To use this loss, just drop it into any optimizer that supports embedding-based objectives:
from tropt.common import Targets
from tropt.optimizer import GASLITEOptimizer
loss = MyLoss()
optimizer = GASLITEOptimizer(model=encoder_model, loss=loss, num_steps=50)
result = optimizer.optimize_trigger(
templates=["This passage is great. {{OPTIMIZED_TRIGGER}}"],
targets=Targets(target_vectors=target_vec.unsqueeze(0)), # (1, d_model)
)
Enhancing the Loss#
The version below adds
(i) hyperparameters via dataclass fields,
(ii) a require_* flag asking the model for extra output,
and (iii) inheritance from an existing loss category base class.
For this we rewrite the loss as cross-entropy on a prefilled target response — the canonical GCG-style jailbreak loss, which encourages the model to produce an affirmative target string (e.g. "Sure, here's how:"). This is what PrefillCELoss does in TROPT.
from dataclasses import dataclass
from typing import ClassVar
import torch.nn.functional as F
from jaxtyping import Float, Int
from torch import Tensor
from tropt.loss import PrefillBasedLoss
@dataclass
class MyLoss(PrefillBasedLoss):
"""Maximizes the likelihood of a target response."""
# Hyperparameters as dataclass fields
temperature: float = 1.0
# Ask the model to prefill the target response so we get logits over it.
# (PrefillBasedLoss already declares this; shown here for emphasis.)
require_target_prefill: ClassVar[bool] = True
def __call__(
self,
prefill_response_logits: Float[Tensor, "bsz response_seq_len vocab_size"],
target_response_toks: Int[Tensor, "response_seq_len"],
) -> Float[Tensor, "bsz"]:
logits = prefill_response_logits / self.temperature
# CE expects (bsz, vocab, seq); broadcast the target across the batch.
targets = target_response_toks.unsqueeze(0).expand(logits.shape[0], -1)
nll = F.cross_entropy(logits.transpose(-1, -2), targets, reduction="none")
return nll.mean(dim=-1) # (bsz,)
Breaking down the additions:
require_* flags. Some losses need the model to do extra work before returning output — append the target response and return logits over it, return attention weights, run real generation, etc. Declare this as a class attibute in the loss; the model backend inspects these before its forward pass and turns the right outputs on. The full set:
Attribute |
What the model does when it’s |
|---|---|
|
Appends the target tokens and returns |
|
Returns |
|
Returns |
|
Performs autoregressive generation, returns |
|
Returns |
Convention: Hyperparameters as Fields. Anything you want callers to tweak at construction time — temperatures, margins, layer ranges, mode flags — goes here. Crucially, per-template data (target tokens, target vectors, target classes) does not belong here; it belongs on Targets, and the loss pulls it in by parameter name (target_response_toks above). This keeps the loss instance stateless w.r.t. the attack — the optimizer can resample or subsample templates without telling the loss.
Convention: Inheriting from a category base class. tropt.loss defines abstract bases like PrefillBasedLoss, EmbeddingBasedLoss, HiddenStateBasedLoss, AttentionBasedLoss, ClassificationBasedLoss, and TextBasedLoss. Each sets the right require_* flag and pins a typed __call__ signature for its category — inheriting from the right base saves boilerplate and signals intent. The categorization is a convention for readability, not a hard requirement: the resolver dispatches by parameter names, not by base class. If your loss doesn’t fit any existing category, inheriting from BaseLoss directly and setting the flags yourself is equally valid.
Practical Example: Activation Steering#
Up to here we asked the model for logits over a target response. The same pattern works just as well for any model artifact — swap the base class, swap the require_* flag, and swap the __call__ parameter names, and you have a loss over a completely different signal.
To demonstrate, we rewrite the loss as activation steering (Arditi et al., 2024): encouraging the model’s hidden activations at chosen layers/positions to align with (or away from) a target direction in activation space. This has been used both to suppress refusal in jailbreaks and to probe internal representations. It’s what SteeringActivationLoss does in TROPT.
from dataclasses import dataclass
from typing import ClassVar
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor
from tropt.loss import HiddenStateBasedLoss
@dataclass
class MyLoss(HiddenStateBasedLoss):
"""Steers hidden activations along (or away from) a target direction."""
# Hyperparameters
targeted_layers: slice = slice(None)
steer_away: bool = False # if True, push activations *away* from the direction
# Ask the model to return all hidden states.
# (HiddenStateBasedLoss already declares this; shown here for emphasis.)
require_hidden_states: ClassVar[bool] = True
def __call__(
self,
full_hidden_states: Float[Tensor, "bsz n_layers seq_len d_model"],
target_directions: Float[Tensor, "d_model"],
) -> Float[Tensor, "bsz"]:
# Score the last token of each sample, averaged over the selected layers.
h = full_hidden_states[:, self.targeted_layers, -1, :] # (bsz, n_layers, d_model)
align = F.cosine_similarity(h, target_directions[None, None, :], dim=-1)
align = align.mean(dim=-1) # (bsz,)
return align if self.steer_away else -align
Structurally nothing is new — same BaseLoss lineage, same dataclass-fields-as-hyperparameters, same require_* declaration, same per-sample return shape. The differences are entirely in which names appear:
Base class & flag:
HiddenStateBasedLoss(setsrequire_hidden_states=True), replacingPrefillBasedLoss/require_target_prefill.Model output:
full_hidden_statesreplacesprefill_response_logits.Per-template target:
target_directions(provided viaTargets(target_directions=...)) replacestarget_response_toks.
The same one-knob swap also produces attention-based losses (inherit from AttentionBasedLoss, name full_attentions) and classifier losses (inherit from ClassificationBasedLoss, name output_class_logits) — see the existing implementations of AttentionEnhLoss and MisclassCELoss for those variants.
Going Non-Differentiable: Scoring Triggered Text#
So far our losses have operated on tensors produced by the target model — embeddings, logits, hidden states. But some losses need to score the triggered text itself (the trigger alone, or the full input with the trigger inserted) via an external system: another LM, a classifier, a heuristic. These losses are typically non-differentiable, and so only black-box optimizers will accept them — but they unlock objectives the target model can’t express on its own.
The version below penalizes non-fluent triggered inputs by measuring their perplexity under an external base LM. It mirrors what InputFluencyLoss and ExternalTriggerPerplexityLoss do in TROPT, and has been used to discourage gibberish suffixes that would be easy to filter at inference time.
from dataclasses import dataclass, field
from typing import Annotated, Any, ClassVar, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from jaxtyping import Float
from tropt.loss import TextBasedLoss
@dataclass
class MyLoss(TextBasedLoss):
"""Penalizes non-fluent triggered inputs via an external base LM's perplexity."""
# Loss runs an external model — flag explicitly as non-differentiable.
is_differentiable: ClassVar[bool] = False
judge_model_name: str = "google/gemma-2-2b"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
# Internal state — not constructor args
_model: Any = field(default=None, init=False, repr=False)
_tokenizer: Any = field(default=None, init=False, repr=False)
def __post_init__(self):
super().__post_init__()
self._model = AutoModelForCausalLM.from_pretrained(
self.judge_model_name, dtype=torch.bfloat16
).eval().to(self.device)
self._tokenizer = AutoTokenizer.from_pretrained(self.judge_model_name)
def __call__(
self,
input_texts: Annotated[List[str], "bsz"],
) -> Float[torch.Tensor, "bsz"]:
enc = self._tokenizer(
input_texts, return_tensors="pt", padding=True
).to(self.device)
with torch.no_grad():
logits = self._model(**enc).logits
nll = torch.nn.functional.cross_entropy(
logits[:, :-1].transpose(1, 2), # (bsz, vocab, seq-1)
enc.input_ids[:, 1:], # (bsz, seq-1)
reduction="none",
)
mask = enc.attention_mask[:, 1:].float()
return (nll * mask).sum(-1) / mask.sum(-1) # mean NLL per sample
Breaking down the changes:
Naming input_texts (or input_trigger_strs). Both fields live on ModelInput. input_texts is the entire triggered input string with the optimized trigger inserted (one string per candidate in the batch); input_trigger_strs is just the trigger substring. Which one you name as a parameter scopes whether your loss scores the trigger in isolation or in context — scoring input_texts penalizes gibberish triggers that also disrupt the surrounding instruction’s fluency, whereas scoring input_trigger_strs only checks the trigger’s local quality.
is_differentiable = False. Once your loss touches an external model, runs generation, or otherwise breaks the autograd graph, mark it explicitly. Differentiable optimizers (e.g. GCGOptimizer, which back-props through the target model) check this flag and will refuse to run with a non-differentiable loss. Black-box optimizers (e.g. RandomSearchOptimizer, BeamSearchOptimizer, PALOptimizer) accept either. TextBasedLoss already sets this to False for you — shown above for visibility.
Loading external resources in __post_init__. __post_init__() exists for exactly this; subclasses override it to load auxiliary models, tokenizers, classifiers, or anything else heavy. Use field(default=None, init=False, repr=False) for these so they don’t appear as constructor args or in the loss’s repr. Note that even though the loss now carries an external model on the instance, it still carries no per-template state — everything the loss needs flows in via the resolver each call.
Composing Multiple Losses#
For multi-objective optimization, wrap your losses in CombinedLoss — it resolves each component independently (via the same resolver) and returns a weighted sum:
from tropt.loss import CombinedLoss
loss = CombinedLoss(
loss_funcs=[LossA(), LossB()],
weights=[0.8, 0.2],
)
Two constraints to know about:
CombinedLosscannot be nested inside anotherCombinedLoss.The combined
require_*flags are the OR over components; the combinedis_differentiableis the AND. So adding one non-differentiable component flips the whole combined loss to non-differentiable.
Available Parameter Names#
The resolver matches your __call__ parameter names against fields on three classes (see tropt.common for the canonical definitions and full list).
From ModelOutput — produced by the model’s forward pass:
Parameter name |
Type |
Provided when |
|---|---|---|
|
|
Encoder models |
|
|
LMs (full sequence) |
|
|
Loss has |
|
|
Loss has |
|
|
Loss has |
|
|
Loss has |
|
|
Loss has |
|
|
Classifier models |
From ModelInput — assembled per step before the forward pass:
Parameter name |
Type |
Description |
|---|---|---|
|
|
Full triggered input strings (one per candidate) |
|
|
Just the trigger string (one per candidate) |
|
|
Current trigger token IDs |
|
|
Position markers (trigger / before / after / appended / last token) |
From MessageTargets — per-template optimization targets, sliced from the Targets passed to optimize_trigger():
Parameter name |
Type |
Description |
|---|---|---|
|
|
Raw target response text |
|
|
Tokenized target response |
|
|
Reference (e.g. teacher) logits |
|
|
Target embedding vector |
|
|
Target direction in activation space |
|
|
Target class index (targeted attacks) |
|
|
True class index (untargeted attacks) |
Parameters with default values in your signature are treated as optional by the resolver — missing data does not raise. Use this for losses that can operate in multiple modes.
If the data you need isn’t exposed by any current field, add a new field to ModelOutput / MessageTargets / Targets in tropt.common and populate it from the relevant model backend; the resolver will then route it to any loss that names it.