Source code for tropt.recipe_hub.utils
from __future__ import annotations
"""Shared helpers for Recipe Hub recipes."""
import gc
import logging
import torch
from tropt.model.huggingface.lm import LMHFModel
logger = logging.getLogger(__name__)
[docs]
def generate_from_model(
model_name: str,
prompt: str,
max_new_tokens: int = 20,
greedy_decode: bool = False, # sample the response by default
) -> str:
"""Generate a single response to `prompt` from a freshly loaded model.
Loads → generates → unloads the model, so it never co-resides with another
model the caller loads afterwards. The semantics of the output (e.g. using a
jailbroken model's response as an optimization target) are the caller's.
"""
model = LMHFModel(model_name=model_name, use_prefix_cache=False, dtype="bfloat16")
out = model.invoke_from_texts(
input_texts=[prompt],
max_new_tokens=max_new_tokens,
greedy_decode=greedy_decode,
require_generation=True,
)
assert out.generated_response_strs is not None, "Generation must return response strs."
response = out.generated_response_strs[0]
logger.info(f"Generated response from {model_name!r}: {response!r}")
del out, model._model, model
gc.collect()
torch.cuda.empty_cache()
return response