Source code for tropt.optimizer.gbda_optimizer

from __future__ import annotations
import logging
from typing import Callable, Literal, Optional

import torch
import torch.nn.functional as F
from jaxtyping import Float, Int
from torch import Tensor

from tropt.common import (
    DEFAULT_INIT_TRIGGER,
    Targets,
    TextTemplates,
)
from tropt.loss import BaseLoss
from tropt.model import (
    BaseModel,
    GradientTokenAccessMixin,
    LossTokenAccessMixin,
)
from tropt.optimizer import BaseOptimizer, OptimizerResult
from tropt.optimizer.utils.running_best import RunningBest
from tropt.tracker import BaseTracker

logger = logging.getLogger(__name__)


[docs] class GBDAOptimizer(BaseOptimizer): """ Gradient-Based Distributional Attack (GBDA). Paper: https://arxiv.org/abs/2104.13733 Reference implementation: https://github.com/facebookresearch/text-adversarial-attack/blob/main/whitebox_attack.py Optimizes a continuous logit matrix theta (model distribution over L tokens, where L is the trigger sequence length) that can be used to sample triggers (w/ Gumbel-softmax). Throughout the optimization, the matrix theta used to provide a weighted sum of input embedding, on which the loss and gradients can be computed, and subsequently update theta. After each optimization step, and in particular at the end, theta can be used to sample discrete triggers. """ model_requirements = (LossTokenAccessMixin, GradientTokenAccessMixin) def __init__( self, model: BaseModel, loss: BaseLoss, tracker: Optional[BaseTracker] = None, seed: Optional[int] = None, # GBDA-specific parameters: num_steps: int = 100, n_grad_samples: int = 10, n_final_gumbel_samples: int = 100, # Init paramaters: initial_coeff: float = 15.0, init_mode: Literal["from_trigger", "random"] = "from_trigger", init_noise_scale: float = 2.0, # Temperature schedule parameters: temp_schedule: Literal["linear", "gradual"] = "linear", temp_start: float = 1.0, temp_end: float = 0.1, # Optimization parameters: gd_optimizer: Callable[..., torch.optim.Optimizer] = torch.optim.Adam, use_lr_schedule: bool = True, learning_rate: float = 0.3, grad_clip_norm: Optional[float] = None, ): """ Args: # GBDA-specific parameters: num_steps: Number of optimization steps. n_grad_samples: Gumbel-softmax samples per gradient step. n_final_gumbel_samples: Samples to draw for final trigger selection (if 0, use argmax). # Init parameters: initial_coeff: Initial logit value at original token positions (used when init_mode="from_trigger"). init_mode: How to initialize the logit matrix. "from_trigger" sets initial_coeff at the initial trigger token positions. "random" samples from N(0, init_noise_scale). init_noise_scale: Std of random initialization (used only with init_mode="random"). # Temperature schedule parameters: temp_schedule: Temperature schedule type. "linear" for linear annealing, "gradual" for 3-phase schedule (explore 2.5->1.0, refine 1.0->0.5, discretize 0.5->0.01). When "gradual", temp_start and temp_end are ignored. temp_start: Starting Gumbel-softmax temperature (used only with "linear" schedule). temp_end: Ending Gumbel-softmax temperature (used only with "linear" schedule). # Optimization parameters: gd_optimizer: The gradient descent optimizer Torch class to use. use_lr_schedule: If True, apply cosine annealing to the learning rate. grad_clip_norm: If set, clip gradient norms to this value before each optimizer step. """ super().__init__(model, loss=loss, tracker=tracker, seed=seed) if temp_schedule not in ("linear", "gradual"): raise ValueError(f"Unknown temp_schedule: {temp_schedule!r}. Must be 'linear' or 'gradual'.") self.num_steps = num_steps self.n_grad_samples = n_grad_samples self.learning_rate = learning_rate self.initial_coeff = initial_coeff self.temp_schedule = temp_schedule self.temp_start = temp_start self.temp_end = temp_end self.n_final_gumbel_samples = n_final_gumbel_samples self.GDOptimizer = gd_optimizer self.use_lr_schedule = use_lr_schedule self.grad_clip_norm = grad_clip_norm self.init_mode = init_mode self.init_noise_scale = init_noise_scale def _get_temperature(self, step: int) -> float: """Get temperature for the current step based on the configured schedule.""" if self.temp_schedule == "gradual": return self._gradual_temperature(step) if self.temp_schedule == "linear": progress = min(1.0, step / max(1, self.num_steps - 1)) return self.temp_start + progress * (self.temp_end - self.temp_start) raise ValueError(f"Unknown temp_schedule: {self.temp_schedule!r}. Must be 'linear' or 'gradual'.") def _gradual_temperature(self, step: int) -> float: """3-phase schedule (50/25/25): explore -> refine -> discretize.""" ratio = step / max(1, self.num_steps) if ratio < 0.5: p = ratio / 0.5 return 2.5 - 1.5 * p # 2.5 → 1.0 elif ratio < 0.75: p = (ratio - 0.5) / 0.25 return 1.0 - 0.5 * p # 1.0 → 0.5 else: p = (ratio - 0.75) / 0.25 return 0.5 * (0.01 / 0.5) ** p # 0.5 → 0.01 (exponential decay)
[docs] def optimize_trigger( self, templates: TextTemplates, initial_trigger: Optional[str] = DEFAULT_INIT_TRIGGER, targets: Optional[Targets] = None, ) -> OptimizerResult: # --- Initialization --- self.model.set_inputs_from_tokens(templates=templates, targets=targets) tokenizer = self.model.tokenizer trigger_ids: Int[Tensor, "trigger_seq_len"] = tokenizer.encode_trigger(initial_trigger).to(self.model.device) vocab_size = self.model.vocab_size device = self.model.device trigger_seq_len = trigger_ids.shape[0] # Initialize logit matrix theta (here called `trigger_probs`) if self.init_mode == "random": trigger_probs: Float[Tensor, "seq_len vocab_size"] = ( torch.randn(trigger_seq_len, vocab_size, device=device, dtype=self.model.dtype) * self.init_noise_scale ) else: # "from_trigger" trigger_probs: Float[Tensor, "seq_len vocab_size"] = torch.zeros( trigger_seq_len, vocab_size, device=device, dtype=self.model.dtype, ) for i in range(trigger_seq_len): trigger_probs[i, trigger_ids[i]] = self.initial_coeff # Initialize optimizer and learning rate scheduler optimizer = self.GDOptimizer([trigger_probs], lr=self.learning_rate) scheduler = ( torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.num_steps) if self.use_lr_schedule else torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0) ) best = RunningBest() # --- Optimization of `trigger_probs` --- for step in self.track_steps(range(self.num_steps), desc="GBDA Optimization"): temperature = self._get_temperature(step) optimizer.zero_grad() # --- Gradient step --- # compute the gradient (wrt `trigger_probs_samples`) w/ Gumbel-softmax sampling # (we repeat `trigger_probs_samples` so the gradient computation will draw multiple samples (w/ gumbel-softmax) from the same (optimized) distribution.) trigger_probs_samples = trigger_probs.unsqueeze(0).repeat(self.n_grad_samples, 1, 1) trigger_grad = self.model.compute_grad_from_tokens( candidate_trigger_probs=trigger_probs_samples, # (n_grad_samples, trigger_seq_len, vocab_size) loss_func=self.loss_func, do_gumbel_softmax=True, gumbel_softmax_temp=temperature, normalize_grads=False, ) # (n_grad_samples, trigger_seq_len, vocab_size) # Average gradients across drawn samples avg_grad = trigger_grad.mean(dim=0) # -> (trigger_seq_len, vocab_size) # take grad step: trigger_probs.grad = avg_grad if self.grad_clip_norm is not None: torch.nn.utils.clip_grad_norm_([trigger_probs], self.grad_clip_norm) optimizer.step() scheduler.step() # --- Evaluate discrete trigger --- with torch.no_grad(): current_trigger_ids: Float[Tensor, "seq_len"] = trigger_probs.argmax(dim=-1) current_loss = self.model.compute_loss_from_tokens( current_trigger_ids.unsqueeze(0), loss_func=self.loss_func, ).item() current_trigger_str = tokenizer.decode_trigger(current_trigger_ids) best.update(loss=current_loss, trigger_ids=current_trigger_ids, trigger_str=current_trigger_str) self.log(loss=current_loss, trigger_str=current_trigger_str, temperature=temperature, lr=scheduler.get_last_lr()[0]) # --- Final sampling --- # Collect candidates: argmax from final theta, best from optimization, and drawn gumbel samples candidates = [trigger_probs.argmax(dim=-1)] if best.trigger_ids is not None: candidates.append(best.trigger_ids) for _ in range(self.n_final_gumbel_samples): sampled_probs = F.gumbel_softmax( trigger_probs.unsqueeze(0), tau=self.temp_end, hard=True, dim=-1, ) candidates.append(sampled_probs.argmax(dim=-1).squeeze(0)) candidates = torch.stack(candidates, dim=0) candidate_losses = self.model.compute_loss_from_tokens( candidates, loss_func=self.loss_func, ) final_best_idx = candidate_losses.argmin().item() final_best_ids = candidates[final_best_idx] final_best_str = tokenizer.decode_trigger(final_best_ids) final_best_loss = candidate_losses[final_best_idx].item() result = OptimizerResult( best_loss=final_best_loss, best_trigger_str=final_best_str, best_trigger_ids=final_best_ids, losses=best.losses, trigger_strs=best.trigger_strs, best_trigger_probs=trigger_probs.detach(), ) return result