Source code for tropt.model.litellm_proxy.lm

from __future__ import annotations
import logging
from typing import Dict, List, Optional

from tropt.common import ModelOutput
from tropt.model import BaseTokenizer, LMBaseModel, LossTextAccessMixin
from tropt.model.openai.encoder import OpenAITokenizer

logger = logging.getLogger(__name__)

_LITELLM_PROXY_URL = "http://localhost:4000"

# Maximum top_logprobs supported by OpenAI (and most providers).
_MAX_TOP_LOGPROBS = 20

[docs] class LiteLLMModel(LMBaseModel, LossTextAccessMixin): """Model wrapper using LiteLLM as a unified interface to LLM providers. By default operates in **library mode**: LiteLLM routes requests directly to the provider based on the model name prefix (e.g. ``"openai/gpt-4o"`` -> OpenAI, ``"anthropic/claude-3.5-sonnet"`` -> Anthropic). API keys are read from environment variables (``OPENAI_API_KEY``, etc.). Alternatively, set ``using_litellm_proxy=True`` and ``base_url`` to connect via a running LiteLLM proxy server (``litellm --port 4000``), which is mostly used for centralized key management or shared server setups. """ _N_RETRIES: int = 5 _RETRY_STRATEGY: str = "exponential_backoff_retry" def __init__( self, model_name: str, base_url: Optional[str] = None, api_key: Optional[str] = None, system_prompt: Optional[str] = None, max_concurrent_requests: int = 20, using_litellm_proxy: bool = False, **client_kwargs, ): """ Args: model_name: LiteLLM model identifier, e.g. ``"openai/gpt-4o-mini"``. base_url: Override the provider URL. Only needed when using a LiteLLM proxy or a custom endpoint. api_key: API key override (by default read from env vars). system_prompt: Optional system prompt to prepend. max_concurrent_requests: Maximum number of parallel requests when batching is not supported. using_litellm_proxy: Set True when connecting to a LiteLLM proxy server. This sets ``base_url`` to ``localhost:4000`` (if not already provided) and tells LiteLLM the endpoint speaks the OpenAI protocol. **client_kwargs: Additional arguments for the litellm completion call. """ self.model_name = model_name self._system_prompt = system_prompt self._max_concurrent_requests = max_concurrent_requests # Expose an OpenAI-compatible tokenizer (via tiktoken) for optimizers # that need token-level mutations (RS, GCGPlus, etc.). # OpenAI models get their exact tokenizer; others fall back to cl100k_base. bare_name = model_name.split("/", 1)[-1] # e.g., "openai/gpt-4o-mini" -> "gpt-4o-mini" self._tokenizer = OpenAITokenizer(bare_name) self._client_kwargs = client_kwargs if using_litellm_proxy: self._client_kwargs.setdefault("base_url", base_url or _LITELLM_PROXY_URL) self._client_kwargs["custom_llm_provider"] = "openai" elif base_url: self._client_kwargs["base_url"] = base_url if api_key: self._client_kwargs["api_key"] = api_key @property def tokenizer(self) -> BaseTokenizer: return self._tokenizer @property def vocab_size(self) -> int: return self._tokenizer.vocab_size
[docs] def invoke_from_texts( self, input_texts: List[str], max_new_tokens: int = 128, temperature: float = 0.0, message_targets=None, require_generation: bool = False, require_target_prefill: bool = False, require_first_token_logprobs: bool = False, **kwargs, ) -> ModelOutput: """ Generates text completions for the given input texts using parallel execution. Args: input_texts: List of input strings. require_generation: Whether to perform generation. max_new_tokens: Maximum number of tokens to generate. Relevant when require_generation=True. temperature: Sampling temperature. Relevant when require_generation=True. require_first_token_logprobs: Whether to return log-probabilities for the first generated token. Default is False. require_target_prefill: Whether to prefill the target response. Currently *not supported* in this class and will raise an error. **kwargs: Additional arguments to pass to the litellm completion call. Returns: ModelOutput containing the generated response strings and optionally the first-token logprobs. """ if require_target_prefill: raise ValueError("Prefill target response is not supported in LiteLLMModel.") assert require_generation or require_first_token_logprobs, "At least one of require_generation or require_first_token_logprobs must be True." # Build prompts prompts = [ [{"role": "user", "content": text}] for text in input_texts] if self._system_prompt: # prepend system prompt, if provided for prompt in prompts: prompt.insert(0, {"role": "system", "content": self._system_prompt}) # Generation params: generation_kwargs = { "max_tokens": max_new_tokens, "temperature": temperature, } if not require_generation: # if generation is not needed, let's save the tokens generation_kwargs["max_tokens"] = 1 if require_first_token_logprobs: generation_kwargs["logprobs"] = True generation_kwargs["top_logprobs"] = _MAX_TOP_LOGPROBS # Import litellm lazily: optional dependency (`tropt[litellm]`). import litellm outputs = litellm.batch_completion( model=self.model_name, messages=prompts, num_retries=self._N_RETRIES, retry_strategy=self._RETRY_STRATEGY, **self._client_kwargs, **generation_kwargs, **kwargs, ) errors = [o for o in outputs if isinstance(o, Exception)] if errors: raise errors[0] responses: list[str] = _parse_responses_from_outputs(outputs) # Parse first-token logprobs when requested. first_token_logprobs: Optional[List[Dict[str, float]]] = None if require_first_token_logprobs: first_token_logprobs = _parse_first_token_logprobs_from_outputs(outputs) # Track usage try: total_tokens = sum( output.usage.total_tokens for output in outputs ) except Exception as e: logger.warning(f"Failed to compute token usage stats: {e}") total_tokens = 0 self._update_invoke_stats( n_tokens=total_tokens, n_samples=len(input_texts), ) return ModelOutput( generated_response_strs=responses, response_first_token_logprobs=first_token_logprobs, )
## Parsing helpers: ## def _parse_responses_from_outputs(outputs) -> List[str]: """Extract generated response strings from LiteLLM/OpenAI response objects.""" return [ output.choices[0].message.content.strip() for output in outputs ] def _parse_first_token_logprobs_from_outputs(outputs) -> List[Dict[str, float]]: """Extract first-token logprobs from LiteLLM/OpenAI response objects. Returns a list of dicts (one per sample) mapping token strings to their log-probabilities. Falls back to an empty dict when logprobs are unavailable for a given sample. """ result: List[Dict[str, float]] = [] for output in outputs: logprobs_dict: Dict[str, float] = {} logprobs_obj = output.choices[0].logprobs if logprobs_obj is not None and logprobs_obj.content: first_token_info = logprobs_obj.content[0] for top_lp in first_token_info.top_logprobs: logprobs_dict[top_lp.token] = top_lp.logprob result.append(logprobs_dict) return result