Source code for borgllm.borgllm

import os
import yaml
import time
from dotenv import load_dotenv
from typing import Dict, List, Optional, Union
import threading
import logging
from pydantic import (
    BaseModel,
    Field,
    HttpUrl,
    ValidationError,
    model_validator,
    ConfigDict,
)


# Initialize logger
logger = logging.getLogger(__name__)


# Load environment variables from .env file
load_dotenv()


# Helper function to parse API keys from various formats
def _parse_api_keys(api_key_input: Union[str, List[str]]) -> List[str]:
    """Parse API keys from various input formats."""
    if isinstance(api_key_input, list):
        return [key.strip() for key in api_key_input if key.strip()]
    elif isinstance(api_key_input, str):
        if "," in api_key_input:
            return [key.strip() for key in api_key_input.split(",") if key.strip()]
        else:
            return [api_key_input.strip()] if api_key_input.strip() else []
    else:
        return []


# Define built-in providers with their base URLs and corresponding environment variable prefixes
BUILTIN_PROVIDERS = {
    "openai": {
        "base_url": "https://api.openai.com/v1",
        "api_key_env": "OPENAI_API_KEY",
        "default_model": "gpt-4o",
        "max_tokens": 4096,
    },
    "anthropic": {
        "base_url": "https://api.anthropic.com/v1",
        "api_key_env": "ANTHROPIC_API_KEY",
        "default_model": "claude-3-5-sonnet-20240620",
        "max_tokens": 4096,
    },
    "openrouter": {
        "base_url": "https://openrouter.ai/api/v1",
        "api_key_env": "OPENROUTER_API_KEY",
        "default_model": "mistralai/mistral-7b-instruct",
        "max_tokens": 4096,
    },
    "togetherai": {
        "base_url": "https://api.together.xyz/v1",
        "api_key_env": "TOGETHER_API_KEY",
        "default_model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
        "max_tokens": 4096,
    },
    "perplexity": {
        "base_url": "https://api.perplexity.ai",
        "api_key_env": "PERPLEXITY_API_KEY",
        "default_model": "llama-3-sonar-small-32k-online",
        "max_tokens": 32768,
    },
    "mistralai": {
        "base_url": "https://api.mistral.ai/v1",
        "api_key_env": "MISTRAL_API_KEY",
        "default_model": "mistral-large-latest",
        "max_tokens": 32768,
    },
    "fireworks": {
        "base_url": "https://api.fireworks.ai/inference/v1",  # Note: Fireworks specific endpoint often ends with /v1 or /v1/chat/completions
        "api_key_env": "FIREWORKS_API_KEY",
        "default_model": "accounts/fireworks/models/mixtral-8x7b-instruct",
        "max_tokens": 32768,
    },
    "groq": {
        "base_url": "https://api.groq.com/openai/v1",
        "api_key_env": "GROQ_API_KEY",
        "default_model": "llama3-8b-8192",
        "max_tokens": 32768,
    },
    "deepinfra": {
        "base_url": "https://api.deepinfra.com/v1",
        "api_key_env": "DEEPINFRA_API_KEY",
        "default_model": "mistralai/Mistral-7B-Instruct-v0.2",
        "max_tokens": 32768,
    },
    "anyscale": {
        "base_url": "https://api.endpoints.anyscale.com/v1",
        "api_key_env": "ANYSCALE_API_KEY",
        "default_model": "meta-llama/Llama-2-7b-chat-hf",
        "max_tokens": 4096,
    },
    "novita": {
        "base_url": "https://api.novita.ai/v1",  # Based on common OpenAI-compatible patterns, though documentation might specify /v1/chat/completions
        "api_key_env": "NOVITA_API_KEY",
        "default_model": "llama2-7b-chat",
        "max_tokens": 8192,
    },
    "cerebras": {
        "base_url": "https://api.cerebras.ai/v1",
        "api_key_env": "CEREBRAS_API_KEY",
        "default_model": "llama3.1-8b",
        "max_tokens": 2048,
    },
    "featherless": {
        "base_url": "https://api.featherless.ai/v1",
        "api_key_env": "FEATHERLESS_API_KEY",
        "default_model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
        "max_tokens": 8192,
    },
    "cohere": {
        "base_url": "https://api.cohere.ai/compatibility/v1",
        "api_key_env": "COHERE_API_KEY",
        "default_model": "command-r-plus",
        "max_tokens": 131072,
    },
    "qwen": {
        "base_url": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
        "api_key_env": "DASHSCOPE_API_KEY",
        "default_model": "qwen-plus",
        "max_tokens": 32768,
    },
    "deepseek": {
        "base_url": "https://api.deepseek.com/v1",
        "api_key_env": "DEEPSEEK_API_KEY",
        "default_model": "deepseek-chat",
        "max_tokens": 32768,
    },
    "google": {
        "base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
        "api_key_env": "GEMINI_API_KEY",
        "default_model": "gemini-2.5-flash",
        "max_tokens": 32768,
    },
}


