from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Annotated, Any, List, Optional
import torch
from jaxtyping import Int
from torch import Tensor
from tropt.common import (
OPTIMIZED_TRIGGER_PLACEHOLDER,
ModelInput,
Targets,
TextTemplates,
)
# ======================= Triggered Input Managers =======================
class InputsManager(ABC):
"""
Base class for maintaining the input template, corresponding targets, and the method for
injecting triggers into the inputs.
This class wraps `n_templates` templates (that contain the substring `OPTIMIZED_TRIGGER_PLACEHOLDER` as
a trigger placeholder) and targets, and provides a unified interface for different types of inputs
(e.g., text-based, token-based) used in adversarial trigger optimization.
"""
def __init__(
self,
templates: TextTemplates,
targets: Targets, # n_templates elements per target entry
):
raise NotImplementedError
@abstractmethod
def get_triggered_inputs(self, chosen_template_idx: int, *args, **kwargs) -> ModelInput:
"""
Returns the trigger-combined model inputs, for the specified template index.
Args:
chosen_template_idx: Index of the template to use for generating the inputs.
... args for receiving the trigger candidates ...
Returns:
A ModelInput object containing the crafted triggered-combined inputs, which includes the
corresponding targets for the specified template.
"""
raise NotImplementedError
## Text inputs manager ##
[docs]
class TextInputManager(InputsManager):
"""
Class for maintaining text-based trigger-combined inputs (fits black-box text-level query access).
Instances of this class store `n_templates` templates and targets, and provide the method `get_triggered_inputs` to combine them with given trigger strings.
"""
before_texts: Annotated[List[str], "n_templates"]
after_texts: Annotated[List[str], "n_templates"]
targets: Targets
def __init__(
self,
templates: TextTemplates,
targets: Optional[Targets] = None,
):
assert isinstance(templates, list), "templates must be a list of strings."
if targets is None:
targets = Targets()
targets = targets.to_device("cuda" if torch.cuda.is_available() else "cpu")
before_texts, after_texts = [], []
for template in templates:
bef, aft = template.split(OPTIMIZED_TRIGGER_PLACEHOLDER, 1)
before_texts.append(bef)
after_texts.append(aft)
self.before_texts = before_texts
self.after_texts = after_texts
self.targets = targets
@property
def n_templates(self) -> int:
return len(self.before_texts)
[docs]
def get_triggered_inputs(
self,
chosen_template_idx: int,
trigger_strs: Annotated[List[str], "n_candidates"],
) -> ModelInput:
"""
Returns a list of inputs with the given trigger strings merged in.
The list is two-dimensional: outer list over templates, inner list over trigger variations; also, returns the corresponding targets.
Given `chosen_template_idx`, returns only the inputs for that template (1D list), and the corresponding targets.
"""
assert isinstance(trigger_strs, list) and all(
isinstance(s, str) for s in trigger_strs
), "trigger_strs must be a list of strings."
n_candidates: int = len(trigger_strs) # noqa
input_texts: List[str] = []
for trigger_str in trigger_strs:
curr_text = (
self.before_texts[chosen_template_idx]
+ trigger_str
+ self.after_texts[chosen_template_idx]
)
input_texts.append(curr_text)
# select only the chosen template's targets
message_targets = self.targets.select_message(chosen_template_idx)
return ModelInput(
input_texts=input_texts,
input_trigger_strs=trigger_strs,
message_targets=message_targets
)
## Token inputs manager ##