Source code for tropt.optimizer.soft_optimizer
from __future__ import annotations
import logging
from typing import Callable, Optional
import torch
from tropt.common import (
DEFAULT_INIT_TRIGGER,
Targets,
TextTemplates,
)
from tropt.loss import BaseLoss
from tropt.model import (
BaseModel,
GradientEmbedAccessMixin,
)
from tropt.optimizer.base import BaseOptimizer, OptimizerResult
from tropt.optimizer.utils.running_best import RunningBest
from tropt.tracker import BaseTracker
logger = logging.getLogger(__name__)
[docs]
class SoftPromptOptimizer(BaseOptimizer):
"""
Optimizing soft prompts
"""
model_requirements = (GradientEmbedAccessMixin,)
def __init__(
self,
model: BaseModel,
loss: BaseLoss,
tracker: Optional[BaseTracker] = None,
seed: Optional[int] = None,
# Soft prompt optimization parameters:
num_steps: int = 100,
learning_rate: float = 0.001,
gd_optimizer: Callable[..., torch.optim.Optimizer] = torch.optim.Adam,
):
"""
Args:
model: The target model to attack (must support gradient computation)
loss: The loss function to optimize
tracker: Experiment tracker for logging
seed: Random seed for reproducibility
num_steps: Number of optimization iterations
learning_rate: Learning rate for the gradient descent optimizer
gd_optimizer: The gradient descent optimizer Torch class to use (e.g., Adam, SGD).
"""
super().__init__(model, loss=loss, tracker=tracker, seed=seed)
self.num_steps = num_steps
self.learning_rate = learning_rate
self.GDOptimizer = gd_optimizer
[docs]
def optimize_trigger(
self,
templates: TextTemplates,
initial_trigger: Optional[str] = DEFAULT_INIT_TRIGGER,
targets: Optional[Targets] = None,
) -> OptimizerResult:
# Initialization
self.model.set_inputs_from_tokens(templates=templates, targets=targets)
tokenizer = self.model.tokenizer
trigger_ids = tokenizer.encode_trigger(initial_trigger).to(self.model.device)
trigger_embeds = self.model._embedding_layer(trigger_ids.unsqueeze(0)) # (1, trigger_seq_len, embd_dim)
# Initialize the optimizer on the trigger embeddings
optimizer = self.GDOptimizer([trigger_embeds], lr=self.learning_rate)
best = RunningBest()
for step in self.track_steps(range(self.num_steps), desc="Soft Prompt Optimization"):
optimizer.zero_grad()
# Compute gradients w.r.t. trigger embeddings
trigger_grad, curr_loss = self.model.compute_grad_from_embeds(
loss_func=self.loss_func,
candidate_trigger_embeds=trigger_embeds,
normalize_grads=False,
return_loss=True,
) # grad: (1, trigger_seq_len, embed_dim); loss: (1,)
curr_loss = curr_loss.item()
# Set gradient on trigger embeddings
trigger_embeds.grad = trigger_grad
# Adam step
optimizer.step()
best.update(loss=curr_loss, trigger_emb=trigger_embeds.detach().squeeze(0))
self.log(loss=curr_loss, lr=optimizer.param_groups[0]["lr"], grad_norm=trigger_grad.norm().item())
result = best.to_result()
return result