class LLMProviderConfig(BaseModel):
    name: str
    base_url: HttpUrl
    model: str
    api_key: str
    temperature: float = 0.7
    max_tokens: int = Field(..., gt=0)

    model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")

    def __init__(self, **data):
        super().__init__(**data)
        if not hasattr(self, "_api_keys") or not self._api_keys:
            self._api_keys: List[str] = [self.api_key] if self.api_key else []
        if not hasattr(self, "_current_key_index"):
            self._current_key_index: int = 0

    @model_validator(mode="before")
    @classmethod
    def validate_api_key(cls, data):
        if isinstance(data, dict):
            api_keys_value = data.get("api_keys")
            api_key_value = data.get("api_key")

            if api_keys_value is not None:
                api_keys = _parse_api_keys(api_keys_value)
                if api_keys:
                    data["api_key"] = api_keys[0]
                    data["_api_keys"] = api_keys
            elif api_key_value is not None:
                api_keys = _parse_api_keys(api_key_value)
                data["api_key"] = api_keys[0] if api_keys else api_key_value
                data["_api_keys"] = api_keys

            if "api_keys" in data:
                del data["api_keys"]
        return data

    def get_next_api_key(self) -> str:
        """Get the next API key in round-robin fashion."""
        if not self._api_keys:
            return self.api_key

        current_key = self._api_keys[self._current_key_index]
        self._current_key_index = (self._current_key_index + 1) % len(self._api_keys)
        self.api_key = current_key
        return current_key

    def set_api_keys(self, api_keys: List[str]):
        """Set the list of API keys for round-robin."""
        if api_keys:
            self._api_keys = api_keys
            self._current_key_index = 0
            self.api_key = api_keys[0]

    def has_multiple_keys(self) -> bool:
        """Check if this provider has multiple API keys."""
        return len(self._api_keys) > 1


# Global cache for built-in providers to share state across all BorgLLM instances
_GLOBAL_BUILTIN_PROVIDERS: Dict[str, LLMProviderConfig] = {}
_GLOBAL_BUILTIN_LOCK = threading.Lock()


class VirtualLLMProviderConfig(BaseModel):
    name: str
    upstreams: List[Dict[str, str]]


class LLMConfig(BaseModel):
    providers: List[LLMProviderConfig]
    virtual: Optional[List[VirtualLLMProviderConfig]] = None
    default_model: Optional[str] = None


