Source code for tropt.optimizer.gasliteplus_optimizer

from __future__ import annotations
import logging
import time
from typing import Optional

import torch
from jaxtyping import Float, Int
from torch import Tensor

from tropt.common import (
    DEFAULT_INIT_TRIGGER,
    Targets,
    TextTemplates,
)
from tropt.loss import BaseLoss
from tropt.model import (
    BaseModel,
    GradientTokenAccessMixin,
    LossTokenAccessMixin,
)
from tropt.optimizer import BaseOptimizer, OptimizerResult
from tropt.optimizer.utils.buffer import TriggerBuffer
from tropt.optimizer.utils.retokenization import retokenize_filtering
from tropt.optimizer.utils.running_best import RunningBest
from tropt.optimizer.utils.scheduler import (
    ConstantScheduler,
    LinearScheduler,
    NFlipScheduler,
)
from tropt.optimizer.utils.token_constraints import TokenConstraints
from tropt.optimizer.utils.token_initializers import get_printable_random_trigger
from tropt.tracker import BaseTracker

logger = logging.getLogger(__name__)


[docs] class GASLITEPlusOptimizer(BaseOptimizer): """ Implements the GASLITE optimization algorithm (Algorithm 1) from the paper: "GASLITEing the Retrieval: Exploring Vulnerabilities in Dense Embedding-based Search" (https://arxiv.org/abs/2412.20953) """ model_requirements = (LossTokenAccessMixin, GradientTokenAccessMixin) def __init__( self, model: BaseModel, loss: BaseLoss, tracker: Optional[BaseTracker] = None, seed: Optional[int] = None, # attack parameters: num_steps: int = 100, n_grad: int = 50, n_flip: int = 20, n_candidates: int = 128, token_constraints: TokenConstraints = TokenConstraints(), use_retokenize: bool = True, use_random_gradient: bool = False, buffer_size: int = 10, decline_n_flip_from_step: Optional[int | float] = None, early_stopping_patience: Optional[int] = None, early_stopping_threshold: float = 0.005, # relative improvement threshold n_bulk_flips: int = 5, flip_pos_method: str = "random", # "random" or "ordered" time_limit: Optional[float] = None, n_flip_scheduler: Optional[NFlipScheduler] = None, **kwargs ): """ Initializes the GASLITE Optimizer. Args: model (HuggingFaceModel): The model to be attacked. loss (BaseLoss): The loss function to be optimized. seed (int, optional): Random seed for reproducibility. num_steps (int): Number of optimization iterations. n_grad (int): Number of random flips for gradient averaging. Set to 1 to disable averaging. n_flip (int): Number of token positions to greedily optimize per step. n_candidates (int): Number of top candidate tokens to evaluate for each position. token_constraints (TokenConstraints): An object to manage token blacklisting. use_retokenize (bool): Whether to filter candidates that are not reversible by the tokenizer. use_random_gradient (bool): If True, uses random gradients instead of model gradients (for ablation). buffer_size (int): Size of the trigger buffer to maintain. decline_n_flip_from_step (int | float, optional): If set, linearly declines `n_flip` to 1 starting from this step (int) or fraction of total steps (float). early_stopping_patience (int, optional): If set, enables early stopping if no improvement is seen in the buffer for this many consecutive steps. early_stopping_threshold (float): Relative improvement threshold for early stopping. n_bulk_flips (int): Number of bulk flips to perform per step (lower => less sequential model calls, faster). flip_pos_method (str): Method to select positions to flip - "random" or "ordered". n_flip_scheduler (NFlipScheduler, optional): A scheduler object to control `n_flip`. If provided, overrides `decline_n_flip_from_step`. References: - GASLITE: https://arxiv.org/abs/2412.20953 It is based on the GASLITE algorithm proposed in the paper, and extends it with multiple enhancements. - ACG: https://www.haizelabs.com/blog/making-a-sota-adversarial-attack-on-llms-38x-faster GASLITEPlus implements (i) multiple trigger random intiizliation, (ii) trigger buffer, and (iii) flipping bulk of positions at once, (iv) early stopping. Thus, it effectively includes most of the enhancements described Haize's ACG. - QCG? PAL? RAL? """ super().__init__(model, loss=loss, tracker=tracker, seed=seed) # save params: self.num_steps = num_steps self.n_grad = n_grad self.n_flip = n_flip self.n_candidates = n_candidates self.token_constraints = token_constraints self.use_retokenize = use_retokenize self.use_random_gradient = use_random_gradient self.buffer_size = buffer_size self.decline_n_flip_from_step = decline_n_flip_from_step # early stopping params self.early_stopping_patience = early_stopping_patience self.early_stopping_threshold = early_stopping_threshold # relative improvement threshold self.n_bulk_flips = n_bulk_flips self.flip_pos_method = flip_pos_method self.time_limit = time_limit if n_flip_scheduler is not None: self.n_flip_scheduler = n_flip_scheduler elif decline_n_flip_from_step is not None: self.n_flip_scheduler = LinearScheduler( initial_n_flip=n_flip, total_steps=num_steps, decline_start=decline_n_flip_from_step ) else: # default: constant n_flip self.n_flip_scheduler = ConstantScheduler(n_flip) def _get_trigger_variations( self, trigger_ids: Float[Tensor, "trigger_seq_len"], valid_token_ids: Float[Tensor, "n_valid"], ) -> Float[Tensor, "n_grad trigger_seq_len"]: """ Creates a list of `n_grad` trigger variations. The first is the original trigger, and the rest are random single-token flips of its. """ trigger_seq_len = len(trigger_ids) device = self.model.device trigger_vars_ids = trigger_ids.repeat( self.n_grad, 1 ) # shape: (n_grad, trigger_seq_len) for idx in range(1, self.n_grad): # (keep the first intact) # select a random position and a random token pos_to_flip = torch.randint(0, trigger_seq_len, (1,), device=device).item() tok_to_flip_to = valid_token_ids[ torch.randint(0, len(valid_token_ids), (1,), device=device) ].item() # apply the flip assert isinstance(pos_to_flip, int) and isinstance(tok_to_flip_to, int) trigger_vars_ids[idx, pos_to_flip] = tok_to_flip_to return trigger_vars_ids
[docs] def optimize_trigger( self, templates: TextTemplates, initial_trigger: Optional[str] = DEFAULT_INIT_TRIGGER, targets: Optional[Targets] = None, ) -> OptimizerResult: # Initialization: self.model.set_inputs_from_tokens(templates=templates, targets=targets) tokenizer = self.model.tokenizer trigger_ids: Int[Tensor, "trigger_seq_len"] = tokenizer.encode_trigger(initial_trigger).to(self.model.device) vocab_size = self.model.vocab_size blacklist_ids = self.token_constraints.get_blacklist_ids(tokenizer, vocab_size) valid_token_ids = self.token_constraints.get_whitelist_ids(tokenizer, vocab_size, return_tensor=True).to(self.model.device) trigger_ids: Float[Tensor, "trigger_seq_len"] = trigger_ids.to(self.model.device) trigger_seq_len = len(trigger_ids) trigger_str = initial_trigger best = RunningBest() current_loss = float("inf") start_time = time.time() # Form buffer_size initial triggers triggers_for_buffer = [trigger_ids] for _ in range(self.buffer_size - 1): random_trigger_ids = get_printable_random_trigger( trigger_seq_len, tokenizer=tokenizer, return_ids=True ).to(self.model.device) triggers_for_buffer.append(random_trigger_ids) # Compute losses for initial triggers losses = self.model.compute_loss_from_tokens( torch.stack(triggers_for_buffer), self.loss_func, ) # (n_cands,) # Create the buffer: buffer = TriggerBuffer( triggers=[triggers_for_buffer[i] for i in range(self.buffer_size)], losses=[losses[i].item() for i in range(self.buffer_size)], ) trigger_str = tokenizer.decode_trigger(buffer.get_best_trigger()) self.log(loss=buffer.get_lowest_loss(), trigger_str=trigger_str) for step in self.track_steps(range(self.num_steps), desc="Optimizing with GASLITE..."): n_flip = self.n_flip_scheduler.get_n_flip(step) # Get the best trigger from the buffer trigger_ids = buffer.get_best_trigger() trigger_str = tokenizer.decode_trigger(trigger_ids) # --- Gradient and candidate selection step --- if self.use_random_gradient: # Replace model gradient with random values trigger_grad = torch.randn( (trigger_seq_len, vocab_size), device=self.model.device ) else: # Compute grad over a list of `n_grad` triggers one-flip away from the current trigger_vars = self._get_trigger_variations(trigger_ids, valid_token_ids) grads = self.model.compute_grad_from_tokens( candidate_trigger_ids=trigger_vars, loss_func=self.loss_func, normalize_grads=True, ) # (n_trigger_vars, trigger_seq_len, vocab_size) # Average the gradients to get the final approximation trigger_grad = grads.mean(dim=0) trigger_grad: Float[Tensor, "trigger_seq_len vocab_size"] trigger_grad = -trigger_grad # we want to minimize the loss # Get Top-k Candidates *per position* trigger_grad[:, blacklist_ids] = float("-inf") topk_ids: Float[Tensor, "trigger_seq_len n_candidates"] topk_ids = trigger_grad.topk(self.n_candidates, dim=-1).indices # --- Greedy coordinate ascent step --- current_trigger_ids = trigger_ids.clone() # Sample `n_flip` unique positions to optimize sampled_positions = torch.randperm(trigger_seq_len, device=self.model.device)[: n_flip] if self.flip_pos_method == "ordered": sampled_positions, _ = sampled_positions.sort() # Perform bulk flips in chunks bulk_pos_list = torch.chunk(sampled_positions, self.n_bulk_flips) for bulk_pos in bulk_pos_list: candidate_triggers = current_trigger_ids.repeat(self.n_candidates, 1) # Inject candidate tokens at all positions in the bulk for pos in bulk_pos: # Get candidate tokens for this position all_candidate_tokens = torch.cat([ current_trigger_ids[pos].unsqueeze(0), # keep the "no flip" option topk_ids[pos, :self.n_candidates // 2], ]) if len(bulk_pos) > 1: # If the bulk has multiple positions, we add more candidates to increase diversity # sample more token ids, with replacements more_cand_indices = torch.randint( high=topk_ids[pos].size(0), size=(self.n_candidates - len(all_candidate_tokens),), device=topk_ids.device ) else: more_cand_indices = torch.arange( self.n_candidates // 2, self.n_candidates // 2 + (self.n_candidates - len(all_candidate_tokens)), device=topk_ids.device ) all_candidate_tokens = torch.cat([ all_candidate_tokens, topk_ids[pos, more_cand_indices] ]) # Create all candidate triggers by flipping this *single* position candidate_triggers[:, pos] = all_candidate_tokens # keep only unique candidates candidate_triggers = torch.unique(candidate_triggers, dim=0) # (Optional) Retokenize filtering if self.use_retokenize: candidate_triggers = retokenize_filtering( candidate_triggers, tokenizer ) if len(candidate_triggers) == 0: logger.warning( f"Retokenize filtering removed all candidates for pos {pos}. Skipping step." ) continue # Keep `current_trigger_ids` as is for this position # Compute losses on candidate flips losses = self.model.compute_loss_from_tokens( candidate_triggers, self.loss_func, keep_message_dim=True, # Get per-message loss ).mean( dim=0 ) # Average over messages -> (n_cands,) # Find the best token for this position losses_sorted_indices = torch.argsort(losses) best_candidate_idx = losses_sorted_indices[0] # Update `current_trigger_ids` *in-place* for the next iteration of the greedy (inner) loop current_trigger_ids = candidate_triggers[best_candidate_idx].clone() current_loss = losses[best_candidate_idx].item() # Update the trigger buffer, how much needed # We go over the buffer-size best candidates and try to add them to the buffer for j in range(min(buffer.size, len(losses))): cand_idx = losses_sorted_indices[j] buffer.add_if_better( candidate_triggers[cand_idx].clone(), losses[cand_idx].item(), ) # After the inner loop, `current_trigger_ids` is the best trigger for this *entire* step trigger_ids = current_trigger_ids trigger_str = tokenizer.decode_trigger(trigger_ids) # Logging: self.log(loss=current_loss, trigger_str=trigger_str) best.update(loss=current_loss, trigger_ids=trigger_ids, trigger_str=trigger_str) if self.time_limit is not None: if time.time() - start_time > self.time_limit: logger.info(f"Time limit of {self.time_limit}s reached. Stopping optimization.") break # (Optional) Early stopping if no improvement in the buffer if self.early_stopping_patience is not None: if step == 0: best_loss_global = current_loss steps_without_improvement = 0 # define the relative improvement denominator = abs(best_loss_global) if best_loss_global != 0 else 1.0 relative_improvement = (best_loss_global - current_loss) / denominator # Check if improvement is greater than the relative threshold if relative_improvement > self.early_stopping_threshold: steps_without_improvement = 0 else: steps_without_improvement += 1 if steps_without_improvement >= self.early_stopping_patience: logger.info(f"Early stopping triggered at step {step+1}. No relative improvement (of > {self.early_stopping_threshold*100:.2%}) in the last {self.early_stopping_patience} steps.") break # Update the best loss globally best_loss_global = min(best_loss_global, current_loss) return best.to_result()