Source code for tropt.optimizer.pez_optimizer

from __future__ import annotations
import logging
from typing import Optional

import torch
import torch.nn.functional as F
from jaxtyping import Float, Int
from torch import Tensor

from tropt.common import (
    DEFAULT_INIT_TRIGGER,
    Targets,
    TextTemplates,
)
from tropt.loss import BaseLoss
from tropt.model import (
    BaseModel,
    GradientEmbedAccessMixin,
    LossTokenAccessMixin,
)
from tropt.optimizer import BaseOptimizer, OptimizerResult
from tropt.optimizer.utils.running_best import RunningBest
from tropt.tracker import BaseTracker

logger = logging.getLogger(__name__)


[docs] class PEZOptimizer(BaseOptimizer): """ Optimizer of PEZ: optimizes contiuous trigger (AKA soft trigger), but projects it to discrete tokens every optimization step. Paper: https://arxiv.org/abs/2302.03668 Reference implementation: https://github.com/YuxinWenRick/hard-prompts-made-easy/blob/main/optim_utils.py """ model_requirements = (LossTokenAccessMixin, GradientEmbedAccessMixin) def __init__( self, model: BaseModel, loss: BaseLoss, tracker: Optional[BaseTracker] = None, seed: Optional[int] = None, # PEZ-specific parameters: num_steps: int = 300, learning_rate: float = 0.1, weight_decay: float = 0.1, gd_optimizer: type = torch.optim.SGD, ): """ Args: model (BaseModel): The language model to be attacked. loss (BaseLoss): The loss function to be optimized. tracker (BaseTracker, optional): An optional tracker for logging optimization progress. seed (int, optional): Random seed for reproducibility. num_steps: Number of optimization iterations. learning_rate: Learning rate for the optimizer. weight_decay: Weight decay for optimizer. gd_optimizer: Torch optimizer to use. """ super().__init__(model, loss=loss, tracker=tracker, seed=seed) self.num_steps = num_steps self.learning_rate = learning_rate self.weight_decay = weight_decay self.GDOptimizer = gd_optimizer
[docs] def optimize_trigger( self, templates: TextTemplates, initial_trigger: Optional[str] = DEFAULT_INIT_TRIGGER, targets: Optional[Targets] = None, ) -> OptimizerResult: # Initialization self.model.set_inputs_from_tokens(templates=templates, targets=targets) tokenizer = self.model.tokenizer trigger_ids = tokenizer.encode_trigger(initial_trigger).to(self.model.device) embedding_matrix = self.model.embedding_matrix # (vocab_size, embed_dim) # Initialize continuous embeddings from the initial trigger tokens trigger_embeds = self.model._embedding_layer(trigger_ids.unsqueeze(0)) # (1, trigger_seq_len, embed_dim) # Initialize optimizer on continuous embeddings optimizer = self.GDOptimizer( [trigger_embeds], lr=self.learning_rate, weight_decay=self.weight_decay ) best = RunningBest() for step in self.track_steps(range(self.num_steps), desc="PEZ Optimization"): optimizer.zero_grad() # Forward projection: project continuous embeddings to nearest vocab tokens projected_ids, projected_embeds = self._project_to_vocab( trigger_embeds.squeeze(0), embedding_matrix ) # Compute gradient w.r.t. the projected embeddings. # (this means grad are only calculates wrt actual vocab tokens, but then the grad is # appplied to the continuous embeddings `trigger_embeds`) trigger_grad, curr_loss = self.model.compute_grad_from_embeds( loss_func=self.loss_func, candidate_trigger_embeds=projected_embeds.unsqueeze(0), # (1, trigger_seq_len, embed_dim) normalize_grads=False, return_loss=True, ) # grad: (1, trigger_seq_len, embed_dim); loss: (1,) curr_loss = curr_loss.item() # Set gradient on the continuous embeddings and step trigger_embeds.grad = trigger_grad optimizer.step() # Decode current discrete trigger for tracking current_trigger_str = tokenizer.decode_trigger(projected_ids) self.log(loss=curr_loss, trigger_str=current_trigger_str) best.update(loss=curr_loss, trigger_ids=projected_ids.cpu(), trigger_str=current_trigger_str) # Final projection and evaluation on discrete tokens final_ids, _ = self._project_to_vocab( trigger_embeds.squeeze(0), embedding_matrix ) final_trigger_str = tokenizer.decode_trigger(final_ids) final_loss = self.model.compute_loss_from_tokens( final_ids.unsqueeze(0), loss_func=self.loss_func, ).item() best.update(loss=final_loss, trigger_ids=final_ids.cpu(), trigger_str=final_trigger_str) return best.to_result()
@torch.no_grad() def _project_to_vocab( self, embeds: Float[Tensor, "trigger_seq_len embed_dim"], embedding_matrix: Float[Tensor, "vocab_size embed_dim"], ) -> tuple[Int[Tensor, "trigger_seq_len"], Float[Tensor, "trigger_seq_len embed_dim"]]: """Project each embedding vector to its nearest neighbor in the vocabulary. Uses cosine similarity (normalized dot product) as in the paper. Returns: Tuple of (projected token IDs, projected embeddings). """ # Normalize both queries and vocabulary embeddings embeds_norm = F.normalize(embeds, dim=-1) # (trigger_seq_len, embed_dim) matrix_norm = F.normalize(embedding_matrix, dim=-1) # (vocab_size, embed_dim) # Cosine similarity via dot product of normalized vectors # (trigger_seq_len, vocab_size) cosim = embeds_norm @ matrix_norm.T nearest_ids = cosim.argmax(dim=-1) # (trigger_seq_len,) projected_embeds = embedding_matrix[nearest_ids] # (trigger_seq_len, embed_dim) return nearest_ids, projected_embeds