[docs] class BorgLLM: _instance = None _config_initialized: bool = False _lock = threading.Lock() def __new__(cls, *args, **kwargs): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super(BorgLLM, cls).__new__(cls) return cls._instance def __init__( self, config_path: str = "borg.yaml", initial_config_data: Optional[Dict] = None, _force_reinitialize: bool = False, ): if not hasattr(self, "_default_provider_name") or _force_reinitialize: self._default_provider_name: Optional[str] = None if not hasattr(self, "_config") or _force_reinitialize: self._config: Optional[LLMConfig] = None if not hasattr(self, "_real_providers") or _force_reinitialize: self._real_providers: Dict[str, LLMProviderConfig] = {} if not hasattr(self, "_virtual_providers") or _force_reinitialize: self._virtual_providers: Dict[str, VirtualLLMProviderConfig] = {} if not hasattr(self, "_unusable_providers") or _force_reinitialize: self._unusable_providers: Dict[str, float] = {} if not hasattr(self, "_virtual_provider_last_index") or _force_reinitialize: self._virtual_provider_last_index: Dict[str, int] = {} if BorgLLM._config_initialized and not _force_reinitialize: logger.debug("BorgLLM already initialized, skipping re-initialization.") return with self._lock: if not BorgLLM._config_initialized or _force_reinitialize: self.config_path = config_path self._config_paths = self._get_config_paths(config_path) self._config: Optional[LLMConfig] = None self._real_providers: Dict[str, LLMProviderConfig] = {} self._virtual_providers: Dict[str, VirtualLLMProviderConfig] = {} self._default_provider_name: Optional[str] = None if initial_config_data: interpolated_data = self._interpolate_env_variables( initial_config_data ) try: self._config = LLMConfig(**interpolated_data["llm"]) except ValidationError as e: logger.error(f"Initial configuration validation error: {e}") raise self._load_config() self._populate_providers() self._add_builtin_providers() logger.info("\n--- BorgLLM Configuration Summary ---") if self._real_providers: for provider_name, provider_config in self._real_providers.items(): masked_key = ( f"{provider_config.api_key[:4]}...{provider_config.api_key[-4:]}" if provider_config.api_key else "[NO KEY]" ) num_keys = ( len(provider_config._api_keys) if provider_config._api_keys else 1 ) logger.info( f"Provider '{provider_name}': {num_keys} API key(s) loaded (Current/First: {masked_key})" ) else: logger.info( "No LLM providers configured or found via environment variables." ) logger.info("-----------------------------------") self._virtual_provider_last_index: Dict[str, int] = {} self._unusable_providers: Dict[str, float] = {} BorgLLM._config_initialized = True
[docs] @classmethod def get_instance( cls, config_path: str = "borg.yaml", initial_config_data: Optional[Dict] = None, ): """Get the singleton BorgLLM instance.""" if cls._instance is None or not cls._config_initialized: cls( _force_reinitialize=True, config_path=config_path, initial_config_data=initial_config_data, ) return cls._instance
@property def config(self) -> Optional[LLMConfig]: """Public property to access the configuration.""" return self._config @property def providers(self) -> Dict[str, LLMProviderConfig]: """Public property to access the real providers.""" return self._real_providers
[docs] def set_default_provider(self, provider_name: str): """Set the default LLM provider name for this BorgLLM instance.""" is_builtin_reference = False if ":" in provider_name: provider_key = provider_name.split(":", 1)[0] if provider_key in BUILTIN_PROVIDERS: is_builtin_reference = True elif provider_name in BUILTIN_PROVIDERS: is_builtin_reference = True if ( provider_name not in self._real_providers and provider_name not in self._virtual_providers and not is_builtin_reference ): raise ValueError( f"Provider '{provider_name}' not found. Cannot set as default." ) with self._lock: self._default_provider_name = provider_name logger.info( f"Instance default LLM provider set to '{provider_name}' (overrides any config file default)." )
def _get_config_paths(self, base_path: str) -> List[str]: if base_path.endswith((".yaml", ".yml")): return [base_path] return [f"{base_path}.yaml", f"{base_path}.yml"] def _populate_providers(self): if not self._config: self._config = LLMConfig(providers=[], virtual=[], default_model=None) for provider in self._config.providers: self._real_providers[provider.name] = provider if self._config.virtual: for provider in self._config.virtual: self._virtual_providers[provider.name] = provider if self._config.default_model and not self._default_provider_name: self._default_provider_name = self._config.default_model # Static check: Verify all upstreams in virtual providers exist if self._config.virtual: for ( virtual_provider_name, virtual_config, ) in self._virtual_providers.items(): for upstream_info in virtual_config.upstreams: upstream_name = upstream_info["name"] # Check if it's a built-in provider reference (e.g., "cerebras:qwen-3-32b") is_builtin_reference = False if ":" in upstream_name: provider_key = upstream_name.split(":", 1)[0] if provider_key in BUILTIN_PROVIDERS: is_builtin_reference = True elif upstream_name in BUILTIN_PROVIDERS: is_builtin_reference = True if ( upstream_name not in self._real_providers and upstream_name not in self._virtual_providers and not is_builtin_reference ): raise ValueError( f"Virtual provider '{virtual_provider_name}' references non-existent upstream '{upstream_name}'." ) def _add_builtin_providers(self): with _GLOBAL_BUILTIN_LOCK: for provider_name, settings in BUILTIN_PROVIDERS.items(): # Support multiple API keys: check for both *_API_KEYS and *_API_KEY api_key_env = settings["api_key_env"] api_keys_env = api_key_env + "S" api_keys_value = os.getenv(api_keys_env) api_key_value = os.getenv(api_key_env) api_keys_list = [] if api_keys_value: api_keys_list = _parse_api_keys(api_keys_value) elif api_key_value: api_keys_list = _parse_api_keys(api_key_value) if api_keys_list and provider_name not in self._real_providers: try: provider_data = { "name": provider_name, "base_url": settings["base_url"], "model": settings["default_model"], "api_key": api_keys_list[0], "api_keys": api_keys_list, "temperature": settings.get("temperature", 0.7), "max_tokens": settings.get("max_tokens", 4096), } builtin_config = LLMProviderConfig(**provider_data) _GLOBAL_BUILTIN_PROVIDERS[provider_name] = builtin_config self._real_providers[provider_name] = builtin_config except ValidationError as e: logger.warning( f"Error validating built-in provider {provider_name}: {e}" ) def _interpolate_env_variables(self, data): if isinstance(data, dict): return {k: self._interpolate_env_variables(v) for k, v in data.items()} elif isinstance(data, list): return [self._interpolate_env_variables(elem) for elem in data] elif isinstance(data, str) and data.startswith("${") and data.endswith("}"): env_var_name = data[2:-1] return os.getenv(env_var_name, data) return data def _load_config(self): loaded = False for path in self._config_paths: if os.path.exists(path): logger.info(f"Loading configuration from {path}") with open(path, "r") as f: raw_config = yaml.safe_load(f) interpolated_config = self._interpolate_env_variables(raw_config) try: if self._config: file_config_data = interpolated_config.get("llm", {}) current_providers = [ p.model_dump() for p in self._config.providers ] current_virtual = ( [v.model_dump() for v in self._config.virtual] if self._config.virtual else [] ) new_providers = file_config_data.get("providers", []) all_providers = {p["name"]: p for p in current_providers} for p_data in new_providers: try: p = LLMProviderConfig(**p_data) all_providers[p.name] = p.model_dump() except ValidationError as e: logger.warning( f"Skipping invalid provider '{p_data.get('name', 'UNKNOWN')}' from config file: {e}" ) new_virtual = file_config_data.get("virtual", []) all_virtual = {v["name"]: v for v in current_virtual} for v in new_virtual: all_virtual[v["name"]] = v combined_config_data = { "providers": list(all_providers.values()), "virtual": list(all_virtual.values()), "default_model": file_config_data.get( "default_model", self._config.default_model ), } self._config = LLMConfig(**combined_config_data) else: self._config = LLMConfig(**interpolated_config["llm"]) loaded = True break except KeyError: logger.warning(f"Configuration file {path} is missing 'llm' key.") except ValidationError as e: logger.error(f"Configuration validation error for {path}: {e}") except Exception as e: logger.error(f"Error loading configuration from {path}: {e}") self._config_loaded = loaded if not loaded and not self._config: logger.info( f"No configuration file found at {', '.join(self._config_paths)}. Proceeding with environment variables and defaults only." ) self._config = LLMConfig(providers=[], virtual=[], default_model=None)
[docs] def signal_429(self, provider_name: str, duration: int = 300): with self._lock: self._unusable_providers[provider_name] = time.time() + duration
def _is_provider_unusable(self, provider_name: str) -> bool: with self._lock: if provider_name in self._unusable_providers: cooldown_end = self._unusable_providers[provider_name] current_time = time.time() if current_time < cooldown_end: return True else: del self._unusable_providers[provider_name] return False
[docs] def get( self, name: Optional[str] = None, approximate_tokens: Optional[int] = None, timeout: Optional[float] = None, allow_await_cooldown: bool = True, ) -> LLMProviderConfig: # Default provider logic if name is None: if self._default_provider_name: name = self._default_provider_name logger.info(f"Using programmatically set default provider: {name}") elif self._config and self._config.default_model: name = self._config.default_model logger.info(f"Using default provider from config file: {name}") elif self._real_providers: name = next(iter(self._real_providers)) logger.info(f"Using first available provider: {name}") else: raise ValueError( "No default LLM provider specified and no configuration file found. " "Please specify a provider name in provider:model format, set a default in borg.yaml, or use set_default_provider()." ) provider_key = None model_name_for_request = None if ":" in name: parts = name.split(":", 1) provider_key = parts[0] model_name_for_request = parts[1] # Check if this is a built-in provider request if provider_key and provider_key in BUILTIN_PROVIDERS: with _GLOBAL_BUILTIN_LOCK: if provider_key in _GLOBAL_BUILTIN_PROVIDERS: provider_instance = _GLOBAL_BUILTIN_PROVIDERS[provider_key] if ( model_name_for_request and provider_instance.model != model_name_for_request ): provider_instance.model = model_name_for_request if allow_await_cooldown: self._await_cooldown(provider_key, timeout=timeout) if self._is_provider_unusable(provider_key): raise ValueError( f"Provider '{provider_key}' is on cooldown and await_cooldown is false" ) if provider_instance.has_multiple_keys(): provider_instance.get_next_api_key() self._real_providers[provider_key] = provider_instance return provider_instance settings = BUILTIN_PROVIDERS[provider_key] api_key_env = settings["api_key_env"] api_keys_env = api_key_env + "S" api_keys_value = os.getenv(api_keys_env) api_key_value = os.getenv(api_key_env) api_keys_list = [] if api_keys_value: api_keys_list = _parse_api_keys(api_keys_value) elif api_key_value: api_keys_list = _parse_api_keys(api_key_value) if not api_keys_list: env_var_names = ( [api_keys_env, api_key_env] if api_keys_value is None else [api_keys_env] ) raise ValueError( f"Built-in provider '{provider_key}' requires " f"one of the environment variables {env_var_names} to be set." ) if not model_name_for_request: model_name_for_request = settings["default_model"] if provider_key in _GLOBAL_BUILTIN_PROVIDERS: return _GLOBAL_BUILTIN_PROVIDERS[provider_key] builtin_config = LLMProviderConfig( name=provider_key, base_url=settings["base_url"], model=model_name_for_request, api_key=api_keys_list[0], api_keys=api_keys_list, temperature=settings.get("temperature", 0.7), max_tokens=settings.get("max_tokens", 4096), ) with _GLOBAL_BUILTIN_LOCK: _GLOBAL_BUILTIN_PROVIDERS[provider_key] = builtin_config return builtin_config # Handle configured providers if name in self._real_providers: if allow_await_cooldown: self._await_cooldown(name, timeout=timeout) if self._is_provider_unusable(name): raise ValueError( f"Provider '{name}' is on cooldown and await_cooldown is false" ) provider = self._real_providers[name] if provider.has_multiple_keys(): provider.get_next_api_key() return provider elif name in self._virtual_providers: return self._get_from_virtual_provider( name, approximate_tokens, timeout, allow_await_cooldown ) else: # Check if it's a built-in provider without model specification if name in BUILTIN_PROVIDERS: raise ValueError( f"Provider '{name}' requires model specification. Use format '{name}:model_name' (e.g., '{name}:{BUILTIN_PROVIDERS[name]['default_model']}')" ) raise ValueError(f"LLM provider '{name}' not found")
def _await_cooldown( self, provider_name: str, interval: float = 1.0, timeout: Optional[float] = None ): with self._lock: if provider_name in self._unusable_providers: cooldown_end = self._unusable_providers[provider_name] current_time = time.time() if current_time < cooldown_end: time_to_wait = cooldown_end - current_time if timeout is not None and time_to_wait > timeout: raise TimeoutError( f"Timeout waiting for provider {provider_name} to exit cooldown" ) logger.info( f"Provider '{provider_name}' in cooldown. Waiting " f"{time_to_wait:.2f} seconds..." ) time.sleep(time_to_wait) if self._is_provider_unusable(provider_name): raise ValueError( f"Provider '{provider_name}' is still in cooldown after waiting." ) del self._unusable_providers[provider_name] def _get_from_virtual_provider( self, virtual_provider_name: str, approximate_tokens: Optional[int], timeout: Optional[float], allow_await_cooldown: bool, ) -> LLMProviderConfig: virtual_config = self._virtual_providers[virtual_provider_name] start_time = time.time() while True: all_resolved_upstreams: List[LLMProviderConfig] = [] min_cooldown_end_time = float("inf") for upstream_info in virtual_config.upstreams: upstream_name = upstream_info["name"] if self._is_provider_unusable(upstream_name): current_cooldown_end = self._unusable_providers.get( upstream_name, float("inf") ) min_cooldown_end_time = min( min_cooldown_end_time, current_cooldown_end ) continue try: resolved_provider = self.get( upstream_name, approximate_tokens, timeout=None, allow_await_cooldown=False, ) all_resolved_upstreams.append(resolved_provider) except ValueError: pass if all_resolved_upstreams: filtered_upstreams: List[LLMProviderConfig] = [] if approximate_tokens is not None: for provider in all_resolved_upstreams: if approximate_tokens <= provider.max_tokens: filtered_upstreams.append(provider) else: filtered_upstreams = all_resolved_upstreams if filtered_upstreams: import random selected_provider = filtered_upstreams[0] if selected_provider.has_multiple_keys(): selected_provider.get_next_api_key() return selected_provider if not allow_await_cooldown: raise ValueError( f"No eligible upstream providers for virtual provider {virtual_provider_name}. All are on cooldown." ) if min_cooldown_end_time == float("inf"): raise ValueError( f"No upstreams found for virtual provider {virtual_provider_name} to await." ) time_to_wait = min_cooldown_end_time - time.time() if time_to_wait <= 0: continue if timeout is not None: time_elapsed = time.time() - start_time remaining_timeout = timeout - time_elapsed if remaining_timeout <= 0: earliest_provider = None earliest_time = float("inf") for upstream_info in virtual_config.upstreams: upstream_name = upstream_info["name"] if upstream_name in self._unusable_providers: cooldown_end = self._unusable_providers[upstream_name] if cooldown_end < earliest_time: earliest_time = cooldown_end earliest_provider = upstream_name if earliest_provider: raise TimeoutError( f"Timeout waiting for provider {earliest_provider} to exit cooldown." ) else: raise ValueError( f"Timeout of {timeout} seconds reached while waiting for usable upstreams for virtual provider {virtual_provider_name}." ) sleep_duration = min(time_to_wait, remaining_timeout) time.sleep(sleep_duration) else: time.sleep(time_to_wait)