Source code for tropt.optimizer.utils.running_best
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional
from jaxtyping import Float, Int
from torch import Tensor
[docs]
@dataclass
class RunningBest:
"""
An auxilary object for optimizers:
accumulates per-step losses and tracks the best trigger found so far.
Stores:
loss: the best loss found so far
trigger_ids: the token IDs of the best trigger found so far
trigger_str: the string of the best trigger found so far
trigger_emb: the embedding of the best trigger found so far (for gradient-based optimizers; optional)
step: the step at which the best trigger was found
losses: a list of all losses observed at each step
trigger_strs: a list of all trigger strings observed at each step
"""
loss: float = float("inf")
trigger_ids: Optional[Int[Tensor, "trigger_seq_len"]] = None
trigger_str: Optional[str] = None
trigger_emb: Optional[Float[Tensor, "trigger_seq_len embed_dim"]] = None
step: int = -1
losses: List[float] = field(default_factory=list)
trigger_strs: List[str] = field(default_factory=list)
[docs]
def update(
self,
loss: float,
trigger_ids: Optional[Int[Tensor, "trigger_seq_len"]] = None,
trigger_str: Optional[str] = None,
trigger_emb: Optional[Float[Tensor, "trigger_seq_len embed_dim"]] = None,
) -> bool:
"""
Record a step and update the best if improved. Returns True on new best.
Args:
loss: the loss observed at the current step
trigger_ids: the token IDs of the trigger at the current step
trigger_str: the string of the trigger at the current step
trigger_emb: the embedding of the trigger at the current step (for gradient-based optimizers; optional)
Returns:
bool: True if this is a new best, False otherwise.
"""
self.losses.append(loss)
self.trigger_strs.append(trigger_str)
if loss < self.loss:
self.loss = loss
self.trigger_ids = trigger_ids.clone() if trigger_ids is not None else None
self.trigger_str = trigger_str
self.trigger_emb = trigger_emb.clone() if trigger_emb is not None else None
self.step = len(self.losses) - 1
return True
return False
[docs]
def to_result(self):
"""Convert to an OptimizerResult."""
from tropt.optimizer.base import OptimizerResult
return OptimizerResult(
best_loss=self.loss,
best_trigger_str=self.trigger_str,
best_trigger_ids=self.trigger_ids,
best_trigger_emb=self.trigger_emb,
losses=self.losses,
trigger_strs=self.trigger_strs,
)