Source code for tropt.recipe_hub.PromptRecovery__wen2023

from __future__ import annotations
"""Prompt Recovery for Image Generation Models.

Based on: Hard Prompts Made Easy: Gradient-Based Discrete Optimization for
Prompt Tuning and Discovery (Wen et al. 2023) "Prompt Recovery for Image
Generation Models: A Comparative Study of Discrete Optimizers" (Williams et al., 2025).

Uses CLIP-like models as a proxy: optimize discrete text tokens to maximize
cosine similarity between the text embedding and a target image embedding.
Evaluation follows the paper's protocol: CLIP similarity (text vs. image)
and Text Embedding Similarity (inverted prompt vs. ground-truth prompt).
"""

from dataclasses import dataclass
from typing import Any, Optional

import torch
from jaxtyping import Float
from torch import Tensor

from tropt.common import Targets
from tropt.loss import SimilarityLoss
from tropt.model.huggingface.clip_encoder import CLIPTextEncoderHFModel
from tropt.optimizer import BeamSearchOptimizer, OptimizerResult
from tropt.optimizer.gcg_optimizer import GCGOptimizer
from tropt.optimizer.gcgplus_optimizer import GCGPlusOptimizer
from tropt.optimizer.pez_optimizer import PEZOptimizer
from tropt.optimizer.utils.token_constraints import TokenConstraints
from tropt.optimizer.utils.token_initializers import get_printable_random_trigger
from tropt.tracker import BaseTracker

# Paper uses 8-20 free tokens; Wen 2023 Fig 5 finds 16 most generalizable.
_DEFAULT_TRIGGER_LEN = 16

# Paper §3.2 / §5.1: vanilla GCG run for 3000 steps with batch 512.
_GCG_NUM_STEPS = 3000

# MAC (momentum-accelerated GCG+) step count.
_MAC_NUM_STEPS = 500

# PEZ (Wen et al., 2023): paper uses 3000 steps with AdamW lr=0.1, wd=0.1.
_PEZ_NUM_STEPS = 3000


