from __future__ import annotations
import logging
from typing import Optional
import torch
from jaxtyping import Float, Int
from torch import Tensor
from tropt.common import (
DEFAULT_INIT_TRIGGER,
Targets,
TextTemplates,
)
from tropt.loss import BaseLoss
from tropt.model import (
BaseModel,
GradientTokenAccessMixin,
LossTokenAccessMixin,
)
from tropt.optimizer import BaseOptimizer, OptimizerResult
from tropt.optimizer.utils.retokenization import retokenize_filtering
from tropt.optimizer.utils.running_best import RunningBest
from tropt.optimizer.utils.token_constraints import TokenConstraints
from tropt.tracker import BaseTracker
logger = logging.getLogger(__name__)
[docs]
class AutoPromptOptimizer(BaseOptimizer):
"""Gradient-based discrete prompt optimization (Shin et al., 2020).
Each step picks a single random trigger position and evaluates all
gradient-ranked top-k candidate tokens at that position.
Reference: https://arxiv.org/abs/2010.15980
"""
model_requirements = (LossTokenAccessMixin, GradientTokenAccessMixin)
def __init__(
self,
model: BaseModel,
loss: BaseLoss,
tracker: Optional[BaseTracker] = None,
seed: Optional[int] = None,
num_steps: int = 500,
n_candidates: int = 512,
sample_topk: int = 256,
token_constraints: TokenConstraints = TokenConstraints(),
use_retokenize: bool = True,
):
super().__init__(model, loss=loss, tracker=tracker, seed=seed)
self.num_steps = num_steps
self.n_candidates = n_candidates
self.sample_topk = sample_topk
self.token_constraints = token_constraints
self.use_retokenize = use_retokenize
[docs]
def optimize_trigger(
self,
templates: TextTemplates,
initial_trigger: Optional[str] = DEFAULT_INIT_TRIGGER,
targets: Optional[Targets] = None,
) -> OptimizerResult:
self.model.set_inputs_from_tokens(templates=templates, targets=targets)
tokenizer = self.model.tokenizer
trigger_ids: Int[Tensor, "trigger_seq_len"] = tokenizer.encode_trigger(initial_trigger).to(self.model.device)
blacklist_ids = self.token_constraints.get_blacklist_ids(tokenizer, self.model.vocab_size)
best = RunningBest()
# Initial loss
current_loss = self.model.compute_loss_from_tokens(
trigger_ids.unsqueeze(0), loss_func=self.loss_func
).item()
self.log(loss=current_loss, trigger_str=initial_trigger)
for step_idx in self.track_steps(range(self.num_steps)):
trigger_seq_len = trigger_ids.shape[0]
# Pick one random position for this step
pos = torch.randint(
0, trigger_seq_len, (1,), device=trigger_ids.device
).item()
assert isinstance(pos, int)
# Compute gradient
trigger_grad: Float[Tensor, "trigger_seq_len vocab_size"] = (
self.model.compute_grad_from_tokens(
candidate_trigger_ids=trigger_ids.unsqueeze(0),
loss_func=self.loss_func,
normalize_grads=True,
).squeeze(0)
)
# Top-k candidates per position
trigger_grad *= -1
trigger_grad[:, blacklist_ids] = float("-inf")
topk_ids = trigger_grad.topk(self.sample_topk, dim=-1).indices
# Build candidates: all top-k at the single position
n_cands = min(self.n_candidates, self.sample_topk)
candidate_trigger_ids = trigger_ids.repeat(n_cands, 1).clone()
candidate_trigger_ids[:, pos] = topk_ids[pos, :n_cands]
# Retokenization filtering
if self.use_retokenize:
candidate_trigger_ids = retokenize_filtering(
candidate_trigger_ids, tokenizer
)
# Evaluate
losses = self.model.compute_loss_from_tokens(
candidate_trigger_ids, loss_func=self.loss_func
)
if losses.dim() > 1:
losses = losses.mean(dim=0)
current_loss = losses.min().item()
trigger_ids = candidate_trigger_ids[losses.argmin()]
trigger_str = tokenizer.decode_trigger(trigger_ids)
best.update(loss=current_loss, trigger_ids=trigger_ids, trigger_str=trigger_str)
self.log(loss=current_loss, trigger_str=trigger_str)
return best.to_result()