Source code for tropt.optimizer.rasliteplus_optimizer

from __future__ import annotations
import logging
import math
from typing import Optional

import torch
from jaxtyping import Float, Int
from torch import Tensor

from tropt.common import (
    DEFAULT_INIT_TRIGGER,
    Targets,
    TextTemplates,
)
from tropt.loss import BaseLoss
from tropt.model import (
    BaseModel,
    LMBaseModel,
    LogitsTokenAccessMixin,
    LossTextAccessMixin,
    TokenAccessMixin,
)
from tropt.optimizer import BaseOptimizer, OptimizerResult
from tropt.optimizer.utils.buffer import TriggerBuffer
from tropt.optimizer.utils.retokenization import retokenize_filtering
from tropt.optimizer.utils.running_best import RunningBest
from tropt.optimizer.utils.token_constraints import TokenConstraints
from tropt.optimizer.utils.token_initializers import get_printable_random_trigger
from tropt.tracker import BaseTracker

logger = logging.getLogger(__name__)

# TODO add outer-beam search (like BEAST)


[docs] class RASLITEPlusOptimizer(BaseOptimizer): """ Implements the RASLITEPlus optimization algorithm, which basically runs GASLITE against a black-box model; specifically, we use a util-LM for the tokenizer and to compute logits, using stratgies from GASLITEPlus (buffer, early stopping, decreasing n_flip, etc.). The key loss computations are done on text-level against the black-box target model. Builds on the paper: "GASLITEing the Retrieval: Exploring Vulnerabilities in Dense Embedding-based Search" (https://arxiv.org/abs/2412.20953) """ model_requirements = (LossTextAccessMixin,) def __init__( self, model: BaseModel, loss: BaseLoss, tracker: Optional[BaseTracker] = None, seed: Optional[int] = None, # attack parameters: num_steps: int = 100, n_logit_samples: Optional[int] = None, n_flip: int | float = 20, n_candidates: int = 128, token_constraints: TokenConstraints = TokenConstraints(), use_retokenize: bool = True, util_model: Optional[LMBaseModel] = None, # for logits calc use_random_logits: bool = False, # for possible ablation buffer_size: int = 10, decline_n_flip_from_step: Optional[int | float] = None, early_stopping_patience: Optional[int] = None, early_stopping_threshold: float = 0.005, # relative improvement threshold n_bulk_flips: int = 5, flip_pos_method: str = "random", # "random" or "ordered" **kwargs ): """ Initializes the RASLITEPlus Optimizer. Args: model (BaseModel): The target model to be attacked (black-box). loss (BaseLoss): The loss function to be optimized. seed (int, optional): Random seed for reproducibility. num_steps (int): Number of optimization iterations. n_logit_samples (int): Number of random flips for logit averaging on util_model. n_flip (int): Number of token positions to greedily optimize per step. n_candidates (int): Number of top candidate tokens to evaluate for each position. token_constraints (TokenConstraints): An object to manage token blacklisting. use_retokenize (bool): Whether to filter candidates that are not reversible by the tokenizer. Defaults to False, as we anyway only evaluate the target model with strings, and have no token-specific gradients. util_model (LMBaseModel, optional): Utility model for tokenization, and also optionally for logits calculation (if use_random_logits=False). If `None`, defaults to `model` (if it supports tokenization and logits calculation). use_random_logits (bool): If True, use random logits instead of actual logits from `util_model`. buffer_size (int): Size of the trigger buffer to maintain. decline_n_flip_from_step (int | float, optional): Decline schedule for n_flip. early_stopping_patience (int, optional): Early stopping patience steps. early_stopping_threshold (float): Relative improvement threshold for early stopping. n_bulk_flips (int): Number of bulk flips to perform per step. flip_pos_method (str): Method to select positions to flip ("random" or "ordered"). """ super().__init__(model, loss=loss, tracker=tracker, seed=seed) # save params: self.num_steps = num_steps self.n_logit_samples = n_logit_samples self.n_flip = n_flip self.n_candidates = n_candidates self.token_constraints = token_constraints self.use_retokenize = use_retokenize self.use_random_logits = use_random_logits if util_model is not None: self.util_model = util_model else: # No util model provided: fall back to the target's tokenizer when # only random logits are needed; otherwise we genuinely need a util LM. if not self.use_random_logits: raise ValueError( "RASLITEPlus requires a util_model with LM-logits access when " "use_random_logits=False. Pass util_model=<an LMBaseModel that " "implements LogitsTokenAccessMixin>, or set use_random_logits=True." ) self.util_model = self.model # Ensure tokenization capability on the chosen util_model. assert isinstance(self.util_model, TokenAccessMixin), ( "RASLITEPlus needs tokenization for its search. Neither the target model " "nor any util_model implements TokenAccessMixin -- pass util_model=<a model " "with a tokenizer> (e.g. EncoderOpenAIModel, LMHFModel)." ) # Ensure logits access capability (if not using random logits). if not self.use_random_logits: assert isinstance(self.util_model, LMBaseModel) and isinstance( self.util_model, LogitsTokenAccessMixin ), "RASLITEPlus requires util_model to be LM with token logits access" self.buffer_size = buffer_size self.decline_n_flip_from_step = decline_n_flip_from_step # early stopping params self.early_stopping_patience = early_stopping_patience self.early_stopping_threshold = early_stopping_threshold self.n_bulk_flips = n_bulk_flips self.flip_pos_method = flip_pos_method
[docs] def optimize_trigger( self, templates: TextTemplates, initial_trigger: Optional[str] = DEFAULT_INIT_TRIGGER, targets: Optional[Targets] = None, ) -> OptimizerResult: # Initialization: # We prepare inputs for both models. # The optimization (candidates, buffer) operates on `util_trigger_ids` (token space of util_model). # The assessment operates on text-level via `model` (text space of target model). self.model.set_inputs_from_texts(templates=templates, targets=targets) self.util_model.set_inputs_from_tokens(templates=templates, targets=targets) util_tokenizer = self.util_model.tokenizer util_trigger_ids: Int[Tensor, "trigger_seq_len"] = util_tokenizer.encode_trigger(initial_trigger).to(self.util_model.device) util_vocab_size = self.util_model.vocab_size util_blacklist_ids = self.token_constraints.get_blacklist_ids( util_tokenizer, util_vocab_size ) trigger_seq_len = len(util_trigger_ids) trigger_str = initial_trigger n_flip = self.n_flip if isinstance(n_flip, float): n_flip = math.ceil(n_flip * trigger_seq_len) best = RunningBest() current_loss = float("inf") # Form buffer_size initial triggers triggers_for_buffer = [util_trigger_ids] for _ in range(self.buffer_size - 1): random_trigger_ids = get_printable_random_trigger( trigger_seq_len, tokenizer=util_tokenizer, return_ids=True ).to(self.util_model.device) triggers_for_buffer.append(random_trigger_ids) # Compute losses for initial triggers (requires text conversion) trigger_strs_buffer = [util_tokenizer.decode_trigger(t_ids) for t_ids in triggers_for_buffer] losses = self.model.compute_loss_from_texts( trigger_strs_buffer, self.loss_func, ) # (n_cands,) # Create the buffer: buffer = TriggerBuffer( triggers=[triggers_for_buffer[i] for i in range(self.buffer_size)], losses=[losses[i].item() for i in range(self.buffer_size)], ) trigger_str = util_tokenizer.decode_trigger(buffer.get_best_trigger()) self.log(loss=buffer.get_lowest_loss(), trigger_str=trigger_str) for step in self.track_steps(range(self.num_steps), desc="Optimizing with RASLITEPlus..."): # Get the best trigger from the buffer util_trigger_ids = buffer.get_best_trigger() trigger_str = util_tokenizer.decode_trigger(util_trigger_ids) # --- Candidate selection step (logit-based) --- if self.use_random_logits: # Random logits trigger_grad = torch.rand( trigger_seq_len, util_vocab_size, device=self.util_model.device ) else: if self.n_logit_samples is not None and self.n_logit_samples > 1: # Average logits over variations trigger_vars = self._get_trigger_variations(util_trigger_ids, util_vocab_size, device=self.util_model.device) logits = self.util_model.compute_logits_from_tokens( trigger_vars, return_trigger_logits_only=True, keep_message_dim=False, ) # (n_vars, seq_len, vocab_size) trigger_grad = logits.mean(dim=0) else: trigger_grad = self.util_model.compute_logits_from_tokens( util_trigger_ids.unsqueeze(0), return_trigger_logits_only=True, keep_message_dim=False, ).squeeze(0) # Get Top-k Candidates *per position* trigger_grad[:, util_blacklist_ids] = float("-inf") topk_ids: Float[Tensor, "trigger_seq_len n_candidates"] topk_ids = trigger_grad.topk(self.n_candidates, dim=-1).indices # --- Greedy coordinate ascent step --- current_trigger_ids = util_trigger_ids.clone() # Sample `n_flip` unique positions to optimize sampled_positions = torch.randperm(trigger_seq_len, device=self.util_model.device)[: n_flip] if self.flip_pos_method == "ordered": sampled_positions, _ = sampled_positions.sort() # Perform bulk flips in chunks bulk_pos_list = torch.chunk(sampled_positions, self.n_bulk_flips) for bulk_pos in bulk_pos_list: candidate_triggers = current_trigger_ids.repeat(self.n_candidates, 1) # Inject candidate tokens at all positions in the bulk for pos in bulk_pos: # Get candidate tokens for this position all_candidate_tokens = torch.cat([ current_trigger_ids[pos].unsqueeze(0), # keep the "no flip" option topk_ids[pos, :self.n_candidates // 2], ]) if len(bulk_pos) > 1: # sample more token ids, with replacements if bulk is large more_cand_indices = torch.randint( high=topk_ids[pos].size(0), size=(self.n_candidates - len(all_candidate_tokens),), device=topk_ids.device ) else: more_cand_indices = torch.arange( self.n_candidates // 2, self.n_candidates // 2 + (self.n_candidates - len(all_candidate_tokens)), device=topk_ids.device ) all_candidate_tokens = torch.cat([ all_candidate_tokens, topk_ids[pos, more_cand_indices] ]) # Create all candidate triggers by flipping this *single* position candidate_triggers[:, pos] = all_candidate_tokens # keep only unique candidates candidate_triggers = torch.unique(candidate_triggers, dim=0) # (Optional) Retokenize filtering if self.use_retokenize: candidate_triggers = retokenize_filtering( candidate_triggers, util_tokenizer ) if len(candidate_triggers) == 0: logger.warning( f"Retokenize filtering removed all candidates for pos {pos}. Skipping step." ) continue # Compute losses on candidate flips (on TARGET model via text) candidate_strs = util_tokenizer.decode_triggers(candidate_triggers) losses = self.model.compute_loss_from_texts( candidate_strs, self.loss_func, ) # (n_cands,) # Find the best token for this position losses_sorted_indices = torch.argsort(losses) best_candidate_idx = losses_sorted_indices[0] # Update `current_trigger_ids` *in-place* current_trigger_ids = candidate_triggers[best_candidate_idx].clone() current_loss = losses[best_candidate_idx].item() # Update buffer for j in range(min(buffer.size, len(losses))): cand_idx = losses_sorted_indices[j] buffer.add_if_better( candidate_triggers[cand_idx].clone(), losses[cand_idx].item(), ) # After the inner loop, `current_trigger_ids` is the best trigger for this *entire* step util_trigger_ids = current_trigger_ids trigger_str = util_tokenizer.decode_trigger(util_trigger_ids) # (Optional) update n_flip if needed (linear scheduling) if self.decline_n_flip_from_step is not None: # Determine start step if isinstance(self.decline_n_flip_from_step, float): decline_step = int(self.num_steps * self.decline_n_flip_from_step) else: decline_step = int(self.decline_n_flip_from_step) # If past the step, linearly decline n_steps to 1 if step >= decline_step: final_step = self.num_steps steps_remaining = final_step - step decline_duration = final_step - decline_step if decline_duration > 0: ratio = steps_remaining / decline_duration n_flip = max(1, math.ceil(self.n_flip * ratio)) # Logging: self.log(loss=current_loss, trigger_str=trigger_str) best.update(loss=current_loss, trigger_ids=util_trigger_ids, trigger_str=trigger_str) # (Optional) Early stopping if self.early_stopping_patience is not None: if step == 0: best_loss_global = current_loss steps_without_improvement = 0 denominator = abs(best_loss_global) if best_loss_global != 0 else 1.0 relative_improvement = (best_loss_global - current_loss) / denominator if relative_improvement > self.early_stopping_threshold: steps_without_improvement = 0 else: steps_without_improvement += 1 if steps_without_improvement >= self.early_stopping_patience: logger.info(f"Early stopping triggered at step {step+1}. No relative improvement (of > {self.early_stopping_threshold*100:.2%}) in the last {self.early_stopping_patience} steps.") break best_loss_global = min(best_loss_global, current_loss) # Return the best trigger found result = best.to_result() logger.info(f"Best loss: {result.best_loss} | Best trigger: {result.best_trigger_str}") return result
def _get_trigger_variations( self, trigger_ids: Float[Tensor, "trigger_seq_len"], vocab_size: int, device: torch.device, ) -> Float[Tensor, "n_logit_samples trigger_seq_len"]: """ Creates a list of `n_logit_samples` trigger variations. The first is the original trigger, and the rest are random single-token flips. """ trigger_seq_len = len(trigger_ids) trigger_vars_ids = trigger_ids.repeat( self.n_logit_samples, 1 ) # shape: (n_logit_samples, trigger_seq_len) for idx in range(1, self.n_logit_samples): # (keep the first intact) # select a random position and a random token pos_to_flip = int(torch.randint(0, trigger_seq_len, (1,), device=device).item()) tok_to_flip_to = int(torch.randint(0, vocab_size, (1,), device=device).item()) # apply the flip trigger_vars_ids[idx, pos_to_flip] = tok_to_flip_to return trigger_vars_ids