Source code for tropt.optimizer.utils.buffer
from __future__ import annotations
from typing import Optional
import torch
[docs]
class TriggerBuffer:
"""
Enables maintaining a buffer of the best triggers found during optimization.
https://www.haizelabs.com/blog/making-a-sota-adversarial-attack-on-llms-38x-faster
https://arxiv.org/pdf/2402.12329
"""
def __init__(
self,
triggers: Optional[list[torch.Tensor]] = None,
losses: Optional[list[float]] = None,
):
self.triggers = triggers or [] # List of trigger token ID tensors
self.losses = losses or [] # Corresponding list of losses
@property
def size(self) -> int:
return len(self.triggers)
[docs]
def add(self, trigger_ids: torch.Tensor, loss: float):
"""
Adds a new trigger and its loss to the buffer.
Increases the buffer size by one.
"""
self.triggers.append(trigger_ids)
self.losses.append(loss)
[docs]
def add_if_better(self, trigger_ids: torch.Tensor, loss: float):
"""
Adds the trigger to the buffer if its loss is better than the worst in the buffer.
Retains the buffer size.
"""
max_loss = self.get_highest_loss()
if loss < max_loss:
max_loss_idx = self.losses.index(max_loss)
self.triggers[max_loss_idx] = trigger_ids
self.losses[max_loss_idx] = loss
[docs]
def get_best_trigger(self, top_k: int = 1) -> torch.Tensor:
"""Return the lowest-loss trigger; if ``top_k > 1``, sample uniformly from the top_k for exploration."""
# Note: this can be optimized by maintaining a heap; but large buffers are currently rare in the package, so we keep it simple for now.
if top_k <= 1:
min_loss_idx = self.losses.index(min(self.losses))
return self.triggers[min_loss_idx]
sorted_indices = sorted(range(len(self.losses)), key=lambda i: self.losses[i])[:top_k]
return self.triggers[sorted_indices[torch.randint(len(sorted_indices), (1,)).item()]]
[docs]
def get_highest_loss(self) -> float:
if not self.losses:
return float('inf')
return max(self.losses)
[docs]
def get_lowest_loss(self) -> float:
if not self.losses:
return float('inf')
return min(self.losses)