from __future__ import annotations
"""
Random Search (PRS) optimizer.
Andriushchenko et al., "Jailbreaking Leading Safety-Aligned LLMs with Simple
Adaptive Attacks" (2024). https://arxiv.org/abs/2404.02151
Zeroth-order (gradient-free) token optimizer that mutates contiguous blocks
of tokens. A coarse-to-fine schedule shrinks the block size over time
(exploration -> exploitation), and a patience mechanism triggers random
restarts when the search stalls.
"""
import logging
from typing import Literal, Optional
import torch
from jaxtyping import Int
from torch import Tensor
from tropt.common import DEFAULT_INIT_TRIGGER, Targets, TextTemplates
from tropt.loss import BaseLoss
from tropt.model import BaseModel, LossTextAccessMixin
from tropt.model.model_base import BaseTokenizer
from tropt.optimizer import BaseOptimizer, OptimizerResult
from tropt.optimizer.utils.running_best import RunningBest
from tropt.optimizer.utils.token_constraints import TokenConstraints
from tropt.optimizer.utils.token_initializers import get_printable_random_trigger
logger = logging.getLogger(__name__)
[docs]
class RandomSearchOptimizer(BaseOptimizer):
"""RandomSearch: batched zeroth-order token optimization with block mutation.
Per step:
1. Compute block size from coarse-to-fine schedule
2. For each candidate, pick a random start position and replace a
contiguous block with random tokens from the allowed set
3. Decode candidates to strings and evaluate via ``compute_loss_from_texts``
4. Keep best if it improves current loss
5. If no improvement for ``patience`` steps, restart from random init
Implementation Notes:
- Candidate evaluation is always text-based (``compute_loss_from_texts``); even for HF model,
we decode to strings and re-encode for model input.
- A tokenizer is needed for the optimizer's token-level mutations; it should eitehr be provided, or we fall back to the model's tokenizer if it has one.
- The original implementation employs a "warm" intiial trigger (eg another GCG suffix), and uses it as the starting point for all restarts. Here, we sample random triggers for all restarts for diversity.
- The original implementation employs an LLM judge for early stopping; here we use a simple patience counter for restarts.
- The original implementation mostly use a loss-based scheduler. For generality (e.g., different potential loss values) we avoid using it.
Reference implementation:
- The original implementation: https://github.com/tml-epfl/llm-adaptive-attacks/blob/main/main.py
- Another (more simplified) implementation: https://github.com/romovpa/claudini/blob/main/claudini/methods/original/prs/optimizer.py
"""
model_requirements = (LossTextAccessMixin,)
def __init__(
self,
model: BaseModel,
loss: BaseLoss,
tracker=None,
seed: Optional[int] = None,
# optimization parameters:
num_steps: int = 500,
n_candidates: int = 128,
mutation_mode: Literal["block_random", "single_cyclic"] = "block_random",
# Block parameters
schedule: Literal["fixed", "none"] = "fixed",
initial_block_len: int = 4,
# misc:
patience: int = 25,
token_constraints: TokenConstraints = TokenConstraints(),
# external tokenizer (required when model has no tokenizer):
tokenizer: Optional[BaseTokenizer] = None,
):
"""
Args:
num_steps: Total optimization steps.
n_candidates: Number of mutated candidates per step.
mutation_mode:
``"block_random"`` for random contiguous block mutation
(original PRS),
``"single_cyclic"`` for single-token mutations
spread across positions (candidate ``i`` mutates position
``i % trigger_len``).
schedule: Schedule for block size decay. Relevant for block mutation mode(s).
``"fixed"`` for step-based coarse-to-fine decay,
``"none"`` to keep block size constant.
initial_block_len: Initial block size for mutation. Relevant for block mutation mode(s).
patience: Restart from random init after this many steps without
improvement. Set to 0 to disable restarts.
tokenizer: Tokenizer for encoding/decoding trigger tokens.
If None, uses model's tokenizer if it has one, else raises an error.
"""
super().__init__(model, loss=loss, tracker=tracker, seed=seed)
assert schedule in ("fixed", "none"), "Unsupported schedule type"
self.num_steps = num_steps
self.n_candidates = n_candidates
self.mutation_mode = mutation_mode
# Block parameters:
self.initial_block_len = initial_block_len
self.schedule = schedule
self.patience = patience
self.token_constraints = token_constraints
# Resolve tokenizer: prefer explicit, fall back to model's
if tokenizer is not None:
self.tokenizer = tokenizer
elif hasattr(model, "tokenizer"):
self.tokenizer = model.tokenizer
else:
raise ValueError(
"No tokenizer available. Pass an external tokenizer when the "
"model does not expose one."
)
# ------------------------------------------------------------------ #
# Main entry point
# ------------------------------------------------------------------ #
[docs]
def optimize_trigger(
self,
templates: TextTemplates,
initial_trigger: Optional[str] = DEFAULT_INIT_TRIGGER,
targets: Optional[Targets] = None,
) -> OptimizerResult:
# --- Setup ---
self.model.set_inputs_from_texts(templates=templates, targets=targets)
tokenizer = self.tokenizer
device = self.model.device
trigger_ids: Int[Tensor, "trigger_seq_len"] = (
tokenizer.encode_trigger(initial_trigger).to(device)
)
trigger_len = trigger_ids.shape[0]
valid_token_ids = self.token_constraints.get_whitelist_ids(
tokenizer, tokenizer.vocab_size, device, return_tensor=True
)
n_valid = len(valid_token_ids) # noqa: F841
best = RunningBest()
# Initial loss
trigger_str = tokenizer.decode_trigger(trigger_ids)
current_loss = self.model.compute_loss_from_texts(
[trigger_str], loss_func=self.loss_func
).item()
self.log(loss=current_loss, trigger_str=trigger_str)
best.update(loss=current_loss, trigger_ids=trigger_ids, trigger_str=trigger_str)
# Restart bookkeeping
steps_without_improvement = 0 # consecutive steps without improvement
step_of_curr_restart = 0 # the step of the most recent restart
restart_count = 0 # number of restarts (for logging purposes)
# --- Optimization loop ---
for step_i in self.track_steps(range(self.num_steps)):
# Check patience — restart if stuck
if self.patience > 0 and steps_without_improvement >= self.patience:
restart_count += 1
trigger_ids = get_printable_random_trigger(
trigger_len, return_ids=True, tokenizer=tokenizer
).to(device)
trigger_str = tokenizer.decode_trigger(trigger_ids)
current_loss = self.model.compute_loss_from_texts(
[trigger_str], loss_func=self.loss_func
).item()
steps_without_improvement = 0
step_of_curr_restart = step_i
logger.info(
"PRS restart #%d at step %d (loss=%.4f)",
restart_count, step_i, current_loss,
)
# --- Sample candidates ---
candidates = trigger_ids.unsqueeze(0).expand(self.n_candidates, -1).clone()
if self.mutation_mode == "single_cyclic":
block_len = self._mutate_single_cyclic(
candidates, valid_token_ids
)
else: # default to block_random
local_step = step_i - step_of_curr_restart
block_len = self._mutate_block_random(
candidates, valid_token_ids, local_step, trigger_len
)
# --- Evaluate candidates (as texts) ---
candidate_strs = tokenizer.decode_triggers(candidates)
losses = self.model.compute_loss_from_texts(
candidate_strs, loss_func=self.loss_func
)
# If improved, update current trigger; else increment patience counter
best_idx = losses.argmin()
candidate_loss = losses[best_idx].item()
if candidate_loss < current_loss:
trigger_ids = candidates[best_idx]
current_loss = candidate_loss
steps_without_improvement = 0
else:
steps_without_improvement += 1
# logging stuff:
trigger_str = tokenizer.decode_trigger(trigger_ids)
self.log(
loss=current_loss,
trigger_str=trigger_str,
n_token_flip=block_len,
restarts=restart_count,
patience_counter=steps_without_improvement,
)
best.update(loss=current_loss, trigger_ids=trigger_ids, trigger_str=trigger_str)
# --- Finalize ---
return best.to_result()
# ------------------------------------------------------------------ #
# Mutation strategies
# ------------------------------------------------------------------ #
def _mutate_single_cyclic(
self,
candidates: Int[Tensor, "n_candidates trigger_len"],
valid_token_ids: Int[Tensor, "n_valid"],
) -> int:
"""Single-token mutation spread across positions round-robin.
Candidate ``i`` mutates position ``i % trigger_len``.
Returns block_len (always 1).
"""
B = candidates.shape[0]
device = candidates.device
n_valid = len(valid_token_ids)
arange_B = torch.arange(B, device=device)
positions = arange_B % candidates.shape[1]
random_tokens = valid_token_ids[
torch.randint(n_valid, (B,), device=device)
]
candidates[arange_B, positions] = random_tokens
return 1
def _mutate_block_random(
self,
candidates: Int[Tensor, "n_candidates trigger_len"],
valid_token_ids: Int[Tensor, "n_valid"],
local_step: int,
trigger_len: int,
) -> int:
"""Contiguous block mutation at random positions.
Returns the block length used.
"""
B = candidates.shape[0]
device = candidates.device
n_valid = len(valid_token_ids)
arange_B = torch.arange(B, device=device)
block_len = min(self._get_curr_block_len(local_step), trigger_len)
max_start = trigger_len - block_len
block_starts = torch.randint(0, max_start + 1, (B,), device=device)
random_tokens = valid_token_ids[
torch.randint(n_valid, (B, block_len), device=device)
]
for offset in range(block_len):
candidates[arange_B, block_starts + offset] = random_tokens[:, offset]
return block_len
# ------------------------------------------------------------------ #
# Coarse-to-fine block-len schedule
# ------------------------------------------------------------------ #
def _get_curr_block_len(self, local_step: int) -> int:
"""Block size at a given step within the current restart.
Follows ``schedule_n_to_change_fixed`` from the paper's official code.
"""
if self.schedule == "none":
return self.initial_block_len
m = self.initial_block_len
if local_step <= 10:
return m
elif local_step <= 25:
return max(m // 2, 1)
elif local_step <= 50:
return max(m // 4, 1)
elif local_step <= 100:
return max(m // 8, 1)
elif local_step <= 500:
return max(m // 16, 1)
else:
return max(m // 32, 1)