[docs] def get_image_embedding_for_clip_model( image_path: Optional[str] = None, image=None, model_name: str = "openai/clip-vit-large-patch14", ) -> Float[Tensor, "1 d_model"]: """Encode an image into CLIP's shared embedding space using the vision encoder. Loads only the vision side of the full CLIP model, encodes the image, and returns the projected image embedding. Args: image_path: Path to an image file (used if `image` is None). image: A PIL Image. If None, loads from `image_path`. model_name: CLIP model whose vision encoder to use. Returns: Image embedding tensor of shape (1, d_model). """ from PIL import Image as PILImage from transformers import CLIPModel, CLIPProcessor if image is None: if image_path is None: raise ValueError("Either `image` or `image_path` must be provided.") image = PILImage.open(image_path).convert("RGB") processor = CLIPProcessor.from_pretrained(model_name) clip_model = CLIPModel.from_pretrained(model_name) device = "cuda" if torch.cuda.is_available() else "cpu" clip_model = clip_model.to(device) inputs = processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): image_emb = clip_model.get_image_features(**inputs) # (1, d_model) # Clean up: we only needed the vision encoder del clip_model, processor return image_emb.pooler_output
[docs] def prompt_recovery__wen2023( image=None, model_name: str = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", template: str = "{{OPTIMIZED_TRIGGER}}", initial_trigger: Optional[str] = None, optimizer_type: str = "pez", trigger_len: int = _DEFAULT_TRIGGER_LEN, util_lm_model_name: str = "google/gemma-2-2b-it", tracker: Optional[BaseTracker] = None, target_image_path: Optional[str] = None, target_image_emb: Optional[Float[Tensor, "d_model"]] = None, seed: Optional[int] = None, ) -> OptimizerResult: """Recover the prompt that generated a given image using CLIP + a discrete optimizer. Args: image: A PIL Image to invert. If None, loads from `target_image_path`. model_name: CLIP-like model to use as proxy. Default is OpenCLIP ViT-H/14 on LAION-2B, matching Wen 2023 §4.1. template: Text template with trigger placeholder. initial_trigger: Starting trigger string. If None (default), a random vocab-embedded trigger of length `trigger_len` is sampled. optimizer_type: which discrete optimizer to drive the inversion: - `"pez"` (default): PEZ (Wen et al., 2023). - `"mac"`: MAC = momentum-accelerated GCG+ (Wang 2024). - `"gcg"`: vanilla GCG. - `"adv_decoding"`: beam-search decoding with a utility LM. trigger_len: Number of trigger tokens. util_lm_model_name: HF model id for the utility LM (only for adv_decoding). tracker: Optional experiment tracker. target_image_path: Path to an image file (used if `image` is None). target_image_emb: Pre-computed image embedding (skips encoding). Returns: OptimizerResult with `best_trigger_str` as the recovered prompt. """ model_obj = CLIPTextEncoderHFModel( model_name=model_name, ) if target_image_emb is None: target_image_emb = get_image_embedding_for_clip_model( image_path=target_image_path, image=image, model_name=model_name, ) elif target_image_emb.shape[-1] != model_obj.d_model: # Caller supplied a pre-computed embedding from a different CLIP than # the text tower we just built. Without this check the failure surfaces # deep inside the PEZ loss as a cryptic tensor-size error. raise ValueError( f"target_image_emb dim ({target_image_emb.shape[-1]}) does not " f"match the text-tower projection dim ({model_obj.d_model}) for " f"model_name={model_name!r}. Re-encode the image with the same " f"model_name (e.g. pass `model_name={model_name!r}` to " f"`get_image_embedding_for_clip_model`)." ) token_constraints = TokenConstraints() if initial_trigger is None: # Wen 2023 Algorithm 1: P ~ E^{|V|} (random vocab-embedded init). random_trigger = get_printable_random_trigger( trigger_len=trigger_len, tokenizer=model_obj.tokenizer, ) assert isinstance(random_trigger, str) initial_trigger = random_trigger if optimizer_type == "mac": # MAC (momentum-accelerated GCG+, Wang 2024) with paper params. optimizer = GCGPlusOptimizer( model=model_obj, loss=SimilarityLoss(), proxy_model=model_obj, tracker=tracker, candidate_selection="gradient", num_steps=_MAC_NUM_STEPS, sample_topk=256, n_candidates=256, sample_n_replace=(1, 1), momentum=0.6, candidate_oversample_factor=1.1, token_constraints=token_constraints, use_retokenize=True, seed=seed, ) elif optimizer_type == "gcg": # Vanilla GCG with Williams et al. §3.2 / §5.1 hparams: # 3000 steps, batch 512, top-k 256. optimizer = GCGOptimizer( model=model_obj, loss=SimilarityLoss(), tracker=tracker, num_steps=_GCG_NUM_STEPS, n_candidates=512, sample_topk=256, sample_n_replace=1, token_constraints=token_constraints, use_retokenize=True, seed=seed, ) elif optimizer_type == "pez": optimizer = PEZOptimizer( model=model_obj, loss=SimilarityLoss(), tracker=tracker, num_steps=_PEZ_NUM_STEPS, learning_rate=0.1, weight_decay=0.1, gd_optimizer=torch.optim.AdamW, seed=seed, ) elif optimizer_type == "adv_decoding": from tropt.model.huggingface.lm import LMHFModel device = "cuda" if torch.cuda.is_available() else "cpu" util_lm = LMHFModel( model_name=util_lm_model_name, device=device, use_prefix_cache=False, dtype="bfloat16", ) optimizer = BeamSearchOptimizer( model=model_obj, loss=SimilarityLoss(), tracker=tracker, seed=seed, util_lm=util_lm, util_lm_prefix="Write a sentence with a lot of triggers. {{OPTIMIZED_TRIGGER}}", num_steps=trigger_len, beam_size=96, branching_factor=10, top_k=10, temperature=1.0, token_constraints=token_constraints, ) else: raise ValueError( f"Unknown optimizer_type={optimizer_type!r}; " f"expected 'mac', 'gcg', 'pez', or 'adv_decoding'." ) result = optimizer.optimize_trigger( templates=[template], targets=Targets(target_vectors=target_image_emb), initial_trigger=initial_trigger, ) return result
# ======================= Evaluation ======================= @dataclass class PromptRecoveryEvaluation: """Evaluation results for prompt recovery.""" clip_similarity: float text_embedding_similarity: Optional[float] = None
[docs] def evaluate_prompt_recovery( inverted_prompt: str, image=None, image_path: Optional[str] = None, original_prompt: Optional[str] = None, # Wen 2023 §4.1 evaluates with a held-out OpenCLIP ViT-G to avoid scoring # against the same backbone used for optimization. clip_model_name: str = "laion/CLIP-ViT-g-14-laion2B-s12B-b42K", text_sim_model_name: str = "sentence-transformers/all-MiniLM-L6-v2", clip_text_model_obj: Optional[CLIPTextEncoderHFModel] = None, ) -> PromptRecoveryEvaluation: """Evaluate a recovered prompt following the paper's protocol. Metrics: 1. CLIP similarity: cosine similarity between the inverted prompt's text embedding and the original image embedding in CLIP space. 2. Text Embedding Similarity (optional, requires `original_prompt`): cosine similarity between sentence embeddings of the inverted and original prompts using all-MiniLM-L6-v2. Note: The paper also uses FID/KID (image-to-image), which requires a text-to-image generation pipeline and is not included here. """ if image is None: if image_path is None: raise ValueError("Either `image` or `image_path` must be provided.") from PIL import Image image = Image.open(image_path).convert("RGB") # --- Metric 1: CLIP Similarity (text vs. image) --- if clip_text_model_obj is None: clip_text_model_obj = CLIPTextEncoderHFModel( model_name=clip_model_name, ) image_emb = get_image_embedding_for_clip_model( image=image, model_name=clip_model_name, ) with torch.no_grad(): text_emb = clip_text_model_obj.invoke_from_texts( [inverted_prompt] ).output_embeddings # (1, d) text_emb = text_emb / text_emb.norm(dim=-1, keepdim=True) image_emb = image_emb / image_emb.norm(dim=-1, keepdim=True) clip_sim = (text_emb * image_emb).sum(dim=-1).item() # --- Metric 2: Text Embedding Similarity (optional) --- text_sim = None if original_prompt is not None: from sentence_transformers import SentenceTransformer st_model = SentenceTransformer(text_sim_model_name) embeddings = st_model.encode( [inverted_prompt, original_prompt], convert_to_tensor=True ) inv_emb = embeddings[0] orig_emb = embeddings[1] inv_emb = inv_emb / inv_emb.norm() orig_emb = orig_emb / orig_emb.norm() text_sim = (inv_emb * orig_emb).sum().item() return PromptRecoveryEvaluation( clip_similarity=clip_sim, text_embedding_similarity=text_sim, )
# ======================= Image Generation =======================
[docs] def generate_image_from_prompt( prompt: str, model_name: str = "black-forest-labs/FLUX.1-dev", num_inference_steps: int = 28, height: int = 512, width: int = 512, seed: Optional[int] = None, ): """Generate an image from a text prompt using a diffusers pipeline. Args: prompt: Text prompt to generate from. model_name: Diffusers model to use. num_inference_steps: Number of denoising steps. height: Output image height. width: Output image width. seed: Random seed for reproducibility. Returns: PIL Image. """ is_flux = "flux" in model_name.lower() if is_flux: from diffusers import FluxPipeline pipe = FluxPipeline.from_pretrained( model_name, torch_dtype=torch.bfloat16, ) else: # Stable Diffusion (e.g. sd2-community/stable-diffusion-2-1) from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained( model_name, torch_dtype=torch.float16, ) pipe.enable_model_cpu_offload() generator = torch.Generator(device="cpu").manual_seed(seed) if seed is not None else None image = pipe( prompt, num_inference_steps=num_inference_steps, height=height, width=width, generator=generator, ).images[0] del pipe return image
# ======================= End-to-end ======================= @dataclass class PromptRecoveryQuadruple: """Prompt → image → recovered prompt → regenerated image.""" original_prompt: str original_image: Any # PIL.Image.Image recovered_prompt: str recovered_image: Any # PIL.Image.Image best_loss: float def recover_prompt_end_to_end( prompt: str, sd_model_name: str = "sd2-community/stable-diffusion-2-1", clip_model_name: str = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", optimizer_type: str = "pez", trigger_len: int = 20, util_lm_model_name: str = "google/gemma-2-2b-it", seed: int = 0, height: int = 768, width: int = 768, num_inference_steps: int = 50, tracker: Optional[BaseTracker] = None, initial_trigger: Optional[str] = None, ) -> PromptRecoveryQuadruple: """Generate an image from `prompt`, recover the prompt from the image, regenerate. Defaults mirror the exp3-promrec reproduction of Williams et al. 2024: SD-2.1 + OpenCLIP H/14 (laion2B), random 20-token init. The recipe owns optimizer hyperparameters; callers pick `optimizer_type` ∈ {"gcg", "mac", "pez", "adv_decoding"}. """ import gc import random as _random original_image = generate_image_from_prompt( prompt=prompt, model_name=sd_model_name, num_inference_steps=num_inference_steps, height=height, width=width, seed=seed, ) gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() if initial_trigger is None: from transformers import AutoTokenizer # Re-seed `random` immediately before the initializer — SD pipeline init # above may have consumed global random state. _random.seed(seed) _init = get_printable_random_trigger( trigger_len=trigger_len, tokenizer=AutoTokenizer.from_pretrained(clip_model_name), ) assert isinstance(_init, str) initial_trigger = _init result = prompt_recovery__wen2023( image=original_image, model_name=clip_model_name, initial_trigger=initial_trigger, optimizer_type=optimizer_type, trigger_len=trigger_len, util_lm_model_name=util_lm_model_name, tracker=tracker, seed=seed, ) recovered_prompt = result.best_trigger_str gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() recovered_image = generate_image_from_prompt( prompt=recovered_prompt, model_name=sd_model_name, num_inference_steps=num_inference_steps, height=height, width=width, seed=seed, ) return PromptRecoveryQuadruple( original_prompt=prompt, original_image=original_image, recovered_prompt=recovered_prompt, recovered_image=recovered_image, best_loss=float(result.best_loss), )