from __future__ import annotations
import functools
import inspect
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List, Optional
import pydantic
import torch
from jaxtyping import Float
from tqdm import tqdm
from tropt.common import Targets, TextTemplates, TokenTrigger
from tropt.loss import BaseLoss
from tropt.model import BaseModel
from tropt.tracker import BaseTracker, DummyTracker
logger = logging.getLogger(__name__)
## ------- Optimizer result ------- ##
[docs]
@dataclass
class OptimizerResult:
best_loss: float
# Best trigger options:
best_trigger_ids: Optional[TokenTrigger] = None
best_trigger_str: Optional[str] = None
best_trigger_emb: Optional[Float[torch.Tensor, "trigger_seq_len embed_dim"]] = None
best_trigger_probs: Optional[Float[torch.Tensor, "trigger_seq_len vocab_size"]] = None
# Optiomazation records:
losses: Optional[List[float]] = None
trigger_strs: Optional[List[str]] = None
[docs]
def to_dict(self) -> dict:
"""Lightweight summary dict for final logging (no tensors or lists !)."""
d: dict = {"best_loss": self.best_loss}
if self.best_trigger_str is not None:
d["best_trigger_str"] = self.best_trigger_str
return d
## ------- Base Optimizer ------- ##
[docs]
class BaseOptimizer(ABC):
"""
Base class for all trigger optimizers.
- Implements common functionality and interface for optimizers, including tracking.
- Subclasses must implement the ``optimize_trigger`` method, which contains the core optimization loop and returns an ``OptimizerResult``; this method is automatically wrapped to handle logging, model state resets, and tracker finalization.
"""
model_requirements = ()
"""Tuple of model mixin classes the primary model must satisfy; validated in ``__init__``.
Convention: declare the least-restrictive configuration the optimizer supports.
- Black-box only -> ``(LossTextAccessMixin,)``, even if gradient modes exist.
- Always needs token-level loss → include ``LossTokenAccessMixin``.
- Always needs gradients → include ``GradientTokenAccessMixin`` or ``GradientEmbedAccessMixin``.
Notes:
- Requirements that only occur in an optional flow (e.g.,
``candidate_selection="gradient"`` requiring ``GradientTokenAccessMixin``) must be
validated explicitly in ``__init__`` after ``super().__init__()``, with an assert/error.
- Requirements on auxiliary models (proxy_model, util_model, etc.) are not covered here, and should also be validated explicitly in ``__init__``.
"""
def __init_subclass__(cls, **kwargs):
"""
Wraps the ``optimize_trigger`` method to handle the full tracker lifecycle, model state resets, and other logging-related / bookeeping logic.
"""
super().__init_subclass__(**kwargs)
# Inject a wrapper around optimize_trigger
# This saves some boilerplate in optimizers impl.
if "optimize_trigger" in cls.__dict__:
original = cls.optimize_trigger
# Capture once at class-definition time (not per call).
# _sig is needed to resolve subclass-specific defaults (e.g. DEFAULT_INIT_TRIGGER)
# so _log_config receives the actual values used, not None.
_sig = inspect.signature(original)
validated = pydantic.validate_call(
config=pydantic.ConfigDict(arbitrary_types_allowed=True)
)(original)
@functools.wraps(original)
def _wrapper(self, *args, **kwargs):
# Pre-run setup
self._best_loss = float("inf")
for val in self.__dict__.values():
if isinstance(val, BaseModel):
val.reset_usage_stats()
# Resolve full arg values (including defaults) for config
ba = _sig.bind(self, *args, **kwargs)
ba.apply_defaults()
run_config = self._build_run_config(
ba.arguments.get("templates"),
ba.arguments.get("initial_trigger"),
ba.arguments.get("targets"),
)
# Initialize tracker and log the config:
self.tracker.init(run_config)
# Run the actual optimization method
result: OptimizerResult = validated(self, *args, **kwargs)
# Post-run teardown: log summary, reset all model input state, close tracker
summary = result.to_dict()
summary.update(self._collect_model_stats())
self.tracker.finish(summary)
# Reset model input state to avoid accidental state leakage across runs
# (we don't reset model usage, as user may want to keep using them after optimization)
for val in self.__dict__.values():
if isinstance(val, BaseModel):
for _reset in ("reset_inputs_from_tokens", "reset_inputs_from_texts"):
fn = getattr(val, _reset, None)
if fn is not None:
fn()
return result # return the result object as usual
setattr(cls, 'optimize_trigger', _wrapper)
def __init__(
self,
model: BaseModel,
loss: Optional[BaseLoss] = None,
tracker: Optional[BaseTracker] = None,
seed: Optional[int] = None,
):
"""
Args:
model: The target model to attack.
loss: The loss function to be optimized (optional, but most optimizers will require one).
tracker: An optional tracker for logging optimization progress.
seed: Random seed for reproducibility; set in initialization.
"""
# Model requirements validation
assert isinstance(self.model_requirements, tuple), "model_requirements must be a tuple"
assert all(
isinstance(m, type) for m in self.model_requirements
), "model_requirements must contain only classes/mixins of models."
assert all(
isinstance(model, m) for m in self.model_requirements
), f"Model {type(model)} not supported by {type(self)}"
self.model = model
# Loss function validation
assert isinstance(loss, BaseLoss), "loss must be an instance of BaseLoss"
self.loss_func = loss
tracker = tracker if tracker is not None else DummyTracker()
assert isinstance(tracker, BaseTracker), "tracker must be an instance of BaseTracker"
self.tracker = tracker
self._pbar = None
self._best_loss = float("inf")
# Single active budget (limit, metric, scope), or None. See set_budget().
self._budget: Optional[tuple[int, str, str]] = None
if seed is not None:
from transformers import set_seed
set_seed(seed)
torch.use_deterministic_algorithms(True, warn_only=True)
[docs]
@abstractmethod
def optimize_trigger(
self,
templates: TextTemplates,
initial_trigger: Optional[str] = None,
targets: Optional[Targets] = None,
) -> OptimizerResult:
"""Optimize the trigger to minimize the loss on the given inputs.
Subclasses only implement the search loop and return an
``OptimizerResult``.
Args:
templates: Can be a single string or a list of (n_templates) strings.
initial_trigger: Initial trigger to start optimization from, if used by the optimizer.
targets: Target outputs for the given inputs, if applicable.
Returns:
Optimized trigger.
Note:
This method is wrapped by the baseclass (via ``__init_subclass__``) to handle the full tracker lifecycle: ``tracker.init(config)`` before, ``tracker.finish(summary)`` after, model state resets, and other bookeeping.
The optimization loop should iterate via :meth:`track_steps`, which
handles `tqdm` progress bar and enforces any budget upper-bound configured via :meth:`set_budget`.
"""
...
def _build_run_config(
self,
templates: TextTemplates,
initial_trigger: Optional[str] = None,
targets: Optional[Targets] = None,
) -> dict:
"""Build run config dict from optimizer state and call arguments."""
_skip = {"model", "loss_func", "tracker"}
hparams = {
f"hparam/{k}": v if isinstance(v, (str, int, float, bool, type(None))) else str(v)
for k, v in self.__dict__.items()
if k not in _skip
}
targets_repr = None
if targets is not None:
targets_repr: dict[str, Any] = {
k: v.tolist() if isinstance(v, torch.Tensor) else v
for k, v in targets.model_dump().items()
if v is not None
}
config = {
"optimizer": type(self).__name__,
"model_name": self.model.get_model_name(),
"loss": repr(self.loss_func),
"templates": list(templates) if not isinstance(templates, list) else templates,
"initial_trigger": str(initial_trigger) if initial_trigger is not None else None,
"targets": targets_repr,
**hparams,
}
lines = ["\n=== Optimizer Run Config ==="]
for k, v in config.items():
lines.append(f" {k}: {v}")
lines.append("===========================")
logger.info("\n".join(lines))
return config
[docs]
def log(self, loss: float, trigger_str: Optional[str] = None, **extra) -> None:
"""Log per-step metrics to the tracker.
Automatically enriches with:
- ``best_loss``: running best loss across steps.
- ``loss/*``: loss function component stats.
- ``target_model_stats/*``, ``total_models_stats/*``: model usage stats (by inspecting all optimizer attributes that subclass BaseModel).
Args:
loss: Per-step loss value.
trigger_str: Current trigger string (omitted from log dict if None).
**extra: Any additional key-value pairs to include in the log dict.
"""
# Track the best loss for per-step logging:
if loss < self._best_loss:
self._best_loss = loss
log_dict: dict = {"loss": loss, "best_loss": self._best_loss}
if trigger_str is not None:
log_dict["trigger_str"] = trigger_str
log_dict.update(extra)
# Enrich w/ loss function stats (from the last step):
for k, v in self.loss_func.get_loss_log_dict().items():
log_dict[f"loss/{k}"] = v
# Enrich w/ model stats:
log_dict.update(self._collect_model_stats())
# Log to tracker:
self.tracker.log(log_dict)
# Update tqdm progress bar:
if self._pbar is not None:
desc = f"{loss=:.4f}"
if trigger_str is not None:
desc += f" {trigger_str=}"
self._pbar.set_description(desc)
def _collect_model_stats(self) -> dict:
"""Collect usage stats from all distinct model instances (i.e., that subclass BaseModel) on ``self``.
Returns a flat dict with prefixed keys: ``target_model_stats/``, ``{attr}_stats/``,
and ``total_models_stats/`` (sum across all models).
Used for resource monitoring.
"""
seen_ids: set = set()
model_stats: dict[str, dict] = {}
for attr, val in self.__dict__.items():
if not isinstance(val, BaseModel) or id(val) in seen_ids:
continue
prefix = "target_model_stats" if attr == "model" else f"{attr}_stats"
model_stats[prefix] = val.get_usage_stats()
seen_ids.add(id(val))
result: dict = {}
for prefix, stats in model_stats.items():
for k, v in stats.items():
result[f"{prefix}/{k}"] = v
# Sum across all models:
total_model_stats: dict = {}
for stats in model_stats.values():
for k, v in stats.items():
total_model_stats[k] = total_model_stats.get(k, 0) + v
for k, v in total_model_stats.items():
result[f"total_models_stats/{k}"] = v
return result
#### -- Budget tracking and step iteration utilities -- ####
[docs]
def reset_budget(self):
"""Clears any budget set by :meth:`set_budget`."""
self._budget = None
[docs]
def set_budget(self, limit: int, metric: str = "total_tokens", scope: str = "all") -> None:
"""Registers an upper-bound resource budget enforced by :meth:`track_steps`.
The budget is a ceiling, not a quota: if the optimizer terminates naturally before reaching it, the budget has no effect.
Common metrics (keys of :meth:`BaseModel.get_usage_stats`):
- ``"total_flops"``: Estimated FLOPs consumed. Requires ``model.set_flop_counting("manual")`` on any model whose FLOPs should count. Best choice for white-box compute-equalised comparisons.
- ``"total_tokens"``: Total tokens processed (prompt + generation). Best for black-box models where FLOPs aren't observable but token usage is.
Args:
limit: Integer upper bound on the ``metric``.
metric: The metric the budget is set by. Defaults to the token usage count.
scope: What models to take the metric against.
In optimizers that accomodate multiple models (e.g., proxy models), this may be critical choice.
``"all"``, sums the metric across all models found on ``self``.
``"target"`` only considers the primary target model (``self.model``), which is useful if we only care about the target model API token usage.
Usage:
```python
# Whitebox: cap compute by FLOPs (across target + any proxy LM)
optimizer = GCGOptimizer(model=model_obj, loss=PrefillCELoss(), num_steps=10_000)
optimizer.set_budget(1e17, metric="total_flops")
# Blackbox: cap by target-model tokens (FLOPs aren't observable on API models)
optimizer = RandomSearchOptimizer(model=model_obj, loss=PrefillCELoss(), num_steps=10_000)
optimizer.set_budget(1_000_000, metric="total_tokens", scope="target")
```
"""
assert scope in ("all", "target"), "scope must be 'all' or 'target'"
assert metric in ("total_flops", "total_tokens"), "currently only 'total_flops' and 'total_tokens' metrics are supported"
self._budget = (int(limit), metric, scope)
def _budget_usage(self) -> float:
if self._budget is None:
return 0.0
_, metric, scope = self._budget
# Budget on target model only:
if scope == "target":
return float(self.model.get_usage_stats().get(metric, 0))
# Budget across all available models:
seen: set[int] = set()
total = 0.0
for val in self.__dict__.values():
if isinstance(val, BaseModel) and id(val) not in seen:
seen.add(id(val))
total += val.get_usage_stats().get(metric, 0)
return total
def _budget_exhausted(self) -> bool:
return self._budget is not None and self._budget_usage() >= self._budget[0]
[docs]
def track_steps(self, *args, **kwargs):
"""Iterator for the optimization loop that handles progress bar and budget enforcement.
This supplement the optimziation loop with:
- a ``tqdm`` progress bar (args/kwargs forwarded) that :meth:`log`
calls auto-updates with the current loss and trigger string, and
- early termination when the budget set via :meth:`set_budget` is hit.
The budget is checked at the top of each step, so overshoot is bounded
by one step's work. Without a budget set, behaves like plain ``tqdm``.
Note: If you implement a personal-use custom optimizer for quick check, and don't care for
fancy progress bar / budget, you may safely ignore this.
Usage::
for _ in self.track_steps(range(self.num_steps), desc="MyOpt"):
...
"""
self._pbar = tqdm(*args, **kwargs)
for item in self._pbar:
if self._budget_exhausted():
limit, metric, scope = self._budget # type: ignore[misc]
logger.info(
f"[{type(self).__name__}] budget reached "
f"({metric}[{scope}]={self._budget_usage():.3g} >= {limit:.3g}); stopping early."
)
self._pbar.close()
return
yield item