Source code for tropt.optimizer.utils.token_initializers

from __future__ import annotations
import random
import string
from typing import List, Optional

import torch
from jaxtyping import Float
from torch import Tensor

from tropt.model.model_base import BaseTokenizer


[docs] def get_printable_random_trigger( trigger_len: int, return_ids: bool = False, blacklist_ids: Optional[List[int]] = None, tokenizer: Optional[BaseTokenizer] = None, token_constraints: Optional["TokenConstraints"] = None, # ty:ignore[unresolved-reference] (to avoid circular imports) ) -> str | Float[Tensor, "trigger_seq_len"]: """ Generates a random initial trigger consisting of printable ASCII english letters. - If the tokenizer is provided, the trigger is tokenized and truncated to ensure it fits within the specified length. Otherwise, the length stands for the number of characters. - Tokens whose IDs appear in `blacklist_ids` are resampled until a clean sequence is found. - `return_ids`: If True, returns a 1-D LongTensor of token IDs (requires `tokenizer`), instead of the string. - `token_constraints`: If provided (and tokenizer is set), extracts blacklist_ids automatically. Overrides `blacklist_ids`. """ assert not (token_constraints is not None and blacklist_ids is not None), ( "Pass either `token_constraints` or `blacklist_ids`, not both." ) if token_constraints is not None and tokenizer is not None: blacklist_ids = token_constraints.get_blacklist_ids(tokenizer) _chars = string.ascii_letters + string.digits + ' ' # + string.punctuation _chars += ' ' * 10 # adding more spaces to increase their appearance if tokenizer is not None: blacklist_ids: set[int] = set(blacklist_ids) if blacklist_ids else set() _token_ids: list[int] = [] while len(_token_ids) < trigger_len: candidate = ''.join(random.choices(_chars, k=trigger_len * 4)) candidate_ids = tokenizer.encode(candidate, add_special_tokens=False) clean_ids = [t for t in candidate_ids if t not in blacklist_ids] needed = trigger_len - len(_token_ids) _token_ids.extend(clean_ids[:needed]) initial_trigger: str = tokenizer.decode(_token_ids) else: initial_trigger: str = ''.join(random.choices(_chars, k=trigger_len * 4)) initial_trigger = initial_trigger[:trigger_len] if return_ids: assert tokenizer is not None, "Tokenizer must be provided to return token IDs." return torch.tensor(tokenizer.encode(initial_trigger, add_special_tokens=False)) # shape: (trigger_seq_len,) return initial_trigger