diff --git a/scripts/check_missing_api_keys.py b/scripts/check_missing_api_keys.py index 4fd2b29e..b3717896 100755 --- a/scripts/check_missing_api_keys.py +++ b/scripts/check_missing_api_keys.py @@ -10,8 +10,9 @@ import sys from pathlib import Path +from any_llm.constants import ProviderName from any_llm.exceptions import MissingApiKeyError -from any_llm.provider import ProviderFactory, ProviderName +from any_llm.factory import ProviderFactory src_path = Path(__file__).parent.parent / "src" sys.path.insert(0, str(src_path)) diff --git a/scripts/hooks.py b/scripts/hooks.py index 09d10618..7aaa3c09 100644 --- a/scripts/hooks.py +++ b/scripts/hooks.py @@ -11,7 +11,7 @@ import httpx -from any_llm.provider import ProviderFactory +from any_llm.factory import ProviderFactory from any_llm.types.provider import ProviderMetadata diff --git a/src/any_llm/__init__.py b/src/any_llm/__init__.py index 34615b1c..c2ad39e4 100644 --- a/src/any_llm/__init__.py +++ b/src/any_llm/__init__.py @@ -8,8 +8,8 @@ list_models_async, responses, ) +from any_llm.constants import ProviderName from any_llm.exceptions import MissingApiKeyError -from any_llm.provider import ProviderName from any_llm.tools import callable_to_tool, prepare_tools __all__ = [ diff --git a/src/any_llm/api.py b/src/any_llm/api.py index afb03867..66c704ce 100644 --- a/src/any_llm/api.py +++ b/src/any_llm/api.py @@ -3,7 +3,9 @@ from pydantic import BaseModel -from any_llm.provider import ClientConfig, ProviderFactory, ProviderName +from any_llm.config import ClientConfig +from any_llm.constants import ProviderName +from any_llm.factory import ProviderFactory from any_llm.tools import prepare_tools from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, CreateEmbeddingResponse from any_llm.types.model import Model diff --git a/src/any_llm/config.py b/src/any_llm/config.py new file mode 100644 index 00000000..8f5523d7 --- /dev/null +++ b/src/any_llm/config.py @@ -0,0 +1,11 @@ +from typing import Any + +from pydantic import BaseModel + + +class ClientConfig(BaseModel): + """Configuration for the underlying client used by the provider.""" + + api_key: str | None = None + api_base: str | None = None + client_args: dict[str, Any] | None = None diff --git a/src/any_llm/constants.py b/src/any_llm/constants.py new file mode 100644 index 00000000..a9a9f2b5 --- /dev/null +++ b/src/any_llm/constants.py @@ -0,0 +1,56 @@ +import builtins +from enum import StrEnum + +from any_llm.exceptions import UnsupportedProviderError + +INSIDE_NOTEBOOK = hasattr(builtins, "__IPYTHON__") + + +class ProviderName(StrEnum): + """String enum for supported providers.""" + + ANTHROPIC = "anthropic" + BEDROCK = "bedrock" + AZURE = "azure" + AZUREOPENAI = "azureopenai" + CEREBRAS = "cerebras" + COHERE = "cohere" + DATABRICKS = "databricks" + DEEPSEEK = "deepseek" + FIREWORKS = "fireworks" + GEMINI = "gemini" + GROQ = "groq" + HUGGINGFACE = "huggingface" + INCEPTION = "inception" + LLAMA = "llama" + LMSTUDIO = "lmstudio" + LLAMAFILE = "llamafile" + LLAMACPP = "llamacpp" + MISTRAL = "mistral" + MOONSHOT = "moonshot" + NEBIUS = "nebius" + OLLAMA = "ollama" + OPENAI = "openai" + OPENROUTER = "openrouter" + PORTKEY = "portkey" + SAMBANOVA = "sambanova" + SAGEMAKER = "sagemaker" + TOGETHER = "together" + VERTEXAI = "vertexai" + VOYAGE = "voyage" + WATSONX = "watsonx" + XAI = "xai" + PERPLEXITY = "perplexity" + + @classmethod + def from_string(cls, value: "str | ProviderName") -> "ProviderName": + """Convert a string to a ProviderName enum.""" + if isinstance(value, cls): + return value + + formatted_value = value.strip().lower() + try: + return cls(formatted_value) + except ValueError as exc: + supported = [provider.value for provider in cls] + raise UnsupportedProviderError(value, supported) from exc diff --git a/src/any_llm/factory.py b/src/any_llm/factory.py new file mode 100644 index 00000000..5cb2c8c7 --- /dev/null +++ b/src/any_llm/factory.py @@ -0,0 +1,127 @@ +import importlib +import warnings +from pathlib import Path + +from any_llm.config import ClientConfig +from any_llm.constants import ProviderName +from any_llm.exceptions import UnsupportedProviderError +from any_llm.provider import Provider +from any_llm.types.provider import ProviderMetadata + + +class ProviderFactory: + """Factory to dynamically load provider instances based on the naming conventions.""" + + PROVIDERS_DIR = Path(__file__).parent / "providers" + + @classmethod + def create_provider(cls, provider_key: str | ProviderName, config: ClientConfig) -> Provider: + """Dynamically load and create an instance of a provider based on the naming convention.""" + provider_key = ProviderName.from_string(provider_key).value + + provider_class_name = f"{provider_key.capitalize()}Provider" + provider_module_name = f"{provider_key}" + + module_path = f"any_llm.providers.{provider_module_name}" + + try: + module = importlib.import_module(module_path) + except ImportError as e: + msg = f"Could not import module {module_path}: {e!s}. Please ensure the provider is supported by doing ProviderFactory.get_supported_providers()" + raise ImportError(msg) from e + + provider_class: type[Provider] = getattr(module, provider_class_name) + return provider_class(config=config) + + @classmethod + def get_provider_class(cls, provider_key: str | ProviderName) -> type[Provider]: + """Get the provider class without instantiating it. + + Args: + provider_key: The provider key (e.g., 'anthropic', 'openai') + + Returns: + The provider class + + """ + provider_key = ProviderName.from_string(provider_key).value + + provider_class_name = f"{provider_key.capitalize()}Provider" + provider_module_name = f"{provider_key}" + + module_path = f"any_llm.providers.{provider_module_name}" + + try: + module = importlib.import_module(module_path) + except ImportError as e: + msg = f"Could not import module {module_path}: {e!s}. Please ensure the provider is supported by doing ProviderFactory.get_supported_providers()" + raise ImportError(msg) from e + + provider_class: type[Provider] = getattr(module, provider_class_name) + return provider_class + + @classmethod + def get_supported_providers(cls) -> list[str]: + """Get a list of supported provider keys.""" + return [provider.value for provider in ProviderName] + + @classmethod + def get_all_provider_metadata(cls) -> list[ProviderMetadata]: + """Get metadata for all supported providers. + + Returns: + List of dictionaries containing provider metadata + + """ + providers: list[ProviderMetadata] = [] + for provider_key in cls.get_supported_providers(): + provider_class = cls.get_provider_class(provider_key) + metadata = provider_class.get_provider_metadata() + providers.append(metadata) + + # Sort providers by name + providers.sort(key=lambda x: x.name) + return providers + + @classmethod + def get_provider_enum(cls, provider_key: str) -> ProviderName: + """Convert a string provider key to a ProviderName enum.""" + try: + return ProviderName(provider_key) + except ValueError as e: + supported = [provider.value for provider in ProviderName] + raise UnsupportedProviderError(provider_key, supported) from e + + @classmethod + def split_model_provider(cls, model: str) -> tuple[ProviderName, str]: + """Extract the provider key from the model identifier. + + Supports both new format 'provider:model' (e.g., 'mistral:mistral-small') + and legacy format 'provider/model' (e.g., 'mistral/mistral-small'). + + The legacy format will be deprecated in version 1.0. + """ + colon_index = model.find(":") + slash_index = model.find("/") + + # Determine which delimiter comes first + if colon_index != -1 and (slash_index == -1 or colon_index < slash_index): + # The colon came first, so it's using the new syntax. + provider, model_name = model.split(":", 1) + elif slash_index != -1: + # Slash comes first, so it's the legacy syntax + warnings.warn( + f"Model format 'provider/model' is deprecated and will be removed in version 1.0. " + f"Please use 'provider:model' format instead. Got: '{model}'", + DeprecationWarning, + stacklevel=3, + ) + provider, model_name = model.split("/", 1) + else: + msg = f"Invalid model format. Expected 'provider:model' or 'provider/model', got '{model}'" + raise ValueError(msg) + + if not provider or not model_name: + msg = f"Invalid model format. Expected 'provider:model' or 'provider/model', got '{model}'" + raise ValueError(msg) + return cls.get_provider_enum(provider), model_name diff --git a/src/any_llm/provider.py b/src/any_llm/provider.py index 6f755d6b..de775a88 100644 --- a/src/any_llm/provider.py +++ b/src/any_llm/provider.py @@ -1,19 +1,13 @@ # Inspired by https://github.com/andrewyng/aisuite/tree/main/aisuite import asyncio -import builtins -import importlib -import logging import os -import warnings from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterator, Sequence -from enum import StrEnum -from pathlib import Path from typing import Any -from pydantic import BaseModel - -from any_llm.exceptions import MissingApiKeyError, UnsupportedProviderError +from any_llm.config import ClientConfig +from any_llm.constants import INSIDE_NOTEBOOK +from any_llm.exceptions import MissingApiKeyError from any_llm.types.completion import ( ChatCompletion, ChatCompletionChunk, @@ -25,68 +19,6 @@ from any_llm.types.responses import Response, ResponseInputParam, ResponseStreamEvent from any_llm.utils.aio import async_iter_to_sync_iter, run_async_in_sync -logger = logging.getLogger(__name__) - - -INSIDE_NOTEBOOK = hasattr(builtins, "__IPYTHON__") - - -class ProviderName(StrEnum): - """String enum for supported providers.""" - - ANTHROPIC = "anthropic" - BEDROCK = "bedrock" - AZURE = "azure" - AZUREOPENAI = "azureopenai" - CEREBRAS = "cerebras" - COHERE = "cohere" - DATABRICKS = "databricks" - DEEPSEEK = "deepseek" - FIREWORKS = "fireworks" - GEMINI = "gemini" - GROQ = "groq" - HUGGINGFACE = "huggingface" - INCEPTION = "inception" - LLAMA = "llama" - LMSTUDIO = "lmstudio" - LLAMAFILE = "llamafile" - LLAMACPP = "llamacpp" - MISTRAL = "mistral" - MOONSHOT = "moonshot" - NEBIUS = "nebius" - OLLAMA = "ollama" - OPENAI = "openai" - OPENROUTER = "openrouter" - PORTKEY = "portkey" - SAMBANOVA = "sambanova" - SAGEMAKER = "sagemaker" - TOGETHER = "together" - VERTEXAI = "vertexai" - VOYAGE = "voyage" - WATSONX = "watsonx" - XAI = "xai" - PERPLEXITY = "perplexity" - - @classmethod - def from_string(cls, value: "str | ProviderName") -> "ProviderName": - if isinstance(value, cls): - return value - - formatted_value = value.strip().lower() - try: - return cls(formatted_value) - except ValueError as exc: - supported = [provider.value for provider in cls] - raise UnsupportedProviderError(value, supported) from exc - - -class ClientConfig(BaseModel): - """Configuration for the underlying client used by the provider.""" - - api_key: str | None = None - api_base: str | None = None - client_args: dict[str, Any] | None = None - class Provider(ABC): """Provider for the LLM.""" @@ -292,119 +224,3 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: async def list_models_async(self, **kwargs: Any) -> Sequence[Model]: return await asyncio.to_thread(self.list_models, **kwargs) - - -class ProviderFactory: - """Factory to dynamically load provider instances based on the naming conventions.""" - - PROVIDERS_DIR = Path(__file__).parent / "providers" - - @classmethod - def create_provider(cls, provider_key: str | ProviderName, config: ClientConfig) -> Provider: - """Dynamically load and create an instance of a provider based on the naming convention.""" - provider_key = ProviderName.from_string(provider_key).value - - provider_class_name = f"{provider_key.capitalize()}Provider" - provider_module_name = f"{provider_key}" - - module_path = f"any_llm.providers.{provider_module_name}" - - try: - module = importlib.import_module(module_path) - except ImportError as e: - msg = f"Could not import module {module_path}: {e!s}. Please ensure the provider is supported by doing ProviderFactory.get_supported_providers()" - raise ImportError(msg) from e - - provider_class: type[Provider] = getattr(module, provider_class_name) - return provider_class(config=config) - - @classmethod - def get_provider_class(cls, provider_key: str | ProviderName) -> type[Provider]: - """Get the provider class without instantiating it. - - Args: - provider_key: The provider key (e.g., 'anthropic', 'openai') - - Returns: - The provider class - """ - provider_key = ProviderName.from_string(provider_key).value - - provider_class_name = f"{provider_key.capitalize()}Provider" - provider_module_name = f"{provider_key}" - - module_path = f"any_llm.providers.{provider_module_name}" - - try: - module = importlib.import_module(module_path) - except ImportError as e: - msg = f"Could not import module {module_path}: {e!s}. Please ensure the provider is supported by doing ProviderFactory.get_supported_providers()" - raise ImportError(msg) from e - - provider_class: type[Provider] = getattr(module, provider_class_name) - return provider_class - - @classmethod - def get_supported_providers(cls) -> list[str]: - """Get a list of supported provider keys.""" - return [provider.value for provider in ProviderName] - - @classmethod - def get_all_provider_metadata(cls) -> list[ProviderMetadata]: - """Get metadata for all supported providers. - - Returns: - List of dictionaries containing provider metadata - """ - providers: list[ProviderMetadata] = [] - for provider_key in cls.get_supported_providers(): - provider_class = cls.get_provider_class(provider_key) - metadata = provider_class.get_provider_metadata() - providers.append(metadata) - - # Sort providers by name - providers.sort(key=lambda x: x.name) - return providers - - @classmethod - def get_provider_enum(cls, provider_key: str) -> ProviderName: - """Convert a string provider key to a ProviderName enum.""" - try: - return ProviderName(provider_key) - except ValueError as e: - supported = [provider.value for provider in ProviderName] - raise UnsupportedProviderError(provider_key, supported) from e - - @classmethod - def split_model_provider(cls, model: str) -> tuple[ProviderName, str]: - """Extract the provider key from the model identifier. - - Supports both new format 'provider:model' (e.g., 'mistral:mistral-small') - and legacy format 'provider/model' (e.g., 'mistral/mistral-small'). - - The legacy format will be deprecated in version 1.0. - """ - colon_index = model.find(":") - slash_index = model.find("/") - - # Determine which delimiter comes first - if colon_index != -1 and (slash_index == -1 or colon_index < slash_index): - # The colon came first, so it's using the new syntax. - provider, model_name = model.split(":", 1) - elif slash_index != -1: - # Slash comes first, so it's the legacy syntax - warnings.warn( - f"Model format 'provider/model' is deprecated and will be removed in version 1.0. " - f"Please use 'provider:model' format instead. Got: '{model}'", - DeprecationWarning, - stacklevel=3, - ) - provider, model_name = model.split("/", 1) - else: - msg = f"Invalid model format. Expected 'provider:model' or 'provider/model', got '{model}'" - raise ValueError(msg) - - if not provider or not model_name: - msg = f"Invalid model format. Expected 'provider:model' or 'provider/model', got '{model}'" - raise ValueError(msg) - return cls.get_provider_enum(provider), model_name diff --git a/src/any_llm/providers/bedrock/bedrock.py b/src/any_llm/providers/bedrock/bedrock.py index 1210b5b1..7401a009 100644 --- a/src/any_llm/providers/bedrock/bedrock.py +++ b/src/any_llm/providers/bedrock/bedrock.py @@ -7,9 +7,10 @@ from pydantic import BaseModel +from any_llm.config import ClientConfig from any_llm.exceptions import MissingApiKeyError from any_llm.logging import logger -from any_llm.provider import ClientConfig, Provider +from any_llm.provider import Provider from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse from any_llm.types.model import Model from any_llm.utils.instructor import _convert_instructor_response diff --git a/src/any_llm/providers/gemini/base.py b/src/any_llm/providers/gemini/base.py index bff0fddd..517c2f12 100644 --- a/src/any_llm/providers/gemini/base.py +++ b/src/any_llm/providers/gemini/base.py @@ -4,8 +4,9 @@ from pydantic import BaseModel +from any_llm.config import ClientConfig from any_llm.exceptions import UnsupportedParameterError -from any_llm.provider import ClientConfig, Provider +from any_llm.provider import Provider from any_llm.types.completion import ( ChatCompletion, ChatCompletionChunk, diff --git a/src/any_llm/providers/gemini/gemini.py b/src/any_llm/providers/gemini/gemini.py index 7c71ab1e..00d90d55 100644 --- a/src/any_llm/providers/gemini/gemini.py +++ b/src/any_llm/providers/gemini/gemini.py @@ -2,8 +2,8 @@ from google import genai +from any_llm.config import ClientConfig from any_llm.exceptions import MissingApiKeyError -from any_llm.provider import ClientConfig from .base import GoogleProvider diff --git a/src/any_llm/providers/llamafile/llamafile.py b/src/any_llm/providers/llamafile/llamafile.py index d848a37c..1fb3d6b9 100644 --- a/src/any_llm/providers/llamafile/llamafile.py +++ b/src/any_llm/providers/llamafile/llamafile.py @@ -2,8 +2,8 @@ from collections.abc import AsyncIterator from typing import Any +from any_llm.config import ClientConfig from any_llm.exceptions import UnsupportedParameterError -from any_llm.provider import ClientConfig from any_llm.providers.openai.base import BaseOpenAIProvider from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams diff --git a/src/any_llm/providers/lmstudio/lmstudio.py b/src/any_llm/providers/lmstudio/lmstudio.py index b6f5ef6c..f457276d 100644 --- a/src/any_llm/providers/lmstudio/lmstudio.py +++ b/src/any_llm/providers/lmstudio/lmstudio.py @@ -1,6 +1,6 @@ import os -from any_llm.provider import ClientConfig +from any_llm.config import ClientConfig from any_llm.providers.openai.base import BaseOpenAIProvider # LM Studio has a python sdk, but per their docs they are compliant with OpenAI spec diff --git a/src/any_llm/providers/ollama/ollama.py b/src/any_llm/providers/ollama/ollama.py index 14007a1a..cbcad18b 100644 --- a/src/any_llm/providers/ollama/ollama.py +++ b/src/any_llm/providers/ollama/ollama.py @@ -6,7 +6,7 @@ from pydantic import BaseModel -from any_llm.provider import ClientConfig, Provider +from any_llm.provider import Provider MISSING_PACKAGES_ERROR = None try: @@ -27,6 +27,7 @@ from ollama import AsyncClient, Client # noqa: TC004 from ollama import ChatResponse as OllamaChatResponse + from any_llm.config import ClientConfig from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse from any_llm.types.model import Model diff --git a/src/any_llm/providers/perplexity/perplexity.py b/src/any_llm/providers/perplexity/perplexity.py index 096d5739..453e6cf9 100644 --- a/src/any_llm/providers/perplexity/perplexity.py +++ b/src/any_llm/providers/perplexity/perplexity.py @@ -1,6 +1,6 @@ import os -from any_llm.provider import ClientConfig +from any_llm.config import ClientConfig from any_llm.providers.openai.base import BaseOpenAIProvider diff --git a/src/any_llm/providers/sagemaker/sagemaker.py b/src/any_llm/providers/sagemaker/sagemaker.py index 711ced2e..87b9023b 100644 --- a/src/any_llm/providers/sagemaker/sagemaker.py +++ b/src/any_llm/providers/sagemaker/sagemaker.py @@ -7,9 +7,10 @@ from pydantic import BaseModel +from any_llm.config import ClientConfig from any_llm.exceptions import MissingApiKeyError from any_llm.logging import logger -from any_llm.provider import ClientConfig, Provider +from any_llm.provider import Provider from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse from any_llm.types.model import Model from any_llm.utils.instructor import _convert_instructor_response diff --git a/src/any_llm/providers/vertexai/vertexai.py b/src/any_llm/providers/vertexai/vertexai.py index b46e9f82..6b3849a5 100644 --- a/src/any_llm/providers/vertexai/vertexai.py +++ b/src/any_llm/providers/vertexai/vertexai.py @@ -1,8 +1,8 @@ import os from typing import TYPE_CHECKING +from any_llm.config import ClientConfig from any_llm.exceptions import MissingApiKeyError -from any_llm.provider import ClientConfig from any_llm.providers.gemini.base import GoogleProvider if TYPE_CHECKING: diff --git a/src/any_llm/utils/api.py b/src/any_llm/utils/api.py index 9600add8..1005f93f 100644 --- a/src/any_llm/utils/api.py +++ b/src/any_llm/utils/api.py @@ -3,7 +3,10 @@ from pydantic import BaseModel -from any_llm.provider import ClientConfig, Provider, ProviderFactory, ProviderName +from any_llm.config import ClientConfig +from any_llm.constants import ProviderName +from any_llm.factory import ProviderFactory +from any_llm.provider import Provider from any_llm.tools import prepare_tools from any_llm.types.completion import ChatCompletionMessage, CompletionParams diff --git a/tests/conftest.py b/tests/conftest.py index 7960497e..829b6920 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import pytest -from any_llm.provider import ProviderName +from any_llm.constants import ProviderName from tests.constants import INCLUDE_LOCAL_PROVIDERS, INCLUDE_NON_LOCAL_PROVIDERS, LOCAL_PROVIDERS diff --git a/tests/constants.py b/tests/constants.py index 43c2b80e..e99ea89f 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1,6 +1,6 @@ import os -from any_llm.provider import ProviderName +from any_llm.constants import ProviderName LOCAL_PROVIDERS = [ProviderName.LLAMACPP, ProviderName.OLLAMA, ProviderName.LMSTUDIO, ProviderName.LLAMAFILE] diff --git a/tests/docs/test_all.py b/tests/docs/test_all.py index 1466ced2..ffa3b200 100644 --- a/tests/docs/test_all.py +++ b/tests/docs/test_all.py @@ -14,6 +14,6 @@ def test_all_docs(doc_file: pathlib.Path, monkeypatch: Any) -> None: monkeypatch.setenv("MISTRAL_API_KEY", "test_key") with ( - patch("any_llm.provider.ProviderFactory.create_provider"), + patch("any_llm.factory.ProviderFactory.create_provider"), ): check_md_file(fpath=doc_file, memory=True) diff --git a/tests/integration/test_completion.py b/tests/integration/test_completion.py index 1d70432d..6fff4e23 100644 --- a/tests/integration/test_completion.py +++ b/tests/integration/test_completion.py @@ -10,7 +10,7 @@ from any_llm import ProviderName, acompletion from any_llm.exceptions import MissingApiKeyError -from any_llm.provider import ProviderFactory +from any_llm.factory import ProviderFactory from any_llm.types.completion import ChatCompletion, ChatCompletionMessage from tests.constants import EXPECTED_PROVIDERS, LOCAL_PROVIDERS diff --git a/tests/integration/test_embedding.py b/tests/integration/test_embedding.py index 4c0f5194..e22b295e 100644 --- a/tests/integration/test_embedding.py +++ b/tests/integration/test_embedding.py @@ -6,7 +6,7 @@ from any_llm import ProviderName, aembedding from any_llm.exceptions import MissingApiKeyError -from any_llm.provider import ProviderFactory +from any_llm.factory import ProviderFactory from any_llm.types.completion import CreateEmbeddingResponse from tests.constants import EXPECTED_PROVIDERS diff --git a/tests/integration/test_list_models.py b/tests/integration/test_list_models.py index a31be85b..21464d3b 100644 --- a/tests/integration/test_list_models.py +++ b/tests/integration/test_list_models.py @@ -5,8 +5,9 @@ from openai import APIConnectionError from any_llm import list_models +from any_llm.constants import ProviderName from any_llm.exceptions import MissingApiKeyError -from any_llm.provider import ProviderFactory, ProviderName +from any_llm.factory import ProviderFactory from any_llm.types.model import Model from tests.constants import EXPECTED_PROVIDERS, LOCAL_PROVIDERS diff --git a/tests/integration/test_reasoning.py b/tests/integration/test_reasoning.py index 73e69270..e2f5f853 100644 --- a/tests/integration/test_reasoning.py +++ b/tests/integration/test_reasoning.py @@ -7,7 +7,7 @@ from any_llm import ProviderName, acompletion from any_llm.exceptions import MissingApiKeyError -from any_llm.provider import ProviderFactory +from any_llm.factory import ProviderFactory from any_llm.types.completion import ChatCompletion, ChatCompletionChunk from tests.constants import EXPECTED_PROVIDERS, LOCAL_PROVIDERS diff --git a/tests/integration/test_response_format.py b/tests/integration/test_response_format.py index 6494acf2..7dbd853c 100644 --- a/tests/integration/test_response_format.py +++ b/tests/integration/test_response_format.py @@ -7,7 +7,7 @@ from any_llm import ProviderName, acompletion from any_llm.exceptions import MissingApiKeyError, UnsupportedParameterError -from any_llm.provider import ProviderFactory +from any_llm.factory import ProviderFactory from any_llm.types.completion import ChatCompletion from tests.constants import EXPECTED_PROVIDERS, LOCAL_PROVIDERS diff --git a/tests/integration/test_responses.py b/tests/integration/test_responses.py index d3f813df..47e64bee 100644 --- a/tests/integration/test_responses.py +++ b/tests/integration/test_responses.py @@ -6,7 +6,7 @@ from any_llm import ProviderName, aresponses from any_llm.exceptions import MissingApiKeyError -from any_llm.provider import ProviderFactory +from any_llm.factory import ProviderFactory from any_llm.types.responses import Response from tests.constants import EXPECTED_PROVIDERS, LOCAL_PROVIDERS diff --git a/tests/integration/test_streaming.py b/tests/integration/test_streaming.py index baf3cee3..d0a2fd0a 100644 --- a/tests/integration/test_streaming.py +++ b/tests/integration/test_streaming.py @@ -7,7 +7,7 @@ from any_llm import ProviderName, acompletion from any_llm.exceptions import MissingApiKeyError, UnsupportedParameterError -from any_llm.provider import ProviderFactory +from any_llm.factory import ProviderFactory from any_llm.types.completion import ChatCompletionChunk from tests.constants import EXPECTED_PROVIDERS, LOCAL_PROVIDERS diff --git a/tests/integration/test_tool.py b/tests/integration/test_tool.py index 1a837381..1e499b8c 100644 --- a/tests/integration/test_tool.py +++ b/tests/integration/test_tool.py @@ -7,7 +7,7 @@ from any_llm import ProviderName, acompletion from any_llm.exceptions import MissingApiKeyError -from any_llm.provider import ProviderFactory +from any_llm.factory import ProviderFactory from tests.constants import EXPECTED_PROVIDERS, LOCAL_PROVIDERS if TYPE_CHECKING: diff --git a/tests/unit/providers/test_anthropic_provider.py b/tests/unit/providers/test_anthropic_provider.py index 3a4b16a1..3f19917c 100644 --- a/tests/unit/providers/test_anthropic_provider.py +++ b/tests/unit/providers/test_anthropic_provider.py @@ -4,8 +4,8 @@ import pytest +from any_llm.config import ClientConfig from any_llm.exceptions import UnsupportedParameterError -from any_llm.provider import ClientConfig from any_llm.providers.anthropic.anthropic import AnthropicProvider from any_llm.providers.anthropic.utils import DEFAULT_MAX_TOKENS, REASONING_EFFORT_TO_THINKING_BUDGETS from any_llm.types.completion import CompletionParams diff --git a/tests/unit/providers/test_aws_provider.py b/tests/unit/providers/test_aws_provider.py index 1164acbe..ebba58bd 100644 --- a/tests/unit/providers/test_aws_provider.py +++ b/tests/unit/providers/test_aws_provider.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from unittest.mock import Mock, patch -from any_llm.provider import ClientConfig +from any_llm.config import ClientConfig from any_llm.providers.bedrock import BedrockProvider from any_llm.types.completion import CompletionParams diff --git a/tests/unit/providers/test_azure_provider.py b/tests/unit/providers/test_azure_provider.py index dee19bca..cef267ab 100644 --- a/tests/unit/providers/test_azure_provider.py +++ b/tests/unit/providers/test_azure_provider.py @@ -3,7 +3,7 @@ import pytest -from any_llm.provider import ClientConfig +from any_llm.config import ClientConfig from any_llm.providers.azure.azure import AzureProvider from any_llm.types.completion import CompletionParams diff --git a/tests/unit/providers/test_cerebras_provider.py b/tests/unit/providers/test_cerebras_provider.py index dde36bfb..0ba9a839 100644 --- a/tests/unit/providers/test_cerebras_provider.py +++ b/tests/unit/providers/test_cerebras_provider.py @@ -1,7 +1,7 @@ import pytest +from any_llm.config import ClientConfig from any_llm.exceptions import UnsupportedParameterError -from any_llm.provider import ClientConfig from any_llm.providers.cerebras.cerebras import CerebrasProvider diff --git a/tests/unit/providers/test_cohere_provider.py b/tests/unit/providers/test_cohere_provider.py index c849a51a..f09415b5 100644 --- a/tests/unit/providers/test_cohere_provider.py +++ b/tests/unit/providers/test_cohere_provider.py @@ -3,8 +3,8 @@ import pytest from pydantic import BaseModel +from any_llm.config import ClientConfig from any_llm.exceptions import UnsupportedParameterError -from any_llm.provider import ClientConfig from any_llm.providers.cohere.utils import _patch_messages from any_llm.types.completion import CompletionParams diff --git a/tests/unit/providers/test_google_provider.py b/tests/unit/providers/test_google_provider.py index c909ef87..0d47805b 100644 --- a/tests/unit/providers/test_google_provider.py +++ b/tests/unit/providers/test_google_provider.py @@ -5,8 +5,9 @@ import pytest from google.genai import types +from any_llm.config import ClientConfig from any_llm.exceptions import UnsupportedParameterError -from any_llm.provider import ClientConfig, Provider +from any_llm.provider import Provider from any_llm.providers.gemini import GeminiProvider from any_llm.providers.gemini.base import REASONING_EFFORT_TO_THINKING_BUDGETS from any_llm.providers.gemini.utils import _convert_response_to_response_dict diff --git a/tests/unit/providers/test_groq_provider.py b/tests/unit/providers/test_groq_provider.py index 9e12e322..0e435cf4 100644 --- a/tests/unit/providers/test_groq_provider.py +++ b/tests/unit/providers/test_groq_provider.py @@ -3,8 +3,8 @@ import pytest from pydantic import BaseModel +from any_llm.config import ClientConfig from any_llm.exceptions import UnsupportedParameterError -from any_llm.provider import ClientConfig from any_llm.types.completion import CompletionParams diff --git a/tests/unit/providers/test_huggingface_provider.py b/tests/unit/providers/test_huggingface_provider.py index 1420242b..e060b21f 100644 --- a/tests/unit/providers/test_huggingface_provider.py +++ b/tests/unit/providers/test_huggingface_provider.py @@ -2,7 +2,7 @@ from typing import Any from unittest.mock import patch -from any_llm.provider import ClientConfig +from any_llm.config import ClientConfig from any_llm.providers.huggingface.huggingface import HuggingfaceProvider from any_llm.types.completion import CompletionParams diff --git a/tests/unit/providers/test_llamafile.py b/tests/unit/providers/test_llamafile.py index ac32d69b..3f60246f 100644 --- a/tests/unit/providers/test_llamafile.py +++ b/tests/unit/providers/test_llamafile.py @@ -2,8 +2,8 @@ import pytest +from any_llm.config import ClientConfig from any_llm.exceptions import UnsupportedParameterError -from any_llm.provider import ClientConfig from any_llm.providers.llamafile.llamafile import LlamafileProvider from any_llm.providers.openai.base import BaseOpenAIProvider from any_llm.types.completion import CompletionParams diff --git a/tests/unit/providers/test_openai_base_provider.py b/tests/unit/providers/test_openai_base_provider.py index f86669cc..0397112d 100644 --- a/tests/unit/providers/test_openai_base_provider.py +++ b/tests/unit/providers/test_openai_base_provider.py @@ -1,6 +1,6 @@ from unittest.mock import MagicMock, patch -from any_llm.provider import ClientConfig +from any_llm.config import ClientConfig from any_llm.providers.openai.base import BaseOpenAIProvider from any_llm.types.model import Model diff --git a/tests/unit/providers/test_perplexity_provider.py b/tests/unit/providers/test_perplexity_provider.py index ababc7dc..2eb97d1a 100644 --- a/tests/unit/providers/test_perplexity_provider.py +++ b/tests/unit/providers/test_perplexity_provider.py @@ -1,6 +1,7 @@ import pytest -from any_llm.provider import ClientConfig, ProviderFactory +from any_llm.config import ClientConfig +from any_llm.factory import ProviderFactory from any_llm.providers.perplexity import PerplexityProvider diff --git a/tests/unit/providers/test_watsonx_provider.py b/tests/unit/providers/test_watsonx_provider.py index 8e4b872a..551bde5f 100644 --- a/tests/unit/providers/test_watsonx_provider.py +++ b/tests/unit/providers/test_watsonx_provider.py @@ -5,7 +5,7 @@ import pytest -from any_llm.provider import ClientConfig +from any_llm.config import ClientConfig from any_llm.providers.watsonx.watsonx import WatsonxProvider from any_llm.types.completion import CompletionParams diff --git a/tests/unit/providers/test_xai_provider.py b/tests/unit/providers/test_xai_provider.py index 2b94fc7e..4b52b3cf 100644 --- a/tests/unit/providers/test_xai_provider.py +++ b/tests/unit/providers/test_xai_provider.py @@ -4,7 +4,7 @@ import pytest -from any_llm.provider import ClientConfig +from any_llm.config import ClientConfig from any_llm.types.completion import ChatCompletion, CompletionParams diff --git a/tests/unit/test_api_config.py b/tests/unit/test_api_config.py index 7f8f6f03..21f6cbef 100644 --- a/tests/unit/test_api_config.py +++ b/tests/unit/test_api_config.py @@ -2,7 +2,8 @@ from unittest.mock import Mock, patch from any_llm import completion -from any_llm.provider import ClientConfig, ProviderName +from any_llm.config import ClientConfig +from any_llm.constants import ProviderName from any_llm.types.completion import CompletionParams diff --git a/tests/unit/test_completion.py b/tests/unit/test_completion.py index cb6b7211..2ca3225c 100644 --- a/tests/unit/test_completion.py +++ b/tests/unit/test_completion.py @@ -3,7 +3,10 @@ import pytest from any_llm import acompletion -from any_llm.provider import ClientConfig, Provider, ProviderFactory, ProviderName +from any_llm.config import ClientConfig +from any_llm.constants import ProviderName +from any_llm.factory import ProviderFactory +from any_llm.provider import Provider from any_llm.types.completion import ChatCompletionMessage, CompletionParams, Reasoning diff --git a/tests/unit/test_embedding.py b/tests/unit/test_embedding.py index b6a43c26..df650cb4 100644 --- a/tests/unit/test_embedding.py +++ b/tests/unit/test_embedding.py @@ -3,7 +3,8 @@ import pytest from any_llm import aembedding -from any_llm.provider import ProviderFactory, ProviderName +from any_llm.constants import ProviderName +from any_llm.factory import ProviderFactory from any_llm.types.completion import CreateEmbeddingResponse, Embedding, Usage diff --git a/tests/unit/test_model_syntax.py b/tests/unit/test_model_syntax.py index 858cb649..542f831d 100644 --- a/tests/unit/test_model_syntax.py +++ b/tests/unit/test_model_syntax.py @@ -4,8 +4,9 @@ import pytest from any_llm import completion +from any_llm.constants import ProviderName from any_llm.exceptions import UnsupportedProviderError -from any_llm.provider import ProviderFactory, ProviderName +from any_llm.factory import ProviderFactory def test_colon_syntax_valid() -> None: diff --git a/tests/unit/test_provider.py b/tests/unit/test_provider.py index 07247d9c..0a5773c3 100644 --- a/tests/unit/test_provider.py +++ b/tests/unit/test_provider.py @@ -5,8 +5,10 @@ import pytest +from any_llm.config import ClientConfig +from any_llm.constants import ProviderName from any_llm.exceptions import MissingApiKeyError, UnsupportedProviderError -from any_llm.provider import ClientConfig, ProviderFactory, ProviderName +from any_llm.factory import ProviderFactory def test_all_providers_in_enum() -> None: diff --git a/tests/unit/test_provider_pyproject_options.py b/tests/unit/test_provider_pyproject_options.py index f9f1df77..bacc0744 100644 --- a/tests/unit/test_provider_pyproject_options.py +++ b/tests/unit/test_provider_pyproject_options.py @@ -1,7 +1,7 @@ import tomllib from pathlib import Path -from any_llm.provider import ProviderName +from any_llm.constants import ProviderName def test_all_providers_have_pyproject_options() -> None: diff --git a/tests/unit/test_responses.py b/tests/unit/test_responses.py index d00709b6..23cf742b 100644 --- a/tests/unit/test_responses.py +++ b/tests/unit/test_responses.py @@ -3,7 +3,7 @@ import pytest from any_llm import aresponses -from any_llm.provider import ProviderName +from any_llm.constants import ProviderName @pytest.mark.asyncio