Source code for tropt.optimizer.gcgplus_optimizer

from __future__ import annotations
import logging
import math
import random
from typing import Literal, Optional, Tuple, Union

import torch
from jaxtyping import Float, Int
from torch import Tensor

from tropt.common import (
    DEFAULT_INIT_TRIGGER,
    Targets,
    TextTemplates,
)
from tropt.loss import BaseLoss
from tropt.model import (
    BaseModel,
    GradientTokenAccessMixin,
    LossTextAccessMixin,
    LossTokenAccessMixin,
    TokenAccessMixin,
)
from tropt.optimizer import BaseOptimizer, OptimizerResult
from tropt.optimizer.utils.buffer import TriggerBuffer
from tropt.optimizer.utils.retokenization import retokenize_filtering
from tropt.optimizer.utils.running_best import RunningBest
from tropt.optimizer.utils.token_constraints import TokenConstraints
from tropt.optimizer.utils.token_initializers import get_printable_random_trigger
from tropt.tracker import BaseTracker

logger = logging.getLogger(__name__)


[docs] class GCGPlusOptimizer(BaseOptimizer): """Flexible GCG implementation, supporting tricks from GCG, QCG, GASLITE, and UAT. Two-stage design: 1. Candidate selection (on proxy model) — gradient-based, random, or focused. 2. Candidate evaluation (on target model) — via text or token access. References: - GCG: https://arxiv.org/abs/2307.15043 - QCG: https://arxiv.org/abs/2402.12329 - PAL: https://arxiv.org/abs/2402.09674 - GASLITE: https://arxiv.org/abs/2412.20953 - UAT: https://arxiv.org/abs/1908.07125 """ model_requirements = (LossTextAccessMixin,) def __init__( self, model: BaseModel, loss: BaseLoss, tracker: Optional[BaseTracker] = None, seed: Optional[int] = None, # Proxy model for candidate selection (defaults to model): proxy_model: Optional[BaseModel] = None, # Differentiable loss for proxy gradient computation (defaults to loss): proxy_loss: Optional[BaseLoss] = None, # Candidate selection strategy: candidate_selection: Literal["gradient", "random", "focused"] = "gradient", # Core GCG params: num_steps: int = 500, n_candidates: int = 512, sample_topk: int = 256, token_constraints: TokenConstraints = TokenConstraints(), use_retokenize: bool = True, # Later tricks: sample_n_replace: Union[int, Tuple[int, int]] = (1, 1), candidate_oversample_factor: float = 1.1, momentum: float = 0.0, # Trigger buffer size: buffer_size: Optional[int] = None, n_grad_avg: int = 1, # Per-step batch sampling: template_batch_size: Optional[int] = None, ): """ Args: proxy_model: Model used for candidate selection (gradients/tokenizer). If None, defaults to `model` (self-proxy / white-box). proxy_loss: Loss function for proxy gradient computation. Required when the main loss is non-differentiable and candidate_selection="gradient". Defaults to `loss`. candidate_selection: "gradient" uses gradient-ranked top-k sampling; "random" uses uniform random token sampling; "focused" probes all positions with target model loss then focuses on the best one (from QCG paper). sample_n_replace: (start, end) number of token positions to replace per candidate. Linearly interpolated over optimization steps; from PAL paper. Defaults to (1, 1) for single-token flips only. candidate_oversample_factor: Craft n_candidates * factor candidates (>1.0), then truncate after retokenization filtering, to fill the required number of candidates. From PAL paper. Defaults to 1.1 (10% oversampling). momentum: Gradient momentum coefficient. When > 0, enables momentum: m = mu*m + (1-mu)*grad for candidate ranking instead of raw gradient. Defaults to 0.0 (no momentum). Reference: https://arxiv.org/abs/2405.01229 . buffer_size: If set, maintain a buffer of the best triggers seen (from QCG paper). Each step starts from the best buffer entry and updates it with improved candidates. Defaults to None (no buffer). n_grad_avg: Number of trigger perturbations to average gradients over. When > 1, flips a random position per copy (from ARCA/GASLITE papers). Defaults to 1 (no averaging, like GCG). template_batch_size: If set, sample this many templates (and their targets) per optimization step instead of using all templates simultaneously. Useful for large template sets. Reference: https://arxiv.org/abs/1908.07125 . Defaults to None (use all templates). """ super().__init__(model, loss=loss, tracker=tracker, seed=seed) # Proxy model setup self.proxy_model = model if proxy_model is None else proxy_model assert isinstance(self.proxy_model, TokenAccessMixin), ( "proxy_model must support TokenAccessMixin (tokenizer access)" ) self.proxy_loss = proxy_loss if proxy_loss is not None else self.loss_func if candidate_selection == "gradient": assert isinstance(self.proxy_model, GradientTokenAccessMixin), ( "candidate_selection='gradient' requires proxy_model with GradientTokenAccessMixin" ) if not self.proxy_loss.is_differentiable: raise ValueError( f"candidate_selection='gradient' requires a differentiable proxy_loss, " f"but {type(self.proxy_loss).__name__}.is_differentiable=False. " f"Pass a differentiable proxy_loss (e.g. PrefillCELoss())." ) assert candidate_oversample_factor >= 1.0, "candidate_oversample_factor must be >= 1.0" # Prefer token-level target evaluation when proxy and target share the same tokenizer # (otherwise, if tokenizer are not shared, or we don't have access to the target model's tokenzier, we simply use text-level loss compuation) use_token_input_for_loss = (model.tokenizer == self.proxy_model.tokenizer) and isinstance(model, LossTokenAccessMixin) # Normalize sample_n_replace to tuple if isinstance(sample_n_replace, int): sample_n_replace = (sample_n_replace, sample_n_replace) # Save params self.candidate_selection = candidate_selection self.num_steps = num_steps self.n_candidates = n_candidates self.sample_topk = sample_topk self.sample_n_replace = sample_n_replace self.token_constraints = token_constraints self.use_retokenize = use_retokenize self.candidate_oversample_factor = candidate_oversample_factor self.momentum = momentum self.use_token_input_for_loss = use_token_input_for_loss self.buffer_size = buffer_size self.n_grad_avg = n_grad_avg self.template_batch_size = template_batch_size
[docs] def optimize_trigger( self, templates: TextTemplates, initial_trigger: Optional[str] = DEFAULT_INIT_TRIGGER, targets: Optional[Targets] = None, ) -> OptimizerResult: # --- Initialization --- proxy_model = self.proxy_model proxy_tokenizer = proxy_model.tokenizer target_model = self.model # Batch sampling setup use_batch_sampling = ( self.template_batch_size is not None and self.template_batch_size < len(templates) ) proxy_model.set_inputs_from_tokens(templates=templates, targets=targets) if not self.use_token_input_for_loss: target_model.set_inputs_from_texts(templates=templates, targets=targets) else: target_model.set_inputs_from_tokens(templates=templates, targets=targets) trigger_ids: Int[Tensor, "trigger_seq_len"] = proxy_tokenizer.encode_trigger(initial_trigger).to(proxy_model.device) vocab_size = proxy_model.vocab_size blacklist_ids = self.token_constraints.get_blacklist_ids(proxy_tokenizer, vocab_size) valid_token_ids = self.token_constraints.get_whitelist_ids(proxy_tokenizer, vocab_size, proxy_model.device, return_tensor=True) best = RunningBest() momentum_buffer: Optional[Tensor] = None # Buffer initialization buffer: Optional[TriggerBuffer] = None if self.buffer_size is not None: buffer = self._init_buffer(trigger_ids) # Number of candidates to generate (oversample if retokenize is on) n_candidates_oversampled = self.n_candidates if self.use_retokenize and self.candidate_oversample_factor > 1.0: n_candidates_oversampled = math.ceil(self.n_candidates * self.candidate_oversample_factor) # Initial loss (on full set or first batch) current_loss = self._evaluate_candidates( trigger_ids.unsqueeze(0) ).item() trigger_str = proxy_tokenizer.decode_trigger(trigger_ids) self.log(loss=current_loss, trigger_str=trigger_str) n_replace_start, n_replace_end = self.sample_n_replace for step_i in self.track_steps(range(self.num_steps)): # --- batch sampling: re-set model inputs with a random subset --- if use_batch_sampling: batch_indices = random.sample( range(len(templates)), self.template_batch_size ) batch_templates = [templates[i] for i in batch_indices] batch_targets = ( targets.select_indices(batch_indices) if targets is not None else None ) proxy_model.set_inputs_from_tokens( templates=batch_templates, targets=batch_targets, ) if not self.use_token_input_for_loss: target_model.set_inputs_from_texts( templates=batch_templates, targets=batch_targets, ) else: target_model.set_inputs_from_tokens( templates=batch_templates, targets=batch_targets, ) # Linearly interpolate sample_n_replace over steps cur_n_replace = round( n_replace_start + (n_replace_end - n_replace_start) * step_i / max(self.num_steps - 1, 1) ) # Buffer mode: start each step from the best buffer entry if buffer is not None: trigger_ids = buffer.get_best_trigger() # === Stage 1: Candidate Selection (on proxy) === if self.candidate_selection == "gradient": grad_triggers = self._get_grad_trigger_variations( trigger_ids, valid_token_ids ) trigger_grad = proxy_model.compute_grad_from_tokens( candidate_trigger_ids=grad_triggers, loss_func=self.proxy_loss, normalize_grads=True, ).mean(dim=0) # average over n_grad_avg variations # Apply momentum if self.momentum > 0: if momentum_buffer is None: momentum_buffer = trigger_grad.clone() else: momentum_buffer = ( self.momentum * momentum_buffer + (1 - self.momentum) * trigger_grad ) trigger_grad = momentum_buffer candidate_trigger_ids = self._sample_ids_from_grad( trigger_ids=trigger_ids, trigger_grad=trigger_grad, blacklist_ids=blacklist_ids, n_candidates=n_candidates_oversampled, n_replace=cur_n_replace, ) elif self.candidate_selection == "focused": candidate_trigger_ids = self._sample_focused_candidates( trigger_ids=trigger_ids, valid_token_ids=valid_token_ids, n_candidates=n_candidates_oversampled, ) else: # "random" candidate_trigger_ids = self._sample_random_candidates( trigger_ids=trigger_ids, valid_token_ids=valid_token_ids, n_candidates=n_candidates_oversampled, n_replace=cur_n_replace, ) # === Retokenization filtering === if self.use_retokenize: candidate_trigger_ids = retokenize_filtering( candidate_trigger_ids, proxy_tokenizer ) # Truncate to n_candidates (after oversample + filter) candidate_trigger_ids = candidate_trigger_ids[: self.n_candidates] if len(candidate_trigger_ids) == 0: logger.warning("All candidates filtered out, skipping step.") continue # === Stage 2: Candidate Evaluation (on target) === losses = self._evaluate_candidates(candidate_trigger_ids) # === Update best / buffer === if buffer is not None: # Buffer mode: update buffer with all evaluated candidates for idx in range(len(candidate_trigger_ids)): buffer.add_if_better( candidate_trigger_ids[idx], losses[idx].item() ) # Track overall best from buffer for logging current_loss = buffer.get_lowest_loss() trigger_ids = buffer.get_best_trigger() trigger_str = proxy_tokenizer.decode_trigger(trigger_ids) else: current_loss = losses.min().item() trigger_ids = candidate_trigger_ids[losses.argmin()] trigger_str = proxy_tokenizer.decode_trigger(trigger_ids) best.update(loss=current_loss, trigger_ids=trigger_ids, trigger_str=trigger_str) self.log(loss=current_loss, trigger_str=trigger_str) # --- Finalize --- result = best.to_result() return result
# ------------------------------------------------------------------ # Initialization # ------------------------------------------------------------------ def _init_buffer( self, initial_trigger_ids: Int[Tensor, "trigger_seq_len"], ) -> TriggerBuffer: """Initialize the buffer with the initial trigger and printable-random variants. """ assert self.buffer_size is not None trigger_seq_len = initial_trigger_ids.shape[0] device = initial_trigger_ids.device tokenizer = self.proxy_model.tokenizer blacklist_ids = self.token_constraints.get_blacklist_ids(tokenizer, self.proxy_model.vocab_size) triggers_list: list[Tensor] = [initial_trigger_ids] for _ in range(self.buffer_size - 1): # Oversample length (reencode can drift shorter), then trim to trigger_seq_len. rand_str = get_printable_random_trigger( trigger_len=2 * trigger_seq_len, tokenizer=tokenizer, blacklist_ids=blacklist_ids, ) ids = tokenizer.encode_trigger(rand_str).to(device)[:trigger_seq_len] triggers_list.append(ids) all_triggers = torch.stack(triggers_list, dim=0) all_losses = self._evaluate_candidates(all_triggers) return TriggerBuffer( triggers=list(all_triggers), losses=[all_losses[i].item() for i in range(self.buffer_size)], ) # ------------------------------------------------------------------ # Stage 1: Candidate Selection # ------------------------------------------------------------------ def _sample_ids_from_grad( self, trigger_ids: Int[Tensor, "trigger_seq_len"], trigger_grad: Float[Tensor, "trigger_seq_len vocab_size"], blacklist_ids: list, n_candidates: int, n_replace: int = 1, ) -> Int[Tensor, "n_candidates trigger_seq_len"]: """Sample candidate token sequences via GCG-style random multi-position replacement.""" trigger_seq_len = trigger_grad.shape[0] device = trigger_grad.device n_replace = min(n_replace, trigger_seq_len) # Top-k candidates per position trigger_grad = trigger_grad.clone() # avoid mutating momentum buffer trigger_grad *= -1 trigger_grad[:, blacklist_ids] = float("-inf") topk_ids = trigger_grad.topk(self.sample_topk, dim=-1).indices candidate_trigger_ids = trigger_ids.repeat(n_candidates, 1).clone() # Random positions to flip sampled_ids_pos = torch.rand( n_candidates, trigger_seq_len, device=device ).argsort(dim=-1)[..., :n_replace] # Select relevant top-k lists and sample one token from each relevant_topk_lists = topk_ids[sampled_ids_pos] rand_k_indices = torch.randint( 0, self.sample_topk, (n_candidates, n_replace, 1), device=device, ) sampled_ids_val = torch.gather( input=relevant_topk_lists, dim=-1, index=rand_k_indices, ).squeeze(-1) candidate_trigger_ids = candidate_trigger_ids.scatter_( dim=-1, index=sampled_ids_pos, src=sampled_ids_val, ) return candidate_trigger_ids def _get_grad_trigger_variations( self, trigger_ids: Int[Tensor, "trigger_seq_len"], valid_token_ids: Int[Tensor, "n_valid"], ) -> Int[Tensor, "n_grad_avg trigger_seq_len"]: """Create trigger variations for gradient averaging (GASLITE-style). With n_grad_avg == 1, returns the trigger as-is. With n_grad_avg > 1, flips a random position with a random valid token per copy (keeping the first copy intact). """ if self.n_grad_avg <= 1: return trigger_ids.unsqueeze(0) device = trigger_ids.device trigger_seq_len = trigger_ids.shape[0] grad_triggers = trigger_ids.repeat(self.n_grad_avg, 1).clone() # Keep first copy intact, perturb the rest for idx in range(1, self.n_grad_avg): pos = torch.randint(0, trigger_seq_len, (1,), device=device).item() tok = valid_token_ids[ torch.randint(0, len(valid_token_ids), (1,), device=device) ].item() assert isinstance(pos, int) and isinstance(tok, int) grad_triggers[idx, pos] = tok return grad_triggers def _sample_random_candidates( self, trigger_ids: Int[Tensor, "trigger_seq_len"], valid_token_ids: Int[Tensor, "n_valid"], n_candidates: int, n_replace: int = 1, ) -> Int[Tensor, "n_candidates trigger_seq_len"]: """Sample candidates by randomly replacing tokens (RAL-style).""" trigger_seq_len = trigger_ids.shape[0] device = trigger_ids.device n_replace = min(n_replace, trigger_seq_len) candidate_trigger_ids = trigger_ids.repeat(n_candidates, 1).clone() # Random positions to flip sampled_ids_pos = torch.rand( n_candidates, trigger_seq_len, device=device ).argsort(dim=-1)[..., :n_replace] # Sample uniformly from valid tokens rand_indices = torch.randint( 0, len(valid_token_ids), (n_candidates, n_replace), device=device ) random_tokens = valid_token_ids[rand_indices] candidate_trigger_ids = candidate_trigger_ids.scatter_( dim=-1, index=sampled_ids_pos, src=random_tokens, ) return candidate_trigger_ids def _sample_focused_candidates( self, trigger_ids: Int[Tensor, "trigger_seq_len"], valid_token_ids: Int[Tensor, "n_valid"], n_candidates: int, ) -> Int[Tensor, "n_candidates trigger_seq_len"]: """Focused position sampling (from QCG paper). Phase 1: Probe each position with one random token replacement, evaluate all probes on the target (via _evaluate_candidates) to find the most promising position. Phase 2: Generate n_candidates candidates at the best position. """ trigger_seq_len = trigger_ids.shape[0] device = trigger_ids.device # Phase 1: Probe each position with one random token probe_candidates = trigger_ids.repeat(trigger_seq_len, 1).clone() probe_tokens = valid_token_ids[ torch.randint(0, len(valid_token_ids), (trigger_seq_len,), device=device) ] # Replace position j in candidate j probe_candidates[ torch.arange(trigger_seq_len, device=device), torch.arange(trigger_seq_len, device=device), ] = probe_tokens probe_losses = self._evaluate_candidates(probe_candidates) best_pos = probe_losses.argmin().item() assert isinstance(best_pos, int) # Phase 2: Generate candidates at best position only candidates = trigger_ids.repeat(n_candidates, 1).clone() random_tokens = valid_token_ids[ torch.randint(0, len(valid_token_ids), (n_candidates,), device=device) ] candidates[:, best_pos] = random_tokens return candidates # ------------------------------------------------------------------ # Stage 2: Candidate Evaluation # ------------------------------------------------------------------ def _evaluate_candidates( self, candidate_trigger_ids: Int[Tensor, "n_candidates trigger_seq_len"], ) -> Float[Tensor, "n_candidates"]: """Evaluate candidates on the target model. Returns per-candidate loss. When use_token_input_for_loss is False, decodes candidates via proxy_model.tokenizer. """ if self.use_token_input_for_loss: losses = self.model.compute_loss_from_tokens( candidate_trigger_ids, loss_func=self.loss_func ) else: candidate_strs = self.proxy_model.tokenizer.decode_triggers(candidate_trigger_ids) losses = self.model.compute_loss_from_texts( candidate_strs, loss_func=self.loss_func ) # Reduce to (n_candidates,) if needed if losses.dim() > 1: losses = losses.mean(dim=0) return losses