from __future__ import annotations
import logging
from typing import List, Optional
import torch
from jaxtyping import Float, Int
from torch import Tensor
from tropt.common import (
DEFAULT_INIT_TRIGGER,
Targets,
TextTemplates,
)
from tropt.loss import BaseLoss
from tropt.model import (
BaseModel,
GradientTokenAccessMixin,
LossTokenAccessMixin,
)
from tropt.optimizer import BaseOptimizer, OptimizerResult
from tropt.optimizer.utils.retokenization import retokenize_filtering
from tropt.optimizer.utils.running_best import RunningBest
from tropt.optimizer.utils.token_constraints import TokenConstraints
from tropt.tracker import BaseTracker
logger = logging.getLogger(__name__)
[docs]
class GCGOptimizer(BaseOptimizer):
"""
https://arxiv.org/abs/2307.15043
"""
model_requirements = (LossTokenAccessMixin, GradientTokenAccessMixin)
def __init__(
self,
model: BaseModel,
loss: BaseLoss,
tracker: Optional[BaseTracker] = None,
seed: Optional[int] = None,
# attack parameters:
num_steps: int = 500,
n_candidates: int = 512,
sample_topk: int = 256,
sample_n_replace: int = 1,
token_constraints: TokenConstraints = TokenConstraints(),
use_retokenize: bool = True,
):
"""
Implements the Greedy Coordinate Gradient (GCG) optimization algorithm for finding adversarial text triggers.
Args:
model (BaseModel): The language model to be attacked.
loss (BaseLoss): The loss function to be optimized.
tracker (BaseTracker, optional): An optional tracker for logging optimization progress.
seed (int, optional): Random seed for reproducibility.
# Attack parameters:
num_steps (int): Number of optimization steps to perform.
n_candidates (int): Number of candidate sequences to generate.
sample_topk (int): Number of top tokens to consider for each position.
sample_n_replace (int): Number of token positions to update per candidate.
"""
super().__init__(model, loss=loss, tracker=tracker, seed=seed)
# save params:
self.num_steps = num_steps
self.n_candidates = n_candidates
self.sample_topk = sample_topk
self.sample_n_replace = sample_n_replace
self.token_constraints = token_constraints
self.use_retokenize = use_retokenize
# validations:
assert self.loss_func.is_differentiable, "GCGOptimizer requires a differentiable loss function."
[docs]
def optimize_trigger(
self,
templates: TextTemplates,
initial_trigger: Optional[str] = DEFAULT_INIT_TRIGGER,
# objective-specific args:
targets: Optional[Targets] = None, # depends on the objective
) -> OptimizerResult:
# Initialization:
self.model.set_inputs_from_tokens(templates=templates, targets=targets)
tokenizer = self.model.tokenizer
trigger_ids: Int[Tensor, "trigger_seq_len"] = tokenizer.encode_trigger(initial_trigger).to(self.model.device)
vocab_size = self.model.vocab_size
blacklist_ids = self.token_constraints.get_blacklist_ids(tokenizer, vocab_size)
best = RunningBest()
# Compute loss before optimization
current_loss = self.model.compute_loss_from_tokens(
trigger_ids.unsqueeze(0), loss_func=self.loss_func
).item()
self.log(loss=current_loss, trigger_str=initial_trigger)
for _ in self.track_steps(range(self.num_steps)):
# Compute the trigger gradient
trigger_grad: Float[Tensor, "trigger_seq_len vocab_size"] = (
self.model.compute_grad_from_tokens(
candidate_trigger_ids=trigger_ids.unsqueeze(0),
loss_func=self.loss_func,
normalize_grads=True,
).squeeze(0) # take the only trigger
) # shape: (trigger_seq_len, vocab_size)
# Sample candidate token sequences based on the token gradient
candidate_trigger_ids: Int[Tensor, "n_candidates trigger_seq_len"] = (
self._sample_ids_from_grad(
trigger_ids=trigger_ids,
trigger_grad=trigger_grad,
blacklist_ids=blacklist_ids,
)
)
if self.use_retokenize:
candidate_trigger_ids = retokenize_filtering(
candidate_trigger_ids, tokenizer
)
# Compute loss on all candidate sequences
losses = self.model.compute_loss_from_tokens(
candidate_trigger_ids, loss_func=self.loss_func
) # shape: (n_candidates,)
current_loss = losses.min().item()
trigger_ids = candidate_trigger_ids[losses.argmin()]
trigger_str = tokenizer.decode_trigger(trigger_ids)
self.log(loss=current_loss, trigger_str=trigger_str)
best.update(loss=current_loss, trigger_ids=trigger_ids, trigger_str=trigger_str)
return best.to_result()
def _sample_ids_from_grad(
self,
trigger_ids: Int[Tensor, "trigger_seq_len"],
trigger_grad: Float[Tensor, "trigger_seq_len vocab_size"],
blacklist_ids: List[int] = [],
) -> Int[Tensor, "n_candidates trigger_seq_len"]:
"""
Samples `n_candidates` combinations of token ids based on the token gradient.
Args:
trigger_ids (Tensor): shape = (n_type, trigger_seq_len)
The sequence of token ids being optimized.
trigger_grad (Tensor): shape = (n_type, trigger_seq_len, vocab_size)
The gradient of the loss with respect to the one-hot token embeddings.
Returns:
Tensor: shape = (n_type, n_candidates, trigger_seq_len)
Sampled token ids for each candidate.
"""
trigger_seq_len, vocab_size = trigger_grad.shape
device = trigger_grad.device
candidate_trigger_ids = trigger_ids.repeat(self.n_candidates, 1).clone()
trigger_grad[:, blacklist_ids] = float("inf")
topk_ids: Float[Tensor, "trigger_seq_len sample_topk"] = (
(-trigger_grad).topk(self.sample_topk, dim=-1).indices
)
# Create random indices for each item in the batch and for each candidate.
sampled_ids_pos = torch.rand(
self.n_candidates, trigger_seq_len, device=device
).argsort(dim=-1)[
..., : self.sample_n_replace
] # shape: (n_candidates, sample_n_replace) # noqa
# Select the relevant lists of top-k tokens for each candidate and position
relevant_topk_lists = topk_ids[sampled_ids_pos]
# Randomly choose one token from each of the top-k lists
rand_k_indices = torch.randint(
0,
self.sample_topk,
(self.n_candidates, self.sample_n_replace, 1),
device=device,
)
# Gather the selected token ids using the random indices
sampled_ids_val = torch.gather(
input=relevant_topk_lists, # shape: (n_candidates, sample_n_replace, sample_topk)
dim=-1,
index=rand_k_indices, # shape: (n_candidates, sample_n_replace, 1)
).squeeze(-1) # shape: (n_candidates, sample_n_replace)
# Scatter the sampled token ids in the selected positions, within the trigger (=apply the flips)
candidate_trigger_ids = candidate_trigger_ids.scatter_(
dim=-1, # -> trigger_seq_len dimension
index=sampled_ids_pos,
src=sampled_ids_val,
)
return candidate_trigger_ids