from __future__ import annotations
import logging
import math
from typing import Literal, Optional, Set
import torch
from jaxtyping import Float, Int
from torch import Tensor
from tropt.common import (
DEFAULT_INIT_TRIGGER,
Targets,
TextTemplates,
)
from tropt.loss import BaseLoss
from tropt.model import (
BaseModel,
GradientTokenAccessMixin,
LossTextAccessMixin,
LossTokenAccessMixin,
)
from tropt.optimizer import BaseOptimizer, OptimizerResult
from tropt.optimizer.utils.retokenization import retokenize_filtering
from tropt.optimizer.utils.running_best import RunningBest
from tropt.optimizer.utils.token_constraints import TokenConstraints
from tropt.tracker import BaseTracker
logger = logging.getLogger(__name__)
[docs]
class PALOptimizer(BaseOptimizer):
"""Proxy-guided black-box optimizer for PAL and RAL attacks.
Two-stage design:
1. Candidate selection (on proxy) — gradient-based or random.
2. Candidate evaluation (on target) — via text or token access.
Skip-visited is always enabled. Optional proxy filtering narrows
candidates by proxy loss before querying the target.
References:
- Paper: https://arxiv.org/abs/2402.09674
- Official Codebase: https://github.com/chawins/pal
Note: PAL original implementation also support fine-tuning the proxy model (to get it closer to the target model in the "optimized area"), we currently don't support that.
"""
model_requirements = (LossTextAccessMixin,)
def __init__(
self,
model: BaseModel,
loss: BaseLoss,
tracker: Optional[BaseTracker] = None,
seed: Optional[int] = None,
# Proxy model for candidate selection (defaults to model):
proxy_model: Optional[BaseModel] = None,
# Differentiable loss for proxy gradient/loss computation (defaults to loss):
proxy_loss: Optional[BaseLoss] = None,
# Candidate selection strategy:
candidate_selection: Literal["gradient", "random"] = "gradient",
# Core GCG params:
num_steps: int = 500,
n_candidates: int = 512,
sample_topk: int = 256,
sample_n_replace: int = 1,
candidate_oversample_factor: float = 1.5,
token_constraints: TokenConstraints = TokenConstraints(),
# PAL-specific: proxy filtering (set to None to disable):
n_candidates_after_proxy_filter: Optional[int] = None,
):
"""
Args:
model: Target model to attack (used for candidate evaluation).
By default loss is calculated wrt texts, or, if possible (ie, the model share tokenizer with the proxy model and the required access level), computation on tokens is preferred.
loss: Loss function to optimize.
proxy_model: Model used for candidate selection (gradients/tokenizer).
If None, defaults to `model` (self-proxy / white-box).
proxy_loss: Loss function for proxy gradient/loss computation. Required
when the main loss is non-differentiable (e.g. text-based) and
candidate_selection="gradient". Defaults to `loss`.
candidate_selection: "gradient" uses gradient-ranked top-k sampling,
"random" uses uniform random token sampling.
sample_topk: Number of top tokens to consider for each position when sampling candidates based on gradients.
sample_n_replace: Number of token positions to replace when generating each candidate.
candidate_oversample_factor: Generate n_candidates * factor candidates, then
truncate after retokenization filtering. Only effective when > 1.0.
n_candidates_after_proxy_filter: If set, filter candidates down to top-K by proxy loss
before evaluating on the target (PAL's proxy filtering step).
Expected to be less than n_candidates. Set to None to disable.
"""
super().__init__(model, loss=loss, tracker=tracker, seed=seed)
# Proxy model setup
self.proxy_model = model if proxy_model is None else proxy_model
assert isinstance(self.proxy_model, LossTokenAccessMixin), (
"proxy_model must support LossTokenAccessMixin for candidate selection"
)
self.proxy_loss = proxy_loss if proxy_loss is not None else self.loss_func
if candidate_selection == "gradient":
assert isinstance(self.proxy_model, GradientTokenAccessMixin), (
f"{candidate_selection=} requires proxy_model with GradientTokenAccessMixin"
)
if not self.proxy_loss.is_differentiable:
raise ValueError(
f"candidate_selection='gradient' requires a differentiable proxy_loss, "
f"but {type(self.proxy_loss).__name__}.is_differentiable=False. "
f"Pass a differentiable proxy_loss (e.g. PrefillCELoss())."
)
# Prefer token-level target evaluation when proxy and target share the same tokenizer
# (otherwise, if tokenizer are not shared, or we don't have access to the target model's tokenzier, we simply use text-level loss compuation)
use_token_input_for_loss = (model.tokenizer == self.proxy_model.tokenizer) and isinstance(model, LossTokenAccessMixin)
if model == proxy_model:
n_candidates_after_proxy_filter = None # disable proxy filtering if proxy and target are the same
# Save params
self.candidate_selection = candidate_selection
self.num_steps = num_steps
self.n_candidates = n_candidates
self.n_candidates_after_proxy_filter = n_candidates_after_proxy_filter
self.sample_topk = sample_topk
self.sample_n_replace = sample_n_replace
self.token_constraints = token_constraints
self.candidate_oversample_factor = candidate_oversample_factor
self.use_token_input_for_loss = use_token_input_for_loss
[docs]
def optimize_trigger(
self,
templates: TextTemplates,
initial_trigger: Optional[str] = DEFAULT_INIT_TRIGGER,
targets: Optional[Targets] = None,
) -> OptimizerResult:
# --- Initialization ---
proxy_model = self.proxy_model
proxy_tokenizer = proxy_model.tokenizer
target_model = self.model
# Proxy uses token access:
proxy_model.set_inputs_from_tokens(templates=templates, targets=targets)
# Target model access depends on `use_token_input_for_loss`:
if not self.use_token_input_for_loss:
target_model.set_inputs_from_texts(templates=templates, targets=targets)
else:
target_model.set_inputs_from_tokens(templates=templates, targets=targets)
trigger_ids: Int[Tensor, "trigger_seq_len"] = proxy_tokenizer.encode_trigger(initial_trigger).to(proxy_model.device)
vocab_size = proxy_model.vocab_size
blacklist_ids = self.token_constraints.get_blacklist_ids(proxy_tokenizer, vocab_size)
valid_token_ids = self.token_constraints.get_whitelist_ids(proxy_tokenizer, vocab_size, proxy_model.device, return_tensor=True)
best = RunningBest()
visited: Set[str] = set() # track visited trigger strings to avoid re-eval
# Number of candidates to sample (oversample if retokenize is on)
n_candidates_oversampled = math.ceil(self.n_candidates * self.candidate_oversample_factor)
# Initial loss
current_loss = self._evaluate_candidates_on_target_model(
trigger_ids.unsqueeze(0)
).item()
trigger_str = proxy_tokenizer.decode_trigger(trigger_ids)
visited.add(trigger_str)
self.log(loss=current_loss, trigger_str=trigger_str)
for _ in self.track_steps(range(self.num_steps)):
# === Stage 1: Candidate Selection (on proxy) ===
if self.candidate_selection == "gradient":
trigger_grad = proxy_model.compute_grad_from_tokens(
candidate_trigger_ids=trigger_ids.unsqueeze(0),
loss_func=self.proxy_loss,
normalize_grads=True,
).squeeze(0)
candidate_trigger_ids = self._sample_ids_from_grad(
trigger_ids=trigger_ids,
trigger_grad=trigger_grad,
blacklist_ids=blacklist_ids,
_n_candidates=n_candidates_oversampled,
)
else: # "random"
candidate_trigger_ids = self._sample_ids_at_random(
trigger_ids=trigger_ids,
valid_token_ids=valid_token_ids,
_n_candidates=n_candidates_oversampled,
)
# === Filtering ===
candidate_trigger_ids = retokenize_filtering(
candidate_trigger_ids, proxy_tokenizer
)
# Skip visited (always on)
candidate_trigger_ids = self._filter_visited(
candidate_trigger_ids, visited
)
# Truncate to n_candidates (if longer after oversample + filter)
candidate_trigger_ids = candidate_trigger_ids[: self.n_candidates]
if len(candidate_trigger_ids) == 0:
logger.warning("Filtered out all candidates (after tokenization and visited filtering), skipping step.")
continue
# === Proxy filtering ===
# PAL adds another step to further narrow the candidate list
if self.n_candidates_after_proxy_filter is not None:
proxy_losses = proxy_model.compute_loss_from_tokens(
candidate_trigger_ids, loss_func=self.proxy_loss
) # shape: (n_candidates,)
topk_indices = proxy_losses.topk(
self.n_candidates_after_proxy_filter, largest=False
).indices
candidate_trigger_ids = candidate_trigger_ids[topk_indices]
# === Stage 2: Candidate Evaluation (on target) ===
losses = self._evaluate_candidates_on_target_model(candidate_trigger_ids)
current_loss = losses.min().item()
trigger_ids = candidate_trigger_ids[losses.argmin()]
trigger_str = proxy_tokenizer.decode_trigger(trigger_ids)
visited.add(trigger_str)
best.update(loss=current_loss, trigger_ids=trigger_ids, trigger_str=trigger_str)
self.log(loss=current_loss, trigger_str=trigger_str)
# --- Finalize ---
result = best.to_result()
return result
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _evaluate_candidates_on_target_model(
self,
candidate_trigger_ids: Int[Tensor, "n_candidates trigger_seq_len"],
) -> Float[Tensor, "n_candidates"]:
"""Evaluate candidates on the target model. Returns per-candidate loss."""
if self.use_token_input_for_loss:
losses = self.model.compute_loss_from_tokens(
candidate_trigger_ids, loss_func=self.loss_func
)
else:
candidate_strs = self.proxy_model.tokenizer.decode_triggers(candidate_trigger_ids)
losses = self.model.compute_loss_from_texts(
candidate_strs, loss_func=self.loss_func
)
return losses
def _sample_ids_from_grad(
self,
trigger_ids: Int[Tensor, "trigger_seq_len"],
trigger_grad: Float[Tensor, "trigger_seq_len vocab_size"],
blacklist_ids: list,
_n_candidates: int,
) -> Int[Tensor, "n_candidates trigger_seq_len"]:
"""
Follows GCG candidate sampling method.
- `_n_candidates` may be an oversampled number of triggers (ie possibly larger than `self.n_candidates`), which will be truncated after retokenization filtering.
"""
trigger_seq_len = trigger_grad.shape[0]
device = trigger_grad.device
trigger_grad = trigger_grad.clone()
trigger_grad *= -1
trigger_grad[:, blacklist_ids] = float("-inf")
topk_ids = trigger_grad.topk(self.sample_topk, dim=-1).indices
candidate_trigger_ids = trigger_ids.repeat(_n_candidates, 1).clone()
sampled_ids_pos = torch.rand(
_n_candidates, trigger_seq_len, device=device
).argsort(dim=-1)[..., : self.sample_n_replace]
relevant_topk_lists = topk_ids[sampled_ids_pos]
rand_k_indices = torch.randint(
0,
self.sample_topk,
(_n_candidates, self.sample_n_replace, 1),
device=device,
)
sampled_ids_val = torch.gather(
input=relevant_topk_lists,
dim=-1,
index=rand_k_indices,
).squeeze(-1)
candidate_trigger_ids = candidate_trigger_ids.scatter_(
dim=-1,
index=sampled_ids_pos,
src=sampled_ids_val,
)
return candidate_trigger_ids
def _sample_ids_at_random(
self,
trigger_ids: Int[Tensor, "trigger_seq_len"],
valid_token_ids: Int[Tensor, "n_valid"],
_n_candidates: int,
) -> Int[Tensor, "n_candidates_oversampled trigger_seq_len"]:
"""Sample candidates by randomly replacing tokens (RAL-style)."""
trigger_seq_len = trigger_ids.shape[0]
device = trigger_ids.device
candidate_trigger_ids = trigger_ids.repeat(_n_candidates, 1).clone()
sampled_ids_pos = torch.rand(
_n_candidates, trigger_seq_len, device=device
).argsort(dim=-1)[..., : self.sample_n_replace]
rand_indices = torch.randint(
0, len(valid_token_ids), (_n_candidates, self.sample_n_replace), device=device
)
random_tokens = valid_token_ids[rand_indices]
candidate_trigger_ids = candidate_trigger_ids.scatter_(
dim=-1,
index=sampled_ids_pos,
src=random_tokens,
)
return candidate_trigger_ids
def _filter_visited(
self,
candidate_trigger_ids: Int[Tensor, "n_candidates_oversampled trigger_seq_len"],
visited: Set[str],
) -> Int[Tensor, "n_filtered trigger_seq_len"]:
"""Filter out candidates whose decoded string is in the visited set."""
tokenizer = self.proxy_model.tokenizer
keep_mask = torch.tensor(
[tokenizer.decode_trigger(cid) not in visited for cid in candidate_trigger_ids],
device=candidate_trigger_ids.device,
)
return candidate_trigger_ids[keep_mask]