from __future__ import annotations
import logging
from typing import List, Optional
import torch
from jaxtyping import Float, Int
from torch import Tensor
from tropt.common import (
DEFAULT_INIT_TRIGGER,
Targets,
TextTemplates,
)
from tropt.loss import BaseLoss
from tropt.model import (
BaseModel,
GradientTokenAccessMixin,
LossTokenAccessMixin,
)
from tropt.optimizer import BaseOptimizer, OptimizerResult
from tropt.optimizer.utils.retokenization import retokenize_transform
from tropt.optimizer.utils.running_best import RunningBest
from tropt.optimizer.utils.token_constraints import TokenConstraints
from tropt.tracker import BaseTracker
logger = logging.getLogger(__name__)
[docs]
class HotFlipOptimizer(BaseOptimizer):
"""
HotFlip: White-Box Adversarial Examples for Text Classification.
https://arxiv.org/abs/1712.06751
Uses first-order Taylor approximation of the loss to greedily select
token substitutions. Each flip is chosen as the (position, token) pair
that maximally decreases the estimated loss, without requiring a forward
pass for candidate evaluation. We implement the greedy variant introduced in the paper.
"""
model_requirements = (LossTokenAccessMixin, GradientTokenAccessMixin)
def __init__(
self,
model: BaseModel,
loss: BaseLoss,
tracker: Optional[BaseTracker] = None,
seed: Optional[int] = None,
# attack parameters:
num_steps: int = 500,
token_constraints: TokenConstraints = TokenConstraints(),
use_retokenize: bool = True,
):
"""
Args:
num_steps: Number of optimization steps (gradient is recomputed each step).
token_constraints: Token blacklist constraints. Was not originally included in the paper.
use_retokenize: Retokenize after flipping for decode/encode consistency. Was not originally included in the paper.
"""
super().__init__(model, loss=loss, tracker=tracker, seed=seed)
self.num_steps = num_steps
self.token_constraints = token_constraints
self.use_retokenize = use_retokenize
[docs]
def optimize_trigger(
self,
templates: TextTemplates,
initial_trigger: Optional[str] = DEFAULT_INIT_TRIGGER,
targets: Optional[Targets] = None,
) -> OptimizerResult:
# Initialization
self.model.set_inputs_from_tokens(templates=templates, targets=targets)
tokenizer = self.model.tokenizer
trigger_ids: Int[Tensor, "trigger_seq_len"] = tokenizer.encode_trigger(initial_trigger).to(self.model.device)
blacklist_ids = self.token_constraints.get_blacklist_ids(tokenizer, self.model.vocab_size)
best = RunningBest()
# Compute loss before optimization
current_loss = self.model.compute_loss_from_tokens(
trigger_ids.unsqueeze(0), loss_func=self.loss_func
).item()
self.log(loss=current_loss, trigger_str=initial_trigger)
for _ in self.track_steps(range(self.num_steps)):
# Step 1: Compute gradient wrt trigger
trigger_grad: Float[Tensor, "trigger_seq_len vocab_size"] = (
self.model.compute_grad_from_tokens(
candidate_trigger_ids=trigger_ids.unsqueeze(0),
loss_func=self.loss_func,
).squeeze(0)
)
# Step 2: Flip token(s) via first-order approximation
trigger_ids = self._apply_best_flip(trigger_ids, trigger_grad, blacklist_ids)
# Step 3: Retokenize for decode/encode consistency
if self.use_retokenize:
trigger_ids = retokenize_transform(trigger_ids, tokenizer)
# Step 4: Evaluate actual loss
current_loss = self.model.compute_loss_from_tokens(
trigger_ids.unsqueeze(0), loss_func=self.loss_func
).item()
trigger_str = tokenizer.decode_trigger(trigger_ids)
self.log(loss=current_loss, trigger_str=trigger_str)
best.update(
loss=current_loss, trigger_ids=trigger_ids, trigger_str=trigger_str
)
return best.to_result()
def _apply_best_flip(
self,
trigger_ids: Int[Tensor, "trigger_seq_len"],
trigger_grad: Float[Tensor, "trigger_seq_len vocab_size"],
blacklist_ids: List[int],
) -> Int[Tensor, "trigger_seq_len"]:
"""Select and apply token flip(s) using the first-order Taylor approximation.
The estimated loss change from flipping position i (token a_i -> b) is:
"""
trigger_seq_len = trigger_ids.shape[0]
device = trigger_ids.device
# Compute the flips' derivatives for every (position, replacement_token) pair
# Implementer note: this substraction is not really required, as it doesn't affect the candidate ranking;
# indeed, later adaptations of HotFlip (such as GCG) omit this substraction.
current_grad = trigger_grad[torch.arange(trigger_seq_len, device=device), trigger_ids]
delta = trigger_grad - current_grad.unsqueeze(1) # (trigger_seq_len, vocab_size)
# Equivalently, delta[i, b] = grad[i, b] - grad[i, a_i]
# Mask out blacklisted and current (no-op) tokens
delta[:, blacklist_ids] = float("inf")
delta[torch.arange(trigger_seq_len, device=device), trigger_ids] = float("inf")
# Best replacement token and score for each position
best_delta_per_pos, best_token_per_pos = delta.min(dim=1)
# Greedy: single best flip
best_pos = best_delta_per_pos.argmin().item()
assert isinstance(best_pos, int)
new_trigger_ids = trigger_ids.clone()
new_trigger_ids[best_pos] = best_token_per_pos[best_pos]
return new_trigger_ids