Source code for tropt.optimizer.utils.token_constraints
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from typing import List, Optional
import torch
from jaxtyping import Int
from torch import Tensor
logger = logging.getLogger(__name__)
UNUSED_TOKEN_REGEX = r"(\[unused\d+\]|<unused\d+>)"
[docs]
@dataclass
class TokenConstraints:
disallow_non_ascii: bool = True
"""
Disallow non-ASCII tokens, which may be escaped by a defender when the trigger is used.
"""
disallow_special_tokens: bool = True
"""
Disallow special tokens (e.g., bos, eos, unk), which may be escaped by a defender when the trigger is used.
"""
disallow_unused_tokens: bool = True
"""
disallow `<unused*>` tokens, which can be filtered by a defender.
In many cases there are not part of the special tokens, thus require special care.
"""
disallow_custom_token_ids: List[int] = field(default_factory=list)
"""
Disallow any additional custom token ids.
"""
_cache: dict = field(
default_factory=dict, init=False, repr=False, hash=False, compare=False
)
[docs]
def get_blacklist_ids(self, tokenizer, vocab_size: Optional[int] = None) -> List[int]:
"""
Returns a list of token IDs that should be blacklisted based on the constraints.
"""
cache_key = (
tokenizer.name_or_path,
self.disallow_non_ascii,
self.disallow_special_tokens,
self.disallow_unused_tokens,
tuple(self.disallow_custom_token_ids),
)
if cache_key in self._cache:
return self._cache[cache_key]
# Build blacklist:
# initialize with any given custom ids
blacklist_ids = set(self.disallow_custom_token_ids)
vocab_size = vocab_size or tokenizer.vocab_size
if self.disallow_special_tokens:
# Including tokens from tokenizer.special_tokens_map (e.g., bos, eos, unk)
blacklist_ids.update(tokenizer.all_special_ids)
if self.disallow_non_ascii or self.disallow_unused_tokens:
# Single pass: decode each token once and apply all text-based checks together.
def is_ascii(s):
return s.isascii() and s.isprintable()
unused_pattern = re.compile(UNUSED_TOKEN_REGEX) if self.disallow_unused_tokens else None
for i in range(vocab_size):
if i in blacklist_ids:
continue # skip already blacklisted ids for efficiency
try:
token_str = tokenizer.decode([i])
except Exception as e:
logger.debug(f"While perfoming listing token-blacklist: failed to decode token {i}: {e}")
# If we can't decode the token, we can't use it, so we blacklist it
blacklist_ids.add(i)
continue
# Check non-ASCII constraint
if self.disallow_non_ascii and token_str and not is_ascii(token_str):
blacklist_ids.add(i)
# Check unused token constraint
elif self.disallow_unused_tokens and unused_pattern and token_str and unused_pattern.match(token_str):
blacklist_ids.add(i)
blacklist_ids = sorted(list(blacklist_ids))
# filter out negative / out-of-vocab ids (in case tokenizer has weird behavior)
blacklist_ids = [tid for tid in blacklist_ids if 0 <= tid < vocab_size]
self._cache[cache_key] = blacklist_ids
logger.info(
"Black-lising {}% of the vocabulary ({} tokens / {} vocab)".format(
round(100 * len(blacklist_ids) / vocab_size, 2),
len(blacklist_ids),
vocab_size,
)
)
return blacklist_ids
[docs]
def get_whitelist_ids(
self,
tokenizer,
vocab_size: int,
device=None,
return_tensor: bool = False,
) -> List[int] | Int[Tensor, "n_valid"]:
"""Returns valid (non-blacklisted) token ids. Reuses the cached blacklist for efficiency.
Args:
return_tensor: If True, return a 1-D int tensor on `device` instead of a list.
device: Required when ``return_tensor=True``.
"""
blacklist_ids = self.get_blacklist_ids(tokenizer, vocab_size)
blacklist_set = set(blacklist_ids)
whitelist_ids = [i for i in range(vocab_size) if i not in blacklist_set]
if return_tensor:
return torch.tensor(whitelist_ids, dtype=torch.long, device=device)
return whitelist_ids