import os
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional
import pydantic
import torch
from jaxtyping import Float, Int
from torch import Tensor
# =========== Common constants and utilities for TTOP ==========
# Defines a placeholder string for optimized triggers
OPTIMIZED_TRIGGER_PLACEHOLDER: str = "{{OPTIMIZED_TRIGGER}}"
# Default initial trigger
DEFAULT_INIT_TRIGGER = ("! " * 20).strip()
# ======================= Common input types =======================
TextTemplates = Annotated[
List[str],
pydantic.Field(min_length=1),
"n_templates"
]
"""List of text templates, one per optimization target. Each must contain the trigger placeholder (`{{OPTIMIZED_TRIGGER}}`).
Length: n_templates.
"""
TokenTrigger = Float[Tensor, "1 trigger_seq_len"]
TextTrigger = str
TokenTriggerCandidates = Float[Tensor, "n_candidates trigger_seq_len"]
# ======================= Slice Keys Enum =======================
[docs]
class SliceKey(str, Enum):
"""
Enum for standardized slice keys used in input embeddings.
"""
TRIGGER = "trigger" # The optimized trigger tokens
"""The optimized trigger tokens"""
INPUT_BEFORE = "input_before" # Tokens before the trigger
"""Tokens before the trigger (including chat template, if exists)"""
INPUT_AFTER = "input_after" # Tokens after the trigger
"""Tokens after the trigger (including chat template, if exists)"""
INPUT_LAST_TOKEN = "input_last_token" # Last input token
"""The last token of the input sequence (e.g., typically the end of the chat template before generation starts)"""
APPENDED = "appended" # Appended tokens (if any); a.k.a. prefilled tokens
"""Optional tokens appended (=prefilled) at the end (e.g., target outputs for LMs)"""
[docs]
class MessageTargets(pydantic.BaseModel):
"""Targets for a single selected message.
"""
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
# ── LM response targets ────────────────────────────────────────────────
# Consumed by prefill-based and generation-based losses on language models.
target_response_strs: Optional[str] = None
"""Raw text target response for this message."""
target_response_toks: Optional[Int[Tensor, "target_seq_len"]] = None
"""Tokenized target response for this message."""
target_response_logits: Optional[Float[Tensor, "target_seq_len vocab_size"]] = None
"""Target logits from a reference (e.g., jailbroken) model, one per target position.
Used by distillation-style losses (e.g., FLRT, https://arxiv.org/abs/2407.17447).
"""
# ── Representation targets ─────────────────────────────────────────────
# Consumed by losses that operate on embeddings or hidden-state directions.
target_vectors: Optional[Float[Tensor, "d_model"]] = None
"""Target embedding vector for this message."""
target_directions: Optional[Float[Tensor, "d_model"]] = None
"""Target direction in activation space for this message.
Used by steering losses (e.g., representation engineering).
"""
# ── Weight-gradient targets ────────────────────────────────────────────
# Consumed by gradient-matching losses.
target_gradient: Optional[Float[Tensor, "n_params"]] = None
"""Target weight-gradient to align with, flattened over the trainable params.
Precompute the target weight-gradient externally and pass it here.
"""
# ── Classifier targets ─────────────────────────────────────────────────
# Consumed by losses that operate on classifier logits.
target_class_idx: Optional[int] = None
"""Target class index for this message.
Used by targeted-misclassification losses on classifier outputs.
"""
true_class_idx: Optional[int] = None
"""True (current) class index for this message; the class to steer away from.
Used by untargeted-misclassification losses on classifier outputs.
"""
[docs]
class Targets(pydantic.BaseModel):
"""Targets for all templates. Each field has an initial n_templates dimension.
Typically only one or two of these fields need to be provided depending
on the loss function being used.
For example, a standard LM jailbreak only needs `target_response_strs` to provide the target outputs.
"""
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, extra='forbid')
# ── LM response targets ────────────────────────────────────────────────
# Consumed by prefill-based and generation-based losses on language models.
target_response_strs: Optional[Annotated[List[str], "n_templates"]] = None
"""Raw text target outputs, one per template.
List is of length n_templates.
- Used by: Language models for target matching. Will be tokenized
internally to produce `target_response_toks` if not provided directly.
- If used with prefill-based losses, it will automatically run model computations with a the prefilled target response
- Note: any special tokens that must precede the actual response for the
target model are the caller's responsibility to include here. For example,
thinking models (Qwen3, DeepSeek-R1, etc.) that were trained to begin every
response with a `<think>...</think>` block typically need an empty block
(e.g. ``"<think>\\n\\n</think>\\n\\n"``) prepended to the target string to
suppress reasoning before the desired prefix.
"""
target_response_toks: Optional[Int[Tensor, "n_templates target_seq_len"] | Annotated[List[Int[Tensor, "target_seq_len"]], "n_templates"]] = None
"""Tokenized target outputs, one per template.
Shape: (n_templates, target_seq_len) OR List of length n_templates,
each of (potentially different) shape (target_seq_len,)
Used by: Language models for computing cross-entropy loss.
"""
target_response_logits: Optional[Annotated[List[Float[Tensor, "target_seq_len vocab_size"]], "n_templates"]] = None
"""Target logits from a reference (e.g., jailbroken) model, one per target position.
Used by distillation-style losses (e.g., FLRT, https://arxiv.org/abs/2407.17447); one tensor per template.
List of length n_templates, each of (potentially different) shape (target_seq_len, vocab_size).
"""
# ── Representation targets ─────────────────────────────────────────────
# Consumed by losses that operate on embeddings or hidden-state directions.
target_vectors: Optional[Float[Tensor, "n_templates d_model"]] = None
"""Target embedding vectors, one per template.
Shape: (n_templates, d_model)
Used by: Encoder models for similarity-based losses.
"""
target_directions: Optional[Float[Tensor, "n_templates d_model"]] = None
"""Target directions in activation space, one per template.
Shape: (n_templates, d_model)
Used by: Steering losses (e.g., refusal suppression).
Note: if you need per-layer directions, store as (n_templates, n_layers, d_model)
and update this annotation accordingly.
"""
# ── Weight-gradient targets ────────────────────────────────────────────
# Consumed by gradient-matching losses.
target_gradient: Optional[Float[Tensor, "n_templates n_params"]] = None
"""Target weight-gradients, one per template (flattened over the trainable params).
Shape: (n_templates, n_params).
Used by: gradient-matching losses. Precompute the per-template target
weight-gradients externally and stack across templates.
"""
# ── Classifier targets ─────────────────────────────────────────────────
# Consumed by losses that operate on classifier logits.
target_class_idx: Optional[Annotated[List[int], "n_templates"]] = None
"""Target class indices, one per template.
List of length n_templates.
Used by: targeted-misclassification losses on classifier outputs.
"""
true_class_idx: Optional[Annotated[List[int], "n_templates"]] = None
"""True (current) class indices, one per template; the class to steer away from.
List of length n_templates.
Used by: untargeted-misclassification losses on classifier outputs.
"""
@property
def n_templates(self) -> int:
for field_name in self.model_fields_set:
val = getattr(self, field_name)
if val is not None:
return len(val)
return 0
[docs]
@pydantic.model_validator(mode="after")
def check_field_lengths(self) -> "Targets":
lengths = {len(v) for k, v in self if v is not None}
if len(lengths) > 1:
raise ValueError("All target fields must have the same length")
return self
[docs]
def select_message(self, idx: int) -> "MessageTargets":
"""Return a MessageTargets instance for the selected template index."""
return MessageTargets(
**{k: v[idx] for k, v in self if v is not None}
)
[docs]
def select_indices(self, indices: List[int]) -> "Targets":
"""Return a new Targets with only the selected template indices."""
updates = {}
for k, v in self:
if v is None:
continue
if isinstance(v, Tensor):
updates[k] = v[indices]
elif isinstance(v, list):
updates[k] = [v[i] for i in indices]
return Targets(**updates)
[docs]
def to_device(self, device: torch.device | str) -> "Targets":
updates = {}
for k, v in self:
if isinstance(v, Tensor):
updates[k] = v.to(device)
elif isinstance(v, list) and v and isinstance(v[0], Tensor):
updates[k] = [t.to(device) for t in v]
return self.model_copy(update=updates)
# ======================= Model Input Wrapper =======================
# TODO optional additional validators to check shapes of inputs?
# ======================= Model Output Wrapper =======================
[docs]
class ModelOutput(pydantic.BaseModel):
"""Standardized output container for all model types in TROPT.
This renders a uniform interface for model outputs, that can then be used
to compute different losses agnostic of the underlying model type/implementation.
Shape Notation:
- bsz: batch size (number of candidates or samples in a batch)
- n_layers: number of model layers
- n_heads: number of attention heads per layer
- seq_len: total sequence length
- response_len: length of generated response (variable per sample)
- full_seq_len: full sequence length including prompt and generation
- vocab_size: vocabulary size
- d_model: model embedding dimension
Examples:
>>> # Encoder model output
>>> encoder_output = ModelOutput(output_embeddings=torch.randn(4, 768))
>>> # Language model output with logits
>>> lm_output = ModelOutput(
... full_logits=torch.randn(2, 50, 32000),
... generated_response_strs=["Response 1", "Response 2"]
... )
>>> # Full output with hidden states and attentions
>>> full_output = ModelOutput(
... full_logits=logits,
... full_hidden_states=hidden_states,
... full_attentions=attentions,
... generated_response_strs=responses
... )
"""
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, extra='forbid')
# === Embedding outputs (Encoder models) ===
output_embeddings: Optional[Float[Tensor, "bsz d_model"]] = None
"""Pooled output embeddings from encoder models.
Shape: (batch_size, d_model)
"""
# === Logits (Language models) ===
# LM legend: all = input (incl trigger) + prefilled response + generated response tokens
full_logits: Optional[Float[Tensor, "bsz seq_len vocab_size"]] = None
"""Full sequence logits from language models; including both inputs and outputs (prefilled and generated).
"""
prefill_response_logits: Optional[Float[Tensor, "bsz response_seq_len vocab_size"]] = None
"""Logits corresponding to the *prefilled* response portion of the sequence.
"""
# === Hidden states (Transformer models with output_hidden_states=True) ===
full_hidden_states: Optional[Float[Tensor, "bsz n_layers seq_len d_model"]] = None
"""Hidden states from all layers.
Note: Typically requires stacking tuple outputs from HuggingFace models:
`torch.stack(outputs.hidden_states, dim=1)`
"""
# === Attention weights (Transformer models with output_attentions=True) ===
full_attentions: Optional[Float[Tensor, "bsz n_layers n_heads seq_len seq_len"]] = None
"""Attention weights from all layers.
Note: Typically requires stacking tuple outputs from HuggingFace models:
`torch.stack(outputs.attentions, dim=1)`
"""
# === Generated responses (Language models with generation) ===
generated_response_ids: Optional[List[Int[Tensor, "response_len"]]] = None
"""Generated token IDs from language model generation.
Response lengths may vary across samples.
"""
generated_response_strs: Optional[List[str]] = None
"""Generated text strings from language model generation.
"""
generated_response_logits: Optional[List[Float[Tensor, "response_len vocab_size"]] | Float[Tensor, "bsz response_len vocab_size"]] = None
"""Logits for generated tokens from language model generation.
Notably, this differs from `response_logits` which take the logits w.r.t. a prefilled (mostly target) response. In particular, this excludes any prefilled tokens.
Response lengths may vary across samples.
"""
# === First-token logprobs (API and HF models with generation) ===
response_first_token_logprobs: Optional[List[Dict[str, float]]] = None
"""Sparse log-probabilities for the first generated token.
List of length bsz, where each element is a dict mapping token strings to
their log-probability. For API models this is typically the top-k returned
by the provider (e.g. top-20 from OpenAI); for HF models it can cover the
full vocabulary.
"""
# === Classification outputs (Classifier models) ===
output_class_logits: Optional[Float[Tensor, "bsz n_classes"]] = None
"""Classification logits from classifier models (pre-softmax).
Shape: (batch_size, num_classes)
"""
# === Full template ===
full_ids: Optional[Int[Tensor, "bsz full_seq_len"]] = None
"""Full template token IDs (prompt + generation; includes optional padding).
"""
full_strs: Optional[List[str]] = None
"""Full template strings (prompt + generation).
"""
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Convert to a dictionary, excluding None values."""
return self.model_dump(exclude_none=True)
# TODO add validators to check shapes?
#----------------------------------------------------------------------------
# ======================= Other Utilities =======================
def is_debug_mode() -> bool:
"""Whether developer debug mode is enabled (extra asserts/probes). Opt-in via `TROPT_DEBUG=1`."""
return os.environ.get("TROPT_DEBUG", "").lower() in ("1", "true", "yes")