Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scripts/check_missing_api_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import sys
from pathlib import Path

from any_llm.constants import ProviderName
from any_llm.exceptions import MissingApiKeyError
from any_llm.provider import ProviderFactory, ProviderName
from any_llm.factory import ProviderFactory

src_path = Path(__file__).parent.parent / "src"
sys.path.insert(0, str(src_path))
Expand Down
2 changes: 1 addition & 1 deletion scripts/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import httpx

from any_llm.provider import ProviderFactory
from any_llm.factory import ProviderFactory
from any_llm.types.provider import ProviderMetadata


Expand Down
2 changes: 1 addition & 1 deletion src/any_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
list_models_async,
responses,
)
from any_llm.constants import ProviderName
from any_llm.exceptions import MissingApiKeyError
from any_llm.provider import ProviderName
from any_llm.tools import callable_to_tool, prepare_tools

__all__ = [
Expand Down
4 changes: 3 additions & 1 deletion src/any_llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from pydantic import BaseModel

from any_llm.provider import ClientConfig, ProviderFactory, ProviderName
from any_llm.config import ClientConfig
from any_llm.constants import ProviderName
from any_llm.factory import ProviderFactory
from any_llm.tools import prepare_tools
from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, CreateEmbeddingResponse
from any_llm.types.model import Model
Expand Down
11 changes: 11 additions & 0 deletions src/any_llm/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Any

from pydantic import BaseModel


class ClientConfig(BaseModel):
"""Configuration for the underlying client used by the provider."""

api_key: str | None = None
api_base: str | None = None
client_args: dict[str, Any] | None = None
56 changes: 56 additions & 0 deletions src/any_llm/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import builtins
from enum import StrEnum

from any_llm.exceptions import UnsupportedProviderError

INSIDE_NOTEBOOK = hasattr(builtins, "__IPYTHON__")


class ProviderName(StrEnum):
"""String enum for supported providers."""

ANTHROPIC = "anthropic"
BEDROCK = "bedrock"
AZURE = "azure"
AZUREOPENAI = "azureopenai"
CEREBRAS = "cerebras"
COHERE = "cohere"
DATABRICKS = "databricks"
DEEPSEEK = "deepseek"
FIREWORKS = "fireworks"
GEMINI = "gemini"
GROQ = "groq"
HUGGINGFACE = "huggingface"
INCEPTION = "inception"
LLAMA = "llama"
LMSTUDIO = "lmstudio"
LLAMAFILE = "llamafile"
LLAMACPP = "llamacpp"
MISTRAL = "mistral"
MOONSHOT = "moonshot"
NEBIUS = "nebius"
OLLAMA = "ollama"
OPENAI = "openai"
OPENROUTER = "openrouter"
PORTKEY = "portkey"
SAMBANOVA = "sambanova"
SAGEMAKER = "sagemaker"
TOGETHER = "together"
VERTEXAI = "vertexai"
VOYAGE = "voyage"
WATSONX = "watsonx"
XAI = "xai"
PERPLEXITY = "perplexity"

@classmethod
def from_string(cls, value: "str | ProviderName") -> "ProviderName":
"""Convert a string to a ProviderName enum."""
if isinstance(value, cls):
return value

formatted_value = value.strip().lower()
try:
return cls(formatted_value)
except ValueError as exc:
supported = [provider.value for provider in cls]
raise UnsupportedProviderError(value, supported) from exc
127 changes: 127 additions & 0 deletions src/any_llm/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import importlib
import warnings
from pathlib import Path

from any_llm.config import ClientConfig
from any_llm.constants import ProviderName
from any_llm.exceptions import UnsupportedProviderError
from any_llm.provider import Provider
from any_llm.types.provider import ProviderMetadata


class ProviderFactory:
"""Factory to dynamically load provider instances based on the naming conventions."""

PROVIDERS_DIR = Path(__file__).parent / "providers"

@classmethod
def create_provider(cls, provider_key: str | ProviderName, config: ClientConfig) -> Provider:
"""Dynamically load and create an instance of a provider based on the naming convention."""
provider_key = ProviderName.from_string(provider_key).value

provider_class_name = f"{provider_key.capitalize()}Provider"
provider_module_name = f"{provider_key}"

module_path = f"any_llm.providers.{provider_module_name}"

try:
module = importlib.import_module(module_path)
except ImportError as e:
msg = f"Could not import module {module_path}: {e!s}. Please ensure the provider is supported by doing ProviderFactory.get_supported_providers()"
raise ImportError(msg) from e

provider_class: type[Provider] = getattr(module, provider_class_name)
return provider_class(config=config)

@classmethod
def get_provider_class(cls, provider_key: str | ProviderName) -> type[Provider]:
"""Get the provider class without instantiating it.

Args:
provider_key: The provider key (e.g., 'anthropic', 'openai')

Returns:
The provider class

"""
provider_key = ProviderName.from_string(provider_key).value

provider_class_name = f"{provider_key.capitalize()}Provider"
provider_module_name = f"{provider_key}"

module_path = f"any_llm.providers.{provider_module_name}"

try:
module = importlib.import_module(module_path)
except ImportError as e:
msg = f"Could not import module {module_path}: {e!s}. Please ensure the provider is supported by doing ProviderFactory.get_supported_providers()"
raise ImportError(msg) from e

provider_class: type[Provider] = getattr(module, provider_class_name)
return provider_class

@classmethod
def get_supported_providers(cls) -> list[str]:
"""Get a list of supported provider keys."""
return [provider.value for provider in ProviderName]

@classmethod
def get_all_provider_metadata(cls) -> list[ProviderMetadata]:
"""Get metadata for all supported providers.

Returns:
List of dictionaries containing provider metadata

"""
providers: list[ProviderMetadata] = []
for provider_key in cls.get_supported_providers():
provider_class = cls.get_provider_class(provider_key)
metadata = provider_class.get_provider_metadata()
providers.append(metadata)

# Sort providers by name
providers.sort(key=lambda x: x.name)
return providers

@classmethod
def get_provider_enum(cls, provider_key: str) -> ProviderName:
"""Convert a string provider key to a ProviderName enum."""
try:
return ProviderName(provider_key)
except ValueError as e:
supported = [provider.value for provider in ProviderName]
raise UnsupportedProviderError(provider_key, supported) from e

@classmethod
def split_model_provider(cls, model: str) -> tuple[ProviderName, str]:
"""Extract the provider key from the model identifier.

Supports both new format 'provider:model' (e.g., 'mistral:mistral-small')
and legacy format 'provider/model' (e.g., 'mistral/mistral-small').

The legacy format will be deprecated in version 1.0.
"""
colon_index = model.find(":")
slash_index = model.find("/")

# Determine which delimiter comes first
if colon_index != -1 and (slash_index == -1 or colon_index < slash_index):
# The colon came first, so it's using the new syntax.
provider, model_name = model.split(":", 1)
elif slash_index != -1:
# Slash comes first, so it's the legacy syntax
warnings.warn(
f"Model format 'provider/model' is deprecated and will be removed in version 1.0. "
f"Please use 'provider:model' format instead. Got: '{model}'",
DeprecationWarning,
stacklevel=3,
)
provider, model_name = model.split("/", 1)
else:
msg = f"Invalid model format. Expected 'provider:model' or 'provider/model', got '{model}'"
raise ValueError(msg)

if not provider or not model_name:
msg = f"Invalid model format. Expected 'provider:model' or 'provider/model', got '{model}'"
raise ValueError(msg)
return cls.get_provider_enum(provider), model_name
Loading