from __future__ import annotations
import logging
from typing import Optional
import torch
from tropt.common import Targets, TextTemplates
from tropt.loss import BaseLoss
from tropt.model import (
BaseModel,
LMBaseModel,
LogitsTokenAccessMixin,
)
from tropt.model.model_mixins import LossTextAccessMixin
from tropt.optimizer import BaseOptimizer, OptimizerResult
from tropt.optimizer.utils.running_best import RunningBest
from tropt.optimizer.utils.token_constraints import TokenConstraints
from tropt.tracker import BaseTracker
logger = logging.getLogger(__name__)
[docs]
class BeamSearchOptimizer(BaseOptimizer):
"""
An LM beam search-based optimizer.
The general idea is to sample tokens while generating from a util LM, and
steer the generation towards the desired objective(s) on the target model.
Combines the implementations of BEAST and AdvDecoding optimizers:
- BEAST optimizer: https://arxiv.org/abs/2402.15570
- AdvDecoding optimizer: https://arxiv.org/abs/2410.02163
"""
model_requirements = (LossTextAccessMixin,)
def __init__(
self,
model: BaseModel,
loss: BaseLoss,
tracker: Optional[BaseTracker] = None,
seed: Optional[int] = None,
# attack parameters:
util_lm: Optional[LMBaseModel] = None, # if None, use the same as `model`
util_lm_prefix: Optional[str] = None,
num_steps: int = 40,
beam_size: int = 15,
branching_factor: int = 15,
top_k: Optional[int] = None,
temperature: float = 1.0,
token_constraints: TokenConstraints = TokenConstraints(),
):
"""
Initializes the BeamSearch Optimizer.
Args:
model (HuggingFaceModel): The model to be attacked.
loss (BaseLoss): The loss function to be optimized.
seed (int, optional): Random seed for reproducibility.
util_lm (LMBaseModel, optional): Utility LM for generating candidates.
If None, internally uses the same as `model` (as in original BEAST paper).
util_lm_prefix (str, optional): A prefix to seed the util LM for next-token computation of the trigger.
If None, defaults to the same `templates` provided to `optimize_trigger()`.
num_steps (int): Number of optimization iterations (L in paper); also represents the length of the crafted trigger. Default: 40
beam_size (int): Number of beams to maintain (k1 in paper). Default: 15
branching_factor (int): Number of candidates (=tokens) per beam (k2 in paper). Default: 15
top_k (int, optional): Optional top-k filtering before multinomial sampling.
If None, samples from full distribution (as in original BEAST paper).
temperature (float): Sampling temperature. Default: 1.0 (as in paper)
token_constraints (TokenConstraints): An object to manage token blacklisting.
"""
super().__init__(model, loss=loss, tracker=tracker, seed=seed)
# define the util LM model (defaults to the attacked model)
if util_lm is None:
# shallow copies the model object:
# (we do this so each model will have its own state, specifically for input-management)
import copy
assert isinstance(model, LMBaseModel)
util_lm = copy.copy(model)
self.util_lm = util_lm
assert isinstance(self.util_lm, LMBaseModel) and isinstance(
self.util_lm, LogitsTokenAccessMixin
), "BEAST requires util_lm to be LM with token logits access"
self.util_lm_prefix = util_lm_prefix
self.num_steps = num_steps
self.beam_size = beam_size
self.branching_factor = branching_factor
self.top_k = top_k
self.temperature = temperature
self.token_constraints = token_constraints
[docs]
def optimize_trigger(
self,
templates: TextTemplates,
initial_trigger: Optional[str] = None,
targets: Optional[Targets] = None,
) -> OptimizerResult:
"""
Optimize the trigger using BEAST algorithm.
Args:
templates (TextTemplates): List of text templates to optimize the trigger against.
targets (Optional[Targets], optional): Target values for the loss function.
Implementation notes:
- We use the auxiliary LM (`util_lm`) to samples candidate tokens for the trigger.
(Note that in the original BEAST it was the same as the attacked LM; other attack use separate utility LM)
- Then, we evaluate the candidate triggers on the targeted model (`model`) to compute the losses.
- This loss evaluation against the target model is usually done in a black-box manner using text-level access (i.e., we query the model with the full text including the decoded candidate triggers), to enable the attack of fully black-box models; however, if util and target model share the same tokenizer, we can compute loss in token-level using the `use_model_with_token_inputs` option.
"""
# Prepare inputs for both target model and util LM
self.util_lm.set_inputs_from_tokens(
templates=[self.util_lm_prefix] * len(templates) if self.util_lm_prefix is not None else templates,
targets=targets,
)
self.model.set_inputs_from_texts(templates=templates, targets=targets)
util_tokenizer = self.util_lm.tokenizer
util_blacklist_ids = self.token_constraints.get_blacklist_ids(
util_tokenizer, self.util_lm.vocab_size
)
# starts with an empty trigger
util_trigger_ids = torch.zeros((1, 0), dtype=torch.long, device=self.model.device)
# Initialize by sampling beam_size diverse starting tokens
# (BEAST: Algorithm 1 in paper; lines 2-7)
# Get logits for first token position
initial_logits = self.util_lm.compute_logits_from_tokens(
util_trigger_ids, # empty trigger
return_after_trigger_logits_only=True
).squeeze(1) # (1, vocab_size)
# Sample beam_size different initial tokens
initial_probs = torch.softmax(initial_logits / self.temperature, dim=-1)
initial_probs[:, util_blacklist_ids] = 0
initial_probs = initial_probs / initial_probs.sum(dim=-1, keepdim=True)
# Sample beam_size different tokens using multinomial sampling (without replacement)
initial_tokens = self._sample_multinomial(
initial_probs,
return_tokens=self.beam_size,
top_k=self.top_k
) # (1, beam_size)
# Initialize beam with different starting tokens
beam_trigger_ids = initial_tokens.t() # (beam_size, 1)
best = RunningBest()
# Iterate for num_steps steps (BEAST: Algorithm 1 in paper; lines 8-23)
# We already have 1 token, so iterate num_steps - 1 times
for step in self.track_steps(range(self.num_steps - 1), desc="Beam Search optimization"):
# 1. Get logits for the next trigger token (adv[-1]'s)
next_token_logits = self.util_lm.compute_logits_from_tokens(
beam_trigger_ids, return_after_trigger_logits_only=True
) # (beam, 1, vocab_size)
next_token_logits = next_token_logits.squeeze(1) # (beam, vocab_size)
# 2. Sample candidate next trigger tokens using multinomial sampling
next_token_logits[:, util_blacklist_ids] = float("-inf") # Block disallowed tokens before softmax
probs = torch.softmax(next_token_logits / self.temperature, dim=-1)
# Sample branching_factor candidates per beam using multinomial sampling (without replacement)
candidate_next_tokens = self._sample_multinomial(
probs,
return_tokens=self.branching_factor,
top_k=self.top_k
) # (beam, branching_factor)
# 3. Expand beam: create beam_size x branching_factor candidate sequences
# Append each of branching_factor tokens to each of beam_size beams
repeated_triggers = beam_trigger_ids.repeat_interleave(
self.branching_factor, dim=0
) # (beam, len) -> (beam * branching_factor, len)
candidate_next_tokens = candidate_next_tokens.reshape(
-1, 1
) # (beam, branching_factor) -> (beam * branching_factor, 1)
candidate_triggers = torch.cat(
[repeated_triggers, candidate_next_tokens], dim=-1
) # append candidate tokens -> (beam * branching_factor, len+1)
# 5. Compute losses for all beam x branching_factor candidate triggers
candidate_triggers_text = self.util_lm.tokenizer.decode_triggers(candidate_triggers)
losses = self.model.compute_loss_from_texts(
candidate_triggers_text,
loss_func=self.loss_func
)
# 6. Select top beam_size candidates with lowest loss, keep their trigger ids
top_losses, top_indices = torch.topk(losses, self.beam_size, largest=False)
beam_trigger_ids = candidate_triggers[top_indices]
current_loss = top_losses[0].item()
# Track best trigger
best_trigger_ids = beam_trigger_ids[0]
trigger_str = util_tokenizer.decode_trigger(best_trigger_ids)
best.update(loss=current_loss, trigger_ids=best_trigger_ids, trigger_str=trigger_str)
self.log(loss=current_loss, trigger_str=trigger_str)
return best.to_result()
@staticmethod
@torch.no_grad()
def _sample_multinomial(probs, return_tokens=0, top_k=None):
"""
Sample tokens from probability distribution using multinomial sampling.
Optionally filter to top-k tokens before sampling.
Params:
probs: probability distribution (should sum to 1)
return_tokens: number of tokens to sample
top_k: if provided, only sample from top-k highest probability tokens
Return:
sampled tokens: (batch_size, return_tokens)
"""
# If probs do not sum to (roughly) 1, apply softmax
if not torch.allclose(
probs.sum(dim=-1), torch.ones_like(probs.sum(dim=-1)), atol=1e-3
):
probs = torch.softmax(probs, dim=-1)
if top_k is not None:
# Filter to top-k tokens
top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.size(-1)), dim=-1)
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
# Sample from top-k distribution
sampled_indices = torch.multinomial(
top_k_probs,
num_samples=min(return_tokens, top_k),
replacement=False
)
# Map back to original vocabulary indices
next_tokens = torch.gather(top_k_indices, -1, sampled_indices)
else:
# Sample directly from full distribution
next_tokens = torch.multinomial(
probs,
num_samples=return_tokens,
replacement=False
)
return next_tokens