from __future__ import annotations
import logging
import math
from typing import Optional
import torch
from jaxtyping import Float, Int
from torch import Tensor
from tropt.common import (
DEFAULT_INIT_TRIGGER,
Targets,
TextTemplates,
)
from tropt.loss import BaseLoss
from tropt.model import (
BaseModel,
LMBaseModel,
LogitsTokenAccessMixin,
LossTextAccessMixin,
TokenAccessMixin,
)
from tropt.optimizer import BaseOptimizer, OptimizerResult
from tropt.optimizer.utils.buffer import TriggerBuffer
from tropt.optimizer.utils.retokenization import retokenize_filtering
from tropt.optimizer.utils.running_best import RunningBest
from tropt.optimizer.utils.token_constraints import TokenConstraints
from tropt.optimizer.utils.token_initializers import get_printable_random_trigger
from tropt.tracker import BaseTracker
logger = logging.getLogger(__name__)
# TODO add outer-beam search (like BEAST)
[docs]
class RASLITEPlusOptimizer(BaseOptimizer):
"""
Implements the RASLITEPlus optimization algorithm, which basically runs GASLITE against a black-box model;
specifically, we use a util-LM for the tokenizer and to compute logits, using stratgies from GASLITEPlus
(buffer, early stopping, decreasing n_flip, etc.). The key loss computations are done on text-level
against the black-box target model.
Builds on the paper: "GASLITEing the Retrieval: Exploring Vulnerabilities in Dense Embedding-based Search"
(https://arxiv.org/abs/2412.20953)
"""
model_requirements = (LossTextAccessMixin,)
def __init__(
self,
model: BaseModel,
loss: BaseLoss,
tracker: Optional[BaseTracker] = None,
seed: Optional[int] = None,
# attack parameters:
num_steps: int = 100,
n_logit_samples: Optional[int] = None,
n_flip: int | float = 20,
n_candidates: int = 128,
token_constraints: TokenConstraints = TokenConstraints(),
use_retokenize: bool = True,
util_model: Optional[LMBaseModel] = None, # for logits calc
use_random_logits: bool = False, # for possible ablation
buffer_size: int = 10,
decline_n_flip_from_step: Optional[int | float] = None,
early_stopping_patience: Optional[int] = None,
early_stopping_threshold: float = 0.005, # relative improvement threshold
n_bulk_flips: int = 5,
flip_pos_method: str = "random", # "random" or "ordered"
**kwargs
):
"""
Initializes the RASLITEPlus Optimizer.
Args:
model (BaseModel): The target model to be attacked (black-box).
loss (BaseLoss): The loss function to be optimized.
seed (int, optional): Random seed for reproducibility.
num_steps (int): Number of optimization iterations.
n_logit_samples (int): Number of random flips for logit averaging on util_model.
n_flip (int): Number of token positions to greedily optimize per step.
n_candidates (int): Number of top candidate tokens to evaluate for each position.
token_constraints (TokenConstraints): An object to manage token blacklisting.
use_retokenize (bool): Whether to filter candidates that are not reversible by the tokenizer. Defaults
to False, as we anyway only evaluate the target model with strings, and have no token-specific gradients.
util_model (LMBaseModel, optional): Utility model for tokenization, and also optionally for logits calculation (if use_random_logits=False).
If `None`, defaults to `model` (if it supports tokenization and logits calculation).
use_random_logits (bool): If True, use random logits instead of actual logits from `util_model`.
buffer_size (int): Size of the trigger buffer to maintain.
decline_n_flip_from_step (int | float, optional): Decline schedule for n_flip.
early_stopping_patience (int, optional): Early stopping patience steps.
early_stopping_threshold (float): Relative improvement threshold for early stopping.
n_bulk_flips (int): Number of bulk flips to perform per step.
flip_pos_method (str): Method to select positions to flip ("random" or "ordered").
"""
super().__init__(model, loss=loss, tracker=tracker, seed=seed)
# save params:
self.num_steps = num_steps
self.n_logit_samples = n_logit_samples
self.n_flip = n_flip
self.n_candidates = n_candidates
self.token_constraints = token_constraints
self.use_retokenize = use_retokenize
self.use_random_logits = use_random_logits
if util_model is not None:
self.util_model = util_model
else:
# No util model provided: fall back to the target's tokenizer when
# only random logits are needed; otherwise we genuinely need a util LM.
if not self.use_random_logits:
raise ValueError(
"RASLITEPlus requires a util_model with LM-logits access when "
"use_random_logits=False. Pass util_model=<an LMBaseModel that "
"implements LogitsTokenAccessMixin>, or set use_random_logits=True."
)
self.util_model = self.model
# Ensure tokenization capability on the chosen util_model.
assert isinstance(self.util_model, TokenAccessMixin), (
"RASLITEPlus needs tokenization for its search. Neither the target model "
"nor any util_model implements TokenAccessMixin -- pass util_model=<a model "
"with a tokenizer> (e.g. EncoderOpenAIModel, LMHFModel)."
)
# Ensure logits access capability (if not using random logits).
if not self.use_random_logits:
assert isinstance(self.util_model, LMBaseModel) and isinstance(
self.util_model, LogitsTokenAccessMixin
), "RASLITEPlus requires util_model to be LM with token logits access"
self.buffer_size = buffer_size
self.decline_n_flip_from_step = decline_n_flip_from_step
# early stopping params
self.early_stopping_patience = early_stopping_patience
self.early_stopping_threshold = early_stopping_threshold
self.n_bulk_flips = n_bulk_flips
self.flip_pos_method = flip_pos_method
[docs]
def optimize_trigger(
self,
templates: TextTemplates,
initial_trigger: Optional[str] = DEFAULT_INIT_TRIGGER,
targets: Optional[Targets] = None,
) -> OptimizerResult:
# Initialization:
# We prepare inputs for both models.
# The optimization (candidates, buffer) operates on `util_trigger_ids` (token space of util_model).
# The assessment operates on text-level via `model` (text space of target model).
self.model.set_inputs_from_texts(templates=templates, targets=targets)
self.util_model.set_inputs_from_tokens(templates=templates, targets=targets)
util_tokenizer = self.util_model.tokenizer
util_trigger_ids: Int[Tensor, "trigger_seq_len"] = util_tokenizer.encode_trigger(initial_trigger).to(self.util_model.device)
util_vocab_size = self.util_model.vocab_size
util_blacklist_ids = self.token_constraints.get_blacklist_ids(
util_tokenizer, util_vocab_size
)
trigger_seq_len = len(util_trigger_ids)
trigger_str = initial_trigger
n_flip = self.n_flip
if isinstance(n_flip, float):
n_flip = math.ceil(n_flip * trigger_seq_len)
best = RunningBest()
current_loss = float("inf")
# Form buffer_size initial triggers
triggers_for_buffer = [util_trigger_ids]
for _ in range(self.buffer_size - 1):
random_trigger_ids = get_printable_random_trigger(
trigger_seq_len, tokenizer=util_tokenizer, return_ids=True
).to(self.util_model.device)
triggers_for_buffer.append(random_trigger_ids)
# Compute losses for initial triggers (requires text conversion)
trigger_strs_buffer = [util_tokenizer.decode_trigger(t_ids) for t_ids in triggers_for_buffer]
losses = self.model.compute_loss_from_texts(
trigger_strs_buffer,
self.loss_func,
) # (n_cands,)
# Create the buffer:
buffer = TriggerBuffer(
triggers=[triggers_for_buffer[i] for i in range(self.buffer_size)],
losses=[losses[i].item() for i in range(self.buffer_size)],
)
trigger_str = util_tokenizer.decode_trigger(buffer.get_best_trigger())
self.log(loss=buffer.get_lowest_loss(), trigger_str=trigger_str)
for step in self.track_steps(range(self.num_steps), desc="Optimizing with RASLITEPlus..."):
# Get the best trigger from the buffer
util_trigger_ids = buffer.get_best_trigger()
trigger_str = util_tokenizer.decode_trigger(util_trigger_ids)
# --- Candidate selection step (logit-based) ---
if self.use_random_logits:
# Random logits
trigger_grad = torch.rand(
trigger_seq_len, util_vocab_size, device=self.util_model.device
)
else:
if self.n_logit_samples is not None and self.n_logit_samples > 1:
# Average logits over variations
trigger_vars = self._get_trigger_variations(util_trigger_ids, util_vocab_size, device=self.util_model.device)
logits = self.util_model.compute_logits_from_tokens(
trigger_vars,
return_trigger_logits_only=True,
keep_message_dim=False,
) # (n_vars, seq_len, vocab_size)
trigger_grad = logits.mean(dim=0)
else:
trigger_grad = self.util_model.compute_logits_from_tokens(
util_trigger_ids.unsqueeze(0),
return_trigger_logits_only=True,
keep_message_dim=False,
).squeeze(0)
# Get Top-k Candidates *per position*
trigger_grad[:, util_blacklist_ids] = float("-inf")
topk_ids: Float[Tensor, "trigger_seq_len n_candidates"]
topk_ids = trigger_grad.topk(self.n_candidates, dim=-1).indices
# --- Greedy coordinate ascent step ---
current_trigger_ids = util_trigger_ids.clone()
# Sample `n_flip` unique positions to optimize
sampled_positions = torch.randperm(trigger_seq_len, device=self.util_model.device)[: n_flip]
if self.flip_pos_method == "ordered":
sampled_positions, _ = sampled_positions.sort()
# Perform bulk flips in chunks
bulk_pos_list = torch.chunk(sampled_positions, self.n_bulk_flips)
for bulk_pos in bulk_pos_list:
candidate_triggers = current_trigger_ids.repeat(self.n_candidates, 1)
# Inject candidate tokens at all positions in the bulk
for pos in bulk_pos:
# Get candidate tokens for this position
all_candidate_tokens = torch.cat([
current_trigger_ids[pos].unsqueeze(0), # keep the "no flip" option
topk_ids[pos, :self.n_candidates // 2],
])
if len(bulk_pos) > 1:
# sample more token ids, with replacements if bulk is large
more_cand_indices = torch.randint(
high=topk_ids[pos].size(0),
size=(self.n_candidates - len(all_candidate_tokens),),
device=topk_ids.device
)
else:
more_cand_indices = torch.arange(
self.n_candidates // 2,
self.n_candidates // 2 + (self.n_candidates - len(all_candidate_tokens)),
device=topk_ids.device
)
all_candidate_tokens = torch.cat([
all_candidate_tokens,
topk_ids[pos, more_cand_indices]
])
# Create all candidate triggers by flipping this *single* position
candidate_triggers[:, pos] = all_candidate_tokens
# keep only unique candidates
candidate_triggers = torch.unique(candidate_triggers, dim=0)
# (Optional) Retokenize filtering
if self.use_retokenize:
candidate_triggers = retokenize_filtering(
candidate_triggers, util_tokenizer
)
if len(candidate_triggers) == 0:
logger.warning(
f"Retokenize filtering removed all candidates for pos {pos}. Skipping step."
)
continue
# Compute losses on candidate flips (on TARGET model via text)
candidate_strs = util_tokenizer.decode_triggers(candidate_triggers)
losses = self.model.compute_loss_from_texts(
candidate_strs, self.loss_func,
) # (n_cands,)
# Find the best token for this position
losses_sorted_indices = torch.argsort(losses)
best_candidate_idx = losses_sorted_indices[0]
# Update `current_trigger_ids` *in-place*
current_trigger_ids = candidate_triggers[best_candidate_idx].clone()
current_loss = losses[best_candidate_idx].item()
# Update buffer
for j in range(min(buffer.size, len(losses))):
cand_idx = losses_sorted_indices[j]
buffer.add_if_better(
candidate_triggers[cand_idx].clone(),
losses[cand_idx].item(),
)
# After the inner loop, `current_trigger_ids` is the best trigger for this *entire* step
util_trigger_ids = current_trigger_ids
trigger_str = util_tokenizer.decode_trigger(util_trigger_ids)
# (Optional) update n_flip if needed (linear scheduling)
if self.decline_n_flip_from_step is not None:
# Determine start step
if isinstance(self.decline_n_flip_from_step, float):
decline_step = int(self.num_steps * self.decline_n_flip_from_step)
else:
decline_step = int(self.decline_n_flip_from_step)
# If past the step, linearly decline n_steps to 1
if step >= decline_step:
final_step = self.num_steps
steps_remaining = final_step - step
decline_duration = final_step - decline_step
if decline_duration > 0:
ratio = steps_remaining / decline_duration
n_flip = max(1, math.ceil(self.n_flip * ratio))
# Logging:
self.log(loss=current_loss, trigger_str=trigger_str)
best.update(loss=current_loss, trigger_ids=util_trigger_ids, trigger_str=trigger_str)
# (Optional) Early stopping
if self.early_stopping_patience is not None:
if step == 0:
best_loss_global = current_loss
steps_without_improvement = 0
denominator = abs(best_loss_global) if best_loss_global != 0 else 1.0
relative_improvement = (best_loss_global - current_loss) / denominator
if relative_improvement > self.early_stopping_threshold:
steps_without_improvement = 0
else:
steps_without_improvement += 1
if steps_without_improvement >= self.early_stopping_patience:
logger.info(f"Early stopping triggered at step {step+1}. No relative improvement (of > {self.early_stopping_threshold*100:.2%}) in the last {self.early_stopping_patience} steps.")
break
best_loss_global = min(best_loss_global, current_loss)
# Return the best trigger found
result = best.to_result()
logger.info(f"Best loss: {result.best_loss} | Best trigger: {result.best_trigger_str}")
return result
def _get_trigger_variations(
self,
trigger_ids: Float[Tensor, "trigger_seq_len"],
vocab_size: int,
device: torch.device,
) -> Float[Tensor, "n_logit_samples trigger_seq_len"]:
"""
Creates a list of `n_logit_samples` trigger variations. The first is the
original trigger, and the rest are random single-token flips.
"""
trigger_seq_len = len(trigger_ids)
trigger_vars_ids = trigger_ids.repeat(
self.n_logit_samples, 1
) # shape: (n_logit_samples, trigger_seq_len)
for idx in range(1, self.n_logit_samples): # (keep the first intact)
# select a random position and a random token
pos_to_flip = int(torch.randint(0, trigger_seq_len, (1,), device=device).item())
tok_to_flip_to = int(torch.randint(0, vocab_size, (1,), device=device).item())
# apply the flip
trigger_vars_ids[idx, pos_to_flip] = tok_to_flip_to
return trigger_vars_ids