diff --git a/scripts/check_missing_api_keys.py b/scripts/check_missing_api_keys.py new file mode 100755 index 00000000..4fd2b29e --- /dev/null +++ b/scripts/check_missing_api_keys.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +""" +Script to check which providers are being skipped due to missing API keys. + +This script attempts to instantiate each provider and reports which ones +fail due to missing API keys, missing packages, or other issues. +""" + +import os +import sys +from pathlib import Path + +from any_llm.exceptions import MissingApiKeyError +from any_llm.provider import ProviderFactory, ProviderName + +src_path = Path(__file__).parent.parent / "src" +sys.path.insert(0, str(src_path)) + + +def check_provider_status(): + """Check the status of all providers and categorize them.""" + available_providers = [] + missing_api_keys = [] + missing_packages = [] + other_errors = [] + + print("Checking provider status...") + print("=" * 50) + + for provider_name in ProviderName: + try: + provider_class = ProviderFactory.get_provider_class(provider_name) + + if provider_class.MISSING_PACKAGES_ERROR is not None: + missing_packages.append( + { + "name": provider_name.value, + "error": str(provider_class.MISSING_PACKAGES_ERROR), + "env_var": provider_class.ENV_API_KEY_NAME, + } + ) + continue + + available_providers.append( + { + "name": provider_name.value, + "env_var": provider_class.ENV_API_KEY_NAME, + "api_key_set": bool(os.getenv(provider_class.ENV_API_KEY_NAME)), + } + ) + + except MissingApiKeyError as e: + missing_api_keys.append({"name": provider_name.value, "env_var": e.env_var_name, "error": str(e)}) + except ImportError as e: + missing_packages.append({"name": provider_name.value, "error": str(e), "env_var": "N/A"}) + except Exception as e: + other_errors.append({"name": provider_name.value, "error": str(e), "error_type": type(e).__name__}) + + return available_providers, missing_api_keys, missing_packages, other_errors + + +def print_results(available_providers, missing_api_keys, missing_packages, other_errors): + """Print formatted results of the provider status check.""" + + if available_providers: + print(f"āœ… Available Providers ({len(available_providers)}):") + for provider in sorted(available_providers, key=lambda x: x["name"]): + key_status = "šŸ”‘" if provider["api_key_set"] else "šŸ”“" + print(f" {key_status} {provider['name']} (env: {provider['env_var']})") + print() + + if missing_api_keys: + print(f"šŸ”‘ Missing API Keys ({len(missing_api_keys)}):") + for provider in sorted(missing_api_keys, key=lambda x: x["name"]): + print(f" āŒ {provider['name']} - Set {provider['env_var']}") + print() + + if missing_packages: + print(f"šŸ“¦ Missing Packages ({len(missing_packages)}):") + for provider in sorted(missing_packages, key=lambda x: x["name"]): + print(f" āŒ {provider['name']} - {provider['error']}") + print() + + if other_errors: + print(f"āš ļø Other Errors ({len(other_errors)}):") + for provider in sorted(other_errors, key=lambda x: x["name"]): + print(f" āŒ {provider['name']} ({provider['error_type']}) - {provider['error']}") + print() + + total_providers = len(list(ProviderName)) + available_count = len(available_providers) + missing_keys_count = len(missing_api_keys) + missing_packages_count = len(missing_packages) + other_errors_count = len(other_errors) + + print("šŸ“Š Summary:") + print(f" Total providers: {total_providers}") + print(f" Available: {available_count}") + print(f" Missing API keys: {missing_keys_count}") + print(f" Missing packages: {missing_packages_count}") + print(f" Other errors: {other_errors_count}") + + if missing_api_keys: + print("\nšŸ’” To fix missing API keys, set these environment variables:") + for provider in sorted(missing_api_keys, key=lambda x: x["name"]): + print(f" export {provider['env_var']}='your-api-key-here'") + + +def main(): + """Main function to run the provider status check.""" + if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help"]: + print(__doc__) + print("\nUsage:") + print(" python scripts/check_missing_api_keys.py") + print(" python scripts/check_missing_api_keys.py --help") + return + + try: + available, missing_keys, missing_packages, other_errors = check_provider_status() + print_results(available, missing_keys, missing_packages, other_errors) + except Exception as e: + print(f"Error running provider check: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/any_llm/provider.py b/src/any_llm/provider.py index dbad02aa..7cc562ee 100644 --- a/src/any_llm/provider.py +++ b/src/any_llm/provider.py @@ -153,6 +153,42 @@ def _verify_and_set_api_key(self, config: ClientConfig) -> ClientConfig: raise MissingApiKeyError(self.PROVIDER_NAME, self.ENV_API_KEY_NAME) return config + @staticmethod + @abstractmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + msg = "Subclasses must implement this method" + raise NotImplementedError(msg) + + @staticmethod + @abstractmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + msg = "Subclasses must implement this method" + raise NotImplementedError(msg) + + @staticmethod + @abstractmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + msg = "Subclasses must implement this method" + raise NotImplementedError(msg) + + @staticmethod + @abstractmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + msg = "Subclasses must implement this method" + raise NotImplementedError(msg) + + @staticmethod + @abstractmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + msg = "Subclasses must implement this method" + raise NotImplementedError(msg) + + @staticmethod + @abstractmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + msg = "Subclasses must implement this method" + raise NotImplementedError(msg) + @classmethod def get_provider_metadata(cls) -> ProviderMetadata: """Get provider metadata without requiring instantiation. diff --git a/src/any_llm/providers/anthropic/anthropic.py b/src/any_llm/providers/anthropic/anthropic.py index 699cd1b0..dc27900d 100644 --- a/src/any_llm/providers/anthropic/anthropic.py +++ b/src/any_llm/providers/anthropic/anthropic.py @@ -1,8 +1,8 @@ from collections.abc import AsyncIterator, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any from any_llm.provider import Provider -from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams +from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse from any_llm.types.model import Model MISSING_PACKAGES_ERROR = None @@ -18,6 +18,11 @@ except ImportError as e: MISSING_PACKAGES_ERROR = e +if TYPE_CHECKING: + from anthropic.pagination import SyncPage + from anthropic.types import Message + from anthropic.types.model_info import ModelInfo as AnthropicModelInfo + class AnthropicProvider(Provider): """ @@ -40,6 +45,39 @@ class AnthropicProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Anthropic API.""" + return _convert_params(params, **kwargs) + + @staticmethod + def _convert_completion_response(response: "Message") -> ChatCompletion: + """Convert Anthropic Message to OpenAI ChatCompletion format.""" + return _convert_response(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Anthropic streaming chunk to OpenAI ChatCompletionChunk format.""" + model_id = kwargs.get("model_id", "unknown") + return _create_openai_chunk_from_anthropic_chunk(response, model_id) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Anthropic does not support embeddings.""" + msg = "Anthropic does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Anthropic does not support embeddings.""" + msg = "Anthropic does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: "SyncPage[AnthropicModelInfo]") -> Sequence[Model]: + """Convert Anthropic models list to OpenAI format.""" + return _convert_models_list(response) + async def _stream_completion_async( self, client: "AsyncAnthropic", **kwargs: Any ) -> AsyncIterator[ChatCompletionChunk]: @@ -48,7 +86,7 @@ async def _stream_completion_async( **kwargs, ) as anthropic_stream: async for event in anthropic_stream: - yield _create_openai_chunk_from_anthropic_chunk(event, kwargs.get("model", "unknown")) + yield self._convert_completion_chunk_response(event, model_id=kwargs.get("model", "unknown")) async def acompletion( self, @@ -63,14 +101,14 @@ async def acompletion( ) kwargs["provider_name"] = self.PROVIDER_NAME - converted_kwargs = _convert_params(params, **kwargs) + converted_kwargs = self._convert_completion_params(params, **kwargs) if converted_kwargs.pop("stream", False): return self._stream_completion_async(client, **converted_kwargs) message = await client.messages.create(**converted_kwargs) - return _convert_response(message) + return self._convert_completion_response(message) def list_models(self, **kwargs: Any) -> Sequence[Model]: """List available models from Anthropic.""" @@ -80,4 +118,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: **(self.config.client_args if self.config.client_args else {}), ) models_list = client.models.list(**kwargs) - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) diff --git a/src/any_llm/providers/azure/azure.py b/src/any_llm/providers/azure/azure.py index 9c171b45..f3938aad 100644 --- a/src/any_llm/providers/azure/azure.py +++ b/src/any_llm/providers/azure/azure.py @@ -3,7 +3,7 @@ import os from typing import TYPE_CHECKING, Any, cast -from any_llm.provider import ClientConfig, Provider +from any_llm.provider import Provider MISSING_PACKAGES_ERROR = None try: @@ -20,12 +20,13 @@ MISSING_PACKAGES_ERROR = e if TYPE_CHECKING: - from collections.abc import AsyncIterable, AsyncIterator + from collections.abc import AsyncIterable, AsyncIterator, Sequence from azure.ai.inference import aio # noqa: TC004 from azure.ai.inference.models import ChatCompletions, EmbeddingsResult, StreamingChatCompletionsUpdate from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse + from any_llm.types.model import Model class AzureProvider(Provider): @@ -44,10 +45,6 @@ class AzureProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR - def __init__(self, config: ClientConfig) -> None: - """Initialize Azure provider.""" - super().__init__(config) - def _get_endpoint(self) -> str: """Get the Azure endpoint URL.""" if self.config.api_base: @@ -95,7 +92,7 @@ async def _stream_completion_async( ) async for chunk in azure_stream: - yield _create_openai_chunk_from_azure_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) async def acompletion( self, @@ -106,16 +103,7 @@ async def acompletion( api_version = os.getenv("AZURE_API_VERSION", kwargs.pop("api_version", "2024-02-15-preview")) client: aio.ChatCompletionsClient = self._create_chat_client_async(api_version) - if params.reasoning_effort == "auto": - params.reasoning_effort = None - - azure_response_format = None - if params.response_format: - azure_response_format = _convert_response_format(params.response_format) - - call_kwargs = params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format"}) - if azure_response_format: - call_kwargs["response_format"] = azure_response_format + call_kwargs = self._convert_completion_params(params, **kwargs) if params.stream: return self._stream_completion_async( @@ -123,7 +111,6 @@ async def acompletion( params.model_id, params.messages, **call_kwargs, - **kwargs, ) response: ChatCompletions = cast( @@ -132,11 +119,10 @@ async def acompletion( model=params.model_id, messages=params.messages, **call_kwargs, - **kwargs, ), ) - return _convert_response(response) + return self._convert_completion_response(response) async def aembedding( self, @@ -148,10 +134,59 @@ async def aembedding( api_version = os.getenv("AZURE_API_VERSION", kwargs.pop("api_version", "2024-02-15-preview")) client: aio.EmbeddingsClient = self._create_embeddings_client_async(api_version) + embedding_kwargs = self._convert_embedding_params({}, **kwargs) + response: EmbeddingsResult = await client.embed( model=model, input=inputs if isinstance(inputs, list) else [inputs], - **kwargs, + **embedding_kwargs, ) + return self._convert_embedding_response(response) + + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to Azure AI Inference format.""" + if params.reasoning_effort == "auto": + params.reasoning_effort = None + + azure_response_format = None + if params.response_format: + azure_response_format = _convert_response_format(params.response_format) + + call_kwargs = params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format"}) + if azure_response_format: + call_kwargs["response_format"] = azure_response_format + + call_kwargs.update(kwargs) + return call_kwargs + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Azure ChatCompletions response to OpenAI ChatCompletion format.""" + return _convert_response(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Azure StreamingChatCompletionsUpdate to OpenAI ChatCompletionChunk format.""" + return _create_openai_chunk_from_azure_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters to Azure AI Inference format.""" + embedding_kwargs = {} + if isinstance(params, dict): + embedding_kwargs.update(params) + embedding_kwargs.update(kwargs) + return embedding_kwargs + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Azure EmbeddingsResult to OpenAI CreateEmbeddingResponse format.""" return _create_openai_embedding_response_from_azure(response) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Azure list models response to OpenAI format. Not supported by Azure.""" + msg = "Azure provider does not support listing models" + raise NotImplementedError(msg) diff --git a/src/any_llm/providers/bedrock/bedrock.py b/src/any_llm/providers/bedrock/bedrock.py index 9f087860..5536b0fe 100644 --- a/src/any_llm/providers/bedrock/bedrock.py +++ b/src/any_llm/providers/bedrock/bedrock.py @@ -46,6 +46,48 @@ class BedrockProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for AWS API.""" + return _convert_params(params, kwargs) + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert AWS Bedrock response to OpenAI format.""" + return _convert_response(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert AWS Bedrock chunk response to OpenAI format.""" + model = kwargs.get("model", "") + chunk = _create_openai_chunk_from_aws_chunk(response, model) + if chunk is None: + msg = "Failed to convert AWS chunk to OpenAI format" + raise ValueError(msg) + return chunk + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for AWS Bedrock.""" + # For bedrock, we don't need to convert the params, just pass them through + return kwargs + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert AWS Bedrock embedding response to OpenAI format.""" + return _create_openai_embedding_response_from_aws( + response["embedding_data"], response["model"], response["total_tokens"] + ) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert AWS Bedrock list models response to OpenAI format.""" + models_list = response.get("modelSummaries", []) + # AWS doesn't provide a creation date for models + # AWS doesn't provide typing, but per https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/list_foundation_models.html + # the modelId is a string and will not be None + return [Model(id=model["modelId"], object="model", created=0, owned_by="aws") for model in models_list] + def __init__(self, config: ClientConfig) -> None: """Initialize AWS Bedrock provider.""" # This intentionally does not call super().__init__(config) because AWS has a different way of handling credentials @@ -104,7 +146,7 @@ def completion( **(self.config.client_args if self.config.client_args else {}), ) - completion_kwargs = _convert_params(params, kwargs) + completion_kwargs = self._convert_completion_params(params, **kwargs) if params.response_format: if params.stream: @@ -130,15 +172,13 @@ def completion( ) stream_generator = response_stream["stream"] return ( - chunk - for chunk in ( - _create_openai_chunk_from_aws_chunk(item, model=params.model_id) for item in stream_generator - ) - if chunk is not None + self._convert_completion_chunk_response(item, model=params.model_id) + for item in stream_generator + if _create_openai_chunk_from_aws_chunk(item, model=params.model_id) is not None ) response = client.converse(**completion_kwargs) - return _convert_response(response) + return self._convert_completion_response(response) async def aembedding( self, @@ -194,7 +234,8 @@ def embedding( total_tokens += response_body.get("inputTextTokenCount", 0) - return _create_openai_embedding_response_from_aws(embedding_data, model, total_tokens) + response_data = {"embedding_data": embedding_data, "model": model, "total_tokens": total_tokens} + return self._convert_embedding_response(response_data) def list_models(self, **kwargs: Any) -> Sequence[Model]: """ @@ -206,8 +247,5 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: region_name=self.region_name, **(self.config.client_args if self.config.client_args else {}), ) # type: ignore[no-untyped-call] - models_list = client.list_foundation_models(**kwargs).get("modelSummaries", []) - # AWS doesn't provide a creation date for models - # AWS doesn't provide typing, but per https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/list_foundation_models.html - # the modelId is a string and will not be None - return [Model(id=model["modelId"], object="model", created=0, owned_by="aws") for model in models_list] + response = client.list_foundation_models(**kwargs) + return self._convert_list_models_response(response) diff --git a/src/any_llm/providers/cerebras/cerebras.py b/src/any_llm/providers/cerebras/cerebras.py index 32d20e50..2ede3352 100644 --- a/src/any_llm/providers/cerebras/cerebras.py +++ b/src/any_llm/providers/cerebras/cerebras.py @@ -5,7 +5,7 @@ from any_llm.exceptions import UnsupportedParameterError from any_llm.provider import Provider -from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams +from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse from any_llm.types.model import Model MISSING_PACKAGES_ERROR = None @@ -39,6 +39,46 @@ class CerebrasProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Cerebras API.""" + # Cerebras does not support providing reasoning effort + converted_params = params.model_dump(exclude_none=True, exclude={"model_id", "messages", "stream"}) + if converted_params.get("reasoning_effort") == "auto": + converted_params.pop("reasoning_effort") + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Cerebras response to OpenAI format.""" + return _convert_response(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Cerebras chunk response to OpenAI format.""" + if isinstance(response, ChatChunkResponse): + return _create_openai_chunk_from_cerebras_chunk(response) + msg = f"Unsupported chunk type: {type(response)}" + raise ValueError(msg) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Cerebras.""" + msg = "Cerebras does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Cerebras embedding response to OpenAI format.""" + msg = "Cerebras does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Cerebras list models response to OpenAI format.""" + return _convert_models_list(response) + async def _stream_completion_async( self, model: str, @@ -62,11 +102,7 @@ async def _stream_completion_async( ) async for chunk in cast("cerebras.AsyncStream[ChatCompletion]", cerebras_stream): - if isinstance(chunk, ChatChunkResponse): - yield _create_openai_chunk_from_cerebras_chunk(chunk) - else: - msg = f"Unsupported chunk type: {type(chunk)}" - raise ValueError(msg) + yield self._convert_completion_chunk_response(chunk) async def acompletion( self, @@ -75,22 +111,6 @@ async def acompletion( ) -> ChatCompletion | AsyncIterator[ChatCompletionChunk]: """Create a chat completion using Cerebras with instructor support for structured outputs.""" - # Cerebras does not support providing reasoning effort - if params.reasoning_effort == "auto": - params.reasoning_effort = None - - if params.stream: - return self._stream_completion_async( - params.model_id, - params.messages, - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "stream"}), - **kwargs, - ) - - client = cerebras.AsyncCerebras( - api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {}) - ) - if params.response_format: # See https://inference-docs.cerebras.ai/capabilities/structured-outputs for guide to creating schema if isinstance(params.response_format, type) and issubclass(params.response_format, BaseModel): @@ -103,11 +123,23 @@ async def acompletion( }, } + completion_kwargs = self._convert_completion_params(params, **kwargs) + + if params.stream: + return self._stream_completion_async( + params.model_id, + params.messages, + **completion_kwargs, + ) + + client = cerebras.AsyncCerebras( + api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {}) + ) + response = await client.chat.completions.create( model=params.model_id, messages=params.messages, - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "stream"}), - **kwargs, + **completion_kwargs, ) if hasattr(response, "model_dump"): @@ -116,7 +148,7 @@ async def acompletion( msg = "Streaming responses are not supported in this context" raise ValueError(msg) - return _convert_response(response_data) + return self._convert_completion_response(response_data) def list_models(self, **kwargs: Any) -> Sequence[Model]: """ @@ -126,4 +158,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {}) ) models_list = client.models.list(**kwargs) - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) diff --git a/src/any_llm/providers/cohere/cohere.py b/src/any_llm/providers/cohere/cohere.py index 14faa19e..39a23a0c 100644 --- a/src/any_llm/providers/cohere/cohere.py +++ b/src/any_llm/providers/cohere/cohere.py @@ -5,7 +5,7 @@ from any_llm.exceptions import UnsupportedParameterError from any_llm.provider import Provider -from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams +from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse from any_llm.types.model import Model MISSING_PACKAGES_ERROR = None @@ -39,6 +39,47 @@ class CohereProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Cohere API.""" + # Cohere does not support providing reasoning effort + converted_params = params.model_dump( + exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"} + ) + if converted_params.get("reasoning_effort") == "auto": + converted_params.pop("reasoning_effort") + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any, **kwargs: Any) -> ChatCompletion: + """Convert Cohere response to OpenAI format.""" + # We need the model parameter for conversion + model = kwargs.get("model", getattr(response, "model", "cohere-model")) + return _convert_response(response, model) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Cohere chunk response to OpenAI format.""" + return _create_openai_chunk_from_cohere_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Cohere.""" + msg = "Cohere does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Cohere embedding response to OpenAI format.""" + msg = "Cohere does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Cohere list models response to OpenAI format.""" + return _convert_models_list(response) + async def _stream_completion_async( self, model: str, messages: list[dict[str, Any]], **kwargs: Any ) -> AsyncIterator[ChatCompletionChunk]: @@ -53,7 +94,7 @@ async def _stream_completion_async( ) async for chunk in cohere_stream: - yield _create_openai_chunk_from_cohere_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) @staticmethod def _preprocess_response_format(response_format: type[BaseModel] | dict[str, Any]) -> dict[str, Any]: @@ -74,9 +115,6 @@ def _preprocess_response_format(response_format: type[BaseModel] | dict[str, Any async def acompletion( self, params: CompletionParams, **kwargs: Any ) -> ChatCompletion | AsyncIterator[ChatCompletionChunk]: - if params.reasoning_effort == "auto": - params.reasoning_effort = None - if params.response_format is not None: kwargs["response_format"] = self._preprocess_response_format(params.response_format) if params.stream and params.response_format is not None: @@ -86,14 +124,14 @@ async def acompletion( msg = "parallel_tool_calls" raise UnsupportedParameterError(msg, self.PROVIDER_NAME) + completion_kwargs = self._convert_completion_params(params, **kwargs) patched_messages = _patch_messages(params.messages) if params.stream: return self._stream_completion_async( params.model_id, patched_messages, - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), - **kwargs, + **completion_kwargs, ) client = cohere.AsyncClientV2( @@ -104,11 +142,10 @@ async def acompletion( response = await client.chat( model=params.model_id, messages=patched_messages, # type: ignore[arg-type] - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "stream", "response_format"}), - **kwargs, + **completion_kwargs, ) - return _convert_response(response, params.model_id) + return self._convert_completion_response(response, model=params.model_id) def list_models(self, **kwargs: Any) -> Sequence[Model]: """ @@ -118,4 +155,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {}) ) model_list = client.models.list(**kwargs) - return _convert_models_list(model_list) + return self._convert_list_models_response(model_list) diff --git a/src/any_llm/providers/gemini/base.py b/src/any_llm/providers/gemini/base.py index 1c3a2ce2..3b749b2e 100644 --- a/src/any_llm/providers/gemini/base.py +++ b/src/any_llm/providers/gemini/base.py @@ -1,6 +1,6 @@ from abc import abstractmethod from collections.abc import AsyncIterator, Sequence -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, cast from pydantic import BaseModel @@ -57,6 +57,111 @@ class GoogleProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Google API.""" + base_kwargs = params.model_dump( + exclude_none=True, + exclude={ + "model_id", + "messages", + "response_format", + "stream", + "tools", + "tool_choice", + "reasoning_effort", + "max_tokens", + }, + ) + if params.max_tokens is not None: + base_kwargs["max_output_tokens"] = params.max_tokens + base_kwargs.update(kwargs) + return base_kwargs + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Google response data to OpenAI ChatCompletion format.""" + # Expect response to be a tuple of (response_dict, model_id) + response_dict, model_id = response + choices_out: list[Choice] = [] + for i, choice_item in enumerate(response_dict.get("choices", [])): + message_dict: dict[str, Any] = choice_item.get("message", {}) + tool_calls: list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageToolCall] | None = None + if message_dict.get("tool_calls"): + tool_calls_list: list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageToolCall] = [] + for tc in message_dict["tool_calls"]: + tool_calls_list.append( + ChatCompletionMessageFunctionToolCall( + id=tc.get("id"), + type="function", + function=Function( + name=tc["function"]["name"], + arguments=tc["function"]["arguments"], + ), + ) + ) + tool_calls = tool_calls_list + + reasoning_content = message_dict.get("reasoning") + message = ChatCompletionMessage( + role="assistant", + content=message_dict.get("content"), + tool_calls=tool_calls, + reasoning=Reasoning(content=reasoning_content) if reasoning_content else None, + ) + from typing import Literal + + choices_out.append( + Choice( + index=i, + finish_reason=cast( + "Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call']", + choice_item.get("finish_reason", "stop"), + ), + message=message, + ) + ) + + usage_dict = response_dict.get("usage", {}) + usage = CompletionUsage( + prompt_tokens=usage_dict.get("prompt_tokens", 0), + completion_tokens=usage_dict.get("completion_tokens", 0), + total_tokens=usage_dict.get("total_tokens", 0), + ) + + return ChatCompletion( + id=response_dict.get("id", ""), + model=model_id, + created=response_dict.get("created", 0), + object="chat.completion", + choices=choices_out, + usage=usage, + ) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Google chunk response to OpenAI format.""" + return _create_openai_chunk_from_google_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Google API.""" + converted_params = {"contents": params} + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Google embedding response to OpenAI format.""" + # We need the model parameter for conversion + model = response.get("model", "google-model") + return _create_openai_embedding_response_from_google(model, response["result"]) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Google list models response to OpenAI format.""" + return _convert_models_list(response) + @abstractmethod def _get_client(self, config: ClientConfig) -> "genai.Client": """Get the appropriate client for this provider implementation.""" @@ -68,13 +173,14 @@ async def aembedding( **kwargs: Any, ) -> CreateEmbeddingResponse: client = self._get_client(self.config) + embedding_kwargs = self._convert_embedding_params(inputs, **kwargs) result = await client.aio.models.embed_content( model=model, - contents=inputs, # type: ignore[arg-type] - **kwargs, + **embedding_kwargs, ) - return _create_openai_embedding_response_from_google(model, result) + response_data = {"model": model, "result": result} + return self._convert_embedding_response(response_data) async def acompletion( self, @@ -105,24 +211,7 @@ async def acompletion( stream = bool(params.stream) response_format = params.response_format - base_kwargs = params.model_dump( - exclude_none=True, - exclude={ - "model_id", - "messages", - "response_format", - "stream", - "tools", - "tool_choice", - "reasoning_effort", - "max_tokens", - }, - ) - - if params.max_tokens is not None: - base_kwargs["max_output_tokens"] = params.max_tokens - - base_kwargs.update(kwargs) + base_kwargs = self._convert_completion_params(params, **kwargs) generation_config = types.GenerateContentConfig(**base_kwargs) if isinstance(response_format, type) and issubclass(response_format, BaseModel): generation_config.response_mime_type = "application/json" @@ -142,7 +231,7 @@ async def acompletion( async def _stream() -> AsyncIterator[ChatCompletionChunk]: async for chunk in response_stream: - yield _create_openai_chunk_from_google_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) return _stream() @@ -153,62 +242,10 @@ async def _stream() -> AsyncIterator[ChatCompletionChunk]: ) response_dict = _convert_response_to_response_dict(response) - - choices_out: list[Choice] = [] - for i, choice_item in enumerate(response_dict.get("choices", [])): - message_dict: dict[str, Any] = choice_item.get("message", {}) - tool_calls: list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageToolCall] | None = None - if message_dict.get("tool_calls"): - tool_calls_list: list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageToolCall] = [] - for tc in message_dict["tool_calls"]: - tool_calls_list.append( - ChatCompletionMessageFunctionToolCall( - id=tc.get("id"), - type="function", - function=Function( - name=tc["function"]["name"], - arguments=tc["function"]["arguments"], - ), - ) - ) - tool_calls = tool_calls_list - - reasoning_content = message_dict.get("reasoning") - message = ChatCompletionMessage( - role="assistant", - content=message_dict.get("content"), - tool_calls=tool_calls, - reasoning=Reasoning(content=reasoning_content) if reasoning_content else None, - ) - choices_out.append( - Choice( - index=i, - finish_reason=cast( - "Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call']", - choice_item.get("finish_reason", "stop"), - ), - message=message, - ) - ) - - usage_dict = response_dict.get("usage", {}) - usage = CompletionUsage( - prompt_tokens=usage_dict.get("prompt_tokens", 0), - completion_tokens=usage_dict.get("completion_tokens", 0), - total_tokens=usage_dict.get("total_tokens", 0), - ) - - return ChatCompletion( - id=response_dict.get("id", ""), - model=params.model_id, - created=response_dict.get("created", 0), - object="chat.completion", - choices=choices_out, - usage=usage, - ) + return self._convert_completion_response((response_dict, params.model_id)) def list_models(self, **kwargs: Any) -> Sequence[Model]: """Fetch available models from the /v1/models endpoint.""" client = self._get_client(self.config) models_list = client.models.list(**kwargs) - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) diff --git a/src/any_llm/providers/groq/groq.py b/src/any_llm/providers/groq/groq.py index 222efa7c..27b5665c 100644 --- a/src/any_llm/providers/groq/groq.py +++ b/src/any_llm/providers/groq/groq.py @@ -9,6 +9,9 @@ from any_llm.provider import Provider from any_llm.types.responses import Response, ResponseStreamEvent +if TYPE_CHECKING: + from any_llm.types.completion import CreateEmbeddingResponse + MISSING_PACKAGES_ERROR = None try: import groq @@ -50,6 +53,43 @@ class GroqProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Groq API.""" + # Groq does not support providing reasoning effort + converted_params = params.model_dump(exclude_none=True, exclude={"model_id", "messages"}) + if converted_params.get("reasoning_effort") == "auto": + converted_params.pop("reasoning_effort") + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Groq response to OpenAI format.""" + return to_chat_completion(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Groq chunk response to OpenAI format.""" + return _create_openai_chunk_from_groq_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Groq.""" + msg = "Groq does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Groq embedding response to OpenAI format.""" + msg = "Groq does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Groq list models response to OpenAI format.""" + return _convert_models_list(response) + async def _stream_async_completion( self, client: groq.AsyncGroq, params: CompletionParams, **kwargs: Any ) -> AsyncIterator[ChatCompletionChunk]: @@ -57,16 +97,16 @@ async def _stream_async_completion( msg = "stream and response_format" raise UnsupportedParameterError(msg, self.PROVIDER_NAME) + completion_kwargs = self._convert_completion_params(params, **kwargs) stream: GroqAsyncStream[GroqChatCompletionChunk] = await client.chat.completions.create( model=params.model_id, messages=cast("Any", params.messages), - **params.model_dump(exclude_none=True, exclude={"model_id", "messages"}), - **kwargs, + **completion_kwargs, ) async def _stream() -> AsyncIterator[ChatCompletionChunk]: async for chunk in stream: - yield _create_openai_chunk_from_groq_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) return _stream() @@ -78,9 +118,6 @@ async def acompletion( api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {}) ) - if params.reasoning_effort == "auto": - params.reasoning_effort = None - if params.response_format: if isinstance(params.response_format, type) and issubclass(params.response_format, BaseModel): kwargs["response_format"] = { @@ -93,6 +130,8 @@ async def acompletion( else: kwargs["response_format"] = params.response_format + completion_kwargs = self._convert_completion_params(params, **kwargs) + if params.stream: return await self._stream_async_completion( client, @@ -102,11 +141,10 @@ async def acompletion( response: GroqChatCompletion = await client.chat.completions.create( model=params.model_id, messages=cast("Any", params.messages), - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), - **kwargs, + **completion_kwargs, ) - return to_chat_completion(response) + return self._convert_completion_response(response) async def aresponses( self, model: str, input_data: Any, **kwargs: Any @@ -142,4 +180,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: """ client = groq.Groq(api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {})) models_list = client.models.list(**kwargs) - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) diff --git a/src/any_llm/providers/huggingface/huggingface.py b/src/any_llm/providers/huggingface/huggingface.py index 1142ebf4..c62da8e5 100644 --- a/src/any_llm/providers/huggingface/huggingface.py +++ b/src/any_llm/providers/huggingface/huggingface.py @@ -9,6 +9,7 @@ Choice, CompletionParams, CompletionUsage, + CreateEmbeddingResponse, ) from any_llm.types.model import Model @@ -47,6 +48,42 @@ class HuggingfaceProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for HuggingFace API.""" + return _convert_params(params, **kwargs) + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert HuggingFace response to OpenAI format.""" + # If it's already our ChatCompletion type, return it + if isinstance(response, ChatCompletion): + return response + # Otherwise, validate it as our type + return ChatCompletion.model_validate(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert HuggingFace chunk response to OpenAI format.""" + return _create_openai_chunk_from_huggingface_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for HuggingFace.""" + msg = "HuggingFace does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert HuggingFace embedding response to OpenAI format.""" + msg = "HuggingFace does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert HuggingFace list models response to OpenAI format.""" + return _convert_models_list(response) + async def _stream_completion_async( self, client: "AsyncInferenceClient", @@ -56,7 +93,7 @@ async def _stream_completion_async( response: AsyncIterator[HuggingFaceChatCompletionStreamOutput] = await client.chat_completion(**kwargs) async for chunk in response: - yield _create_openai_chunk_from_huggingface_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) def _stream_completion( self, @@ -68,7 +105,7 @@ def _stream_completion( **kwargs, ) for chunk in response: - yield _create_openai_chunk_from_huggingface_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) async def acompletion( self, @@ -82,7 +119,7 @@ async def acompletion( **(self.config.client_args if self.config.client_args else {}), ) - converted_kwargs = _convert_params(params, **kwargs) + converted_kwargs = self._convert_completion_params(params, **kwargs) if params.stream: converted_kwargs["stream"] = True @@ -131,7 +168,7 @@ def completion( **(self.config.client_args if self.config.client_args else {}), ) - converted_kwargs = _convert_params(params, **kwargs) + converted_kwargs = self._convert_completion_params(params, **kwargs) if params.stream: converted_kwargs["stream"] = True @@ -181,4 +218,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: if kwargs.get("limit") is None: kwargs["limit"] = 20 models_list = client.list_models(**kwargs) - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) diff --git a/src/any_llm/providers/mistral/mistral.py b/src/any_llm/providers/mistral/mistral.py index ee69c115..8f809f68 100644 --- a/src/any_llm/providers/mistral/mistral.py +++ b/src/any_llm/providers/mistral/mistral.py @@ -48,13 +48,54 @@ class MistralProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Mistral API.""" + # Mistral does not support providing reasoning effort + converted_params = params.model_dump( + exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"} + ) + if converted_params.get("reasoning_effort") == "auto": + converted_params.pop("reasoning_effort") + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Mistral response to OpenAI format.""" + # We need the model parameter for conversion + model = getattr(response, "model", "mistral-model") + return _create_mistral_completion_from_response(response_data=response, model=model) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Mistral chunk response to OpenAI format.""" + return _create_openai_chunk_from_mistral_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Mistral.""" + converted_params = {"inputs": params} + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Mistral embedding response to OpenAI format.""" + return _create_openai_embedding_response_from_mistral(response) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Mistral list models response to OpenAI format.""" + return _convert_models_list(response) + async def _stream_completion_async( self, client: Mistral, model: str, messages: list[dict[str, Any]], **kwargs: Any ) -> AsyncIterator[ChatCompletionChunk]: mistral_stream = await client.chat.stream_async(model=model, messages=messages, **kwargs) # type: ignore[arg-type] async for event in mistral_stream: - yield _create_openai_chunk_from_mistral_chunk(event) + yield self._convert_completion_chunk_response(event) async def acompletion( self, params: CompletionParams, **kwargs: Any @@ -77,26 +118,23 @@ async def acompletion( **(self.config.client_args if self.config.client_args else {}), ) + completion_kwargs = self._convert_completion_params(params, **kwargs) + if params.stream: return self._stream_completion_async( client, params.model_id, patched_messages, - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), - **kwargs, + **completion_kwargs, ) response = await client.chat.complete_async( model=params.model_id, messages=patched_messages, # type: ignore[arg-type] - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), - **kwargs, + **completion_kwargs, ) - return _create_mistral_completion_from_response( - response_data=response, - model=params.model_id, - ) + return self._convert_completion_response(response) async def aembedding( self, @@ -109,13 +147,13 @@ async def aembedding( server_url=self.config.api_base, **(self.config.client_args if self.config.client_args else {}), ) + embedding_kwargs = self._convert_embedding_params(inputs, **kwargs) result: EmbeddingResponse = await client.embeddings.create_async( model=model, - inputs=inputs, - **kwargs, + **embedding_kwargs, ) - return _create_openai_embedding_response_from_mistral(result) + return self._convert_embedding_response(result) def list_models(self, **kwargs: Any) -> Sequence[Model]: """ @@ -127,4 +165,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: **(self.config.client_args if self.config.client_args else {}), ) models_list = client.models.list(**kwargs) - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) diff --git a/src/any_llm/providers/ollama/ollama.py b/src/any_llm/providers/ollama/ollama.py index b8e544f0..fd0f5ad8 100644 --- a/src/any_llm/providers/ollama/ollama.py +++ b/src/any_llm/providers/ollama/ollama.py @@ -53,6 +53,46 @@ class OllamaProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Ollama API.""" + # Ollama does not support providing reasoning effort + converted_params = params.model_dump( + exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"} + ) + if converted_params.get("reasoning_effort") == "auto": + converted_params.pop("reasoning_effort") + converted_params.update(kwargs) + converted_params["num_ctx"] = converted_params.get("num_ctx", 32000) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Ollama response to OpenAI format.""" + return _create_chat_completion_from_ollama_response(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Ollama chunk response to OpenAI format.""" + return _create_openai_chunk_from_ollama_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Ollama.""" + converted_params = {"input": params} + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Ollama embedding response to OpenAI format.""" + return _create_openai_embedding_response_from_ollama(response) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Ollama list models response to OpenAI format.""" + return _convert_models_list(response) + def __init__(self, config: ClientConfig) -> None: """We don't use the Provider init because by default we don't require an API key.""" self._verify_no_missing_packages() @@ -106,7 +146,7 @@ async def _stream_completion_async( options=kwargs, ) async for chunk in response: - yield _create_openai_chunk_from_ollama_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) async def acompletion( self, @@ -152,33 +192,25 @@ async def acompletion( cleaned_messages.append(cleaned_message) - if params.reasoning_effort == "auto": - params.reasoning_effort = None + completion_kwargs = self._convert_completion_params(params, **kwargs) - kwargs = { - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), - **kwargs, - } - - kwargs["num_ctx"] = kwargs.get("num_ctx", 32000) - - if params.reasoning_effort is not None: - kwargs["think"] = True + if completion_kwargs.get("reasoning_effort") is not None: + completion_kwargs["think"] = True client = AsyncClient(host=self.url, **(self.config.client_args if self.config.client_args else {})) if params.stream: - return self._stream_completion_async(client, params.model_id, cleaned_messages, **kwargs) + return self._stream_completion_async(client, params.model_id, cleaned_messages, **completion_kwargs) response: OllamaChatResponse = await client.chat( model=params.model_id, - tools=kwargs.pop("tools", None), - think=kwargs.pop("think", None), + tools=completion_kwargs.pop("tools", None), + think=completion_kwargs.pop("think", None), messages=cleaned_messages, format=output_format, - options=kwargs, + options=completion_kwargs, ) - return _create_chat_completion_from_ollama_response(response) + return self._convert_completion_response(response) async def aembedding( self, @@ -189,12 +221,12 @@ async def aembedding( """Generate embeddings using Ollama.""" client = AsyncClient(host=self.url, **(self.config.client_args if self.config.client_args else {})) + embedding_kwargs = self._convert_embedding_params(inputs, **kwargs) response = await client.embed( model=model, - input=inputs, - **kwargs, + **embedding_kwargs, ) - return _create_openai_embedding_response_from_ollama(response) + return self._convert_embedding_response(response) def list_models(self, **kwargs: Any) -> Sequence[Model]: """ @@ -202,4 +234,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: """ client = Client(host=self.url, **(self.config.client_args if self.config.client_args else {})) models_list = client.list(**kwargs) - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) diff --git a/src/any_llm/providers/openai/base.py b/src/any_llm/providers/openai/base.py index 1a799a34..a661ca50 100644 --- a/src/any_llm/providers/openai/base.py +++ b/src/any_llm/providers/openai/base.py @@ -38,6 +38,68 @@ class BaseOpenAIProvider(Provider, ABC): _DEFAULT_REASONING_EFFORT: Literal["minimal", "low", "medium", "high", "auto"] | None = None + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for OpenAI API.""" + converted_params = params.model_dump(exclude_none=True, exclude={"model_id", "messages"}) + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert OpenAI response to OpenAI format (passthrough).""" + if isinstance(response, OpenAIChatCompletion): + return _convert_chat_completion(response) + # If it's already our ChatCompletion type, return it + if isinstance(response, ChatCompletion): + return response + # Otherwise, validate it as our type + return ChatCompletion.model_validate(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert OpenAI chunk response to OpenAI format (passthrough).""" + if isinstance(response, OpenAIChatCompletionChunk): + if not isinstance(response.created, int): + logger.warning( + "API returned an unexpected created type: %s. Setting to int.", + type(response.created), + ) + response.created = int(response.created) + normalized_chunk = _normalize_openai_dict_response(response.model_dump()) + return ChatCompletionChunk.model_validate(normalized_chunk) + # If it's already our ChatCompletionChunk type, return it + if isinstance(response, ChatCompletionChunk): + return response + # Otherwise, validate it as our type + return ChatCompletionChunk.model_validate(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for OpenAI API.""" + converted_params = {"input": params} + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert OpenAI embedding response to OpenAI format (passthrough).""" + if isinstance(response, CreateEmbeddingResponse): + return response + return CreateEmbeddingResponse.model_validate(response) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert OpenAI list models response to OpenAI format (passthrough).""" + if hasattr(response, "data"): + # Validate each model in the data + return [Model.model_validate(model) if not isinstance(model, Model) else model for model in response.data] + # If it's already a sequence of our Model type, return it + if isinstance(response, (list, tuple)) and all(isinstance(item, Model) for item in response): + return response + # Otherwise, validate each item + return [Model.model_validate(item) if not isinstance(item, Model) else item for item in response] + def _get_client(self, sync: bool = False) -> AsyncOpenAI | OpenAI: _client_class = OpenAI if sync else AsyncOpenAI return _client_class( @@ -51,19 +113,11 @@ def _convert_completion_response_async( ) -> ChatCompletion | AsyncIterator[ChatCompletionChunk]: """Convert an OpenAI completion response to an AnyLLM completion response.""" if isinstance(response, OpenAIChatCompletion): - return _convert_chat_completion(response) + return self._convert_completion_response(response) async def chunk_iterator() -> AsyncIterator[ChatCompletionChunk]: async for chunk in response: - if not isinstance(chunk.created, int): - logger.warning( - "API returned an unexpected created type: %s. Setting to int.", - type(chunk.created), - ) - chunk.created = int(chunk.created) - - normalized_chunk = _normalize_openai_dict_response(chunk.model_dump()) - yield ChatCompletionChunk.model_validate(normalized_chunk) + yield self._convert_completion_chunk_response(chunk) return chunk_iterator() @@ -75,23 +129,23 @@ async def acompletion( if params.reasoning_effort == "auto": params.reasoning_effort = self._DEFAULT_REASONING_EFFORT + completion_kwargs = self._convert_completion_params(params, **kwargs) + if params.response_format: if params.stream: msg = "stream is not supported for response_format" raise ValueError(msg) - + completion_kwargs.pop("stream") response = await client.chat.completions.parse( model=params.model_id, messages=cast("Any", params.messages), - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "stream"}), - **kwargs, + **completion_kwargs, ) else: response = await client.chat.completions.create( model=params.model_id, messages=cast("Any", params.messages), - **params.model_dump(exclude_none=True, exclude={"model_id", "messages"}), - **kwargs, + **completion_kwargs, ) return self._convert_completion_response_async(response) @@ -124,11 +178,13 @@ async def aembedding( client = cast("AsyncOpenAI", self._get_client()) - return await client.embeddings.create( - model=model, - input=inputs, - dimensions=kwargs.get("dimensions", NOT_GIVEN), - **kwargs, + embedding_kwargs = self._convert_embedding_params(inputs, **kwargs) + return self._convert_embedding_response( + await client.embeddings.create( + model=model, + dimensions=kwargs.get("dimensions", NOT_GIVEN), + **embedding_kwargs, + ) ) def list_models(self, **kwargs: Any) -> Sequence[Model]: @@ -139,4 +195,5 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: message = f"{self.PROVIDER_NAME} does not support listing models." raise NotImplementedError(message) client = cast("OpenAI", self._get_client(sync=True)) - return client.models.list(**kwargs).data + response = client.models.list(**kwargs) + return self._convert_list_models_response(response) diff --git a/src/any_llm/providers/sagemaker/sagemaker.py b/src/any_llm/providers/sagemaker/sagemaker.py index 06e98ad7..472e0fa4 100644 --- a/src/any_llm/providers/sagemaker/sagemaker.py +++ b/src/any_llm/providers/sagemaker/sagemaker.py @@ -2,7 +2,7 @@ import functools import json import os -from collections.abc import AsyncIterator, Callable, Iterator +from collections.abc import AsyncIterator, Callable, Iterator, Sequence from typing import Any from pydantic import BaseModel @@ -11,6 +11,7 @@ from any_llm.logging import logger from any_llm.provider import ClientConfig, Provider from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse +from any_llm.types.model import Model from any_llm.utils.instructor import _convert_instructor_response MISSING_PACKAGES_ERROR = None @@ -44,6 +45,44 @@ class SagemakerProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for SageMaker API.""" + return _convert_params(params, kwargs) + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert SageMaker response to OpenAI format.""" + model = response.get("model", "") + return _convert_response(response, model) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert SageMaker chunk response to OpenAI format.""" + model = kwargs.get("model", "") + chunk = _create_openai_chunk_from_sagemaker_chunk(response, model) + if chunk is None: + msg = "Failed to convert SageMaker chunk to OpenAI format" + raise ValueError(msg) + return chunk + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for SageMaker.""" + return kwargs + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert SageMaker embedding response to OpenAI format.""" + return _create_openai_embedding_response_from_sagemaker( + response["embedding_data"], response["model"], response["total_tokens"] + ) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert SageMaker list models response to OpenAI format.""" + return [] + def __init__(self, config: ClientConfig) -> None: """Initialize AWS SageMaker provider.""" logger.warning( @@ -104,7 +143,7 @@ def completion( **(self.config.client_args if self.config.client_args else {}), ) - completion_kwargs = _convert_params(params, kwargs) + completion_kwargs = self._convert_completion_params(params, **kwargs) if params.response_format: if params.stream: @@ -128,7 +167,7 @@ def completion( return _convert_instructor_response(structured_response, params.model_id, "aws") except (ValueError, TypeError) as e: logger.warning("Failed to parse structured response: %s", e) - return _convert_response(response_body, params.model_id) + return self._convert_completion_response({"model": params.model_id, **response_body}) if params.stream: response = client.invoke_endpoint_with_response_stream( @@ -139,11 +178,9 @@ def completion( event_stream = response["Body"] return ( - chunk - for chunk in ( - _create_openai_chunk_from_sagemaker_chunk(event, model=params.model_id) for event in event_stream - ) - if chunk is not None + self._convert_completion_chunk_response(event, model=params.model_id) + for event in event_stream + if _create_openai_chunk_from_sagemaker_chunk(event, model=params.model_id) is not None ) response = client.invoke_endpoint( @@ -153,7 +190,7 @@ def completion( ) response_body = json.loads(response["Body"].read()) - return _convert_response(response_body, params.model_id) + return self._convert_completion_response({"model": params.model_id, **response_body}) async def aembedding( self, @@ -222,4 +259,5 @@ def embedding( embedding_data.append({"embedding": embedding, "index": index}) total_tokens += response_body.get("usage", {}).get("prompt_tokens", len(text.split())) - return _create_openai_embedding_response_from_sagemaker(embedding_data, model, total_tokens) + response_data = {"embedding_data": embedding_data, "model": model, "total_tokens": total_tokens} + return self._convert_embedding_response(response_data) diff --git a/src/any_llm/providers/together/together.py b/src/any_llm/providers/together/together.py index 86c2b85d..b4c9baaa 100644 --- a/src/any_llm/providers/together/together.py +++ b/src/any_llm/providers/together/together.py @@ -1,4 +1,4 @@ -from collections.abc import AsyncIterator, Iterator +from collections.abc import AsyncIterator, Iterator, Sequence from typing import TYPE_CHECKING, Any, cast from any_llm.provider import Provider @@ -6,7 +6,9 @@ ChatCompletion, ChatCompletionChunk, CompletionParams, + CreateEmbeddingResponse, ) +from any_llm.types.model import Model from any_llm.utils.instructor import _convert_instructor_response MISSING_PACKAGES_ERROR = None @@ -43,6 +45,46 @@ class TogetherProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Together API.""" + # Together does not support providing reasoning effort + converted_params = params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format"}) + if converted_params.get("reasoning_effort") == "auto": + converted_params.pop("reasoning_effort") + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Together response to OpenAI format.""" + # We need the model parameter for conversion + model = response.get("model", "together-model") + return _convert_together_response_to_chat_completion(response, model) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Together chunk response to OpenAI format.""" + return _create_openai_chunk_from_together_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Together.""" + msg = "Together does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Together embedding response to OpenAI format.""" + msg = "Together does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Together list models response to OpenAI format.""" + msg = "Together does not support listing models" + raise NotImplementedError(msg) + async def _stream_completion_async( self, client: "together.AsyncTogether", @@ -62,7 +104,7 @@ async def _stream_completion_async( ), ) async for chunk in response: - yield _create_openai_chunk_from_together_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) def _stream_completion( self, @@ -83,7 +125,7 @@ def _stream_completion( ), ) for chunk in response: - yield _create_openai_chunk_from_together_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) def completion( self, @@ -97,9 +139,6 @@ def completion( **(self.config.client_args if self.config.client_args else {}), ) - if params.reasoning_effort == "auto": - params.reasoning_effort = None - if params.response_format: instructor_client = instructor.patch(client, mode=instructor.Mode.JSON) # type: ignore [call-overload] @@ -113,13 +152,14 @@ def completion( return _convert_instructor_response(instructor_response, params.model_id, self.PROVIDER_NAME) + completion_kwargs = self._convert_completion_params(params, **kwargs) + if params.stream: return self._stream_completion( client, params.model_id, params.messages, - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "stream"}), - **kwargs, + **completion_kwargs, ) response = cast( @@ -127,12 +167,11 @@ def completion( client.chat.completions.create( model=params.model_id, messages=cast("Any", params.messages), - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format"}), - **kwargs, + **completion_kwargs, ), ) - return _convert_together_response_to_chat_completion(response.model_dump(), params.model_id) + return self._convert_completion_response(response.model_dump()) async def acompletion( self, @@ -146,9 +185,6 @@ async def acompletion( **(self.config.client_args if self.config.client_args else {}), ) - if params.reasoning_effort == "auto": - params.reasoning_effort = None - if params.response_format: instructor_client = instructor.patch(client, mode=instructor.Mode.JSON) # type: ignore [call-overload] @@ -165,13 +201,14 @@ async def acompletion( return _convert_instructor_response(instructor_response, params.model_id, self.PROVIDER_NAME) + completion_kwargs = self._convert_completion_params(params, **kwargs) + if params.stream: return self._stream_completion_async( client, params.model_id, params.messages, - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "reasoning_effort", "stream"}), - **kwargs, + **completion_kwargs, ) response = cast( @@ -179,9 +216,8 @@ async def acompletion( await client.chat.completions.create( model=params.model_id, messages=cast("Any", params.messages), - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format"}), - **kwargs, + **completion_kwargs, ), ) - return _convert_together_response_to_chat_completion(response.model_dump(), params.model_id) + return self._convert_completion_response(response.model_dump()) diff --git a/src/any_llm/providers/voyage/voyage.py b/src/any_llm/providers/voyage/voyage.py index 9dc62e28..e0c64658 100644 --- a/src/any_llm/providers/voyage/voyage.py +++ b/src/any_llm/providers/voyage/voyage.py @@ -1,8 +1,9 @@ -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Sequence from typing import Any from any_llm.provider import Provider from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse +from any_llm.types.model import Model MISSING_PACKAGES_ERROR = None try: @@ -34,24 +35,62 @@ class VoyageProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Voyage API.""" + msg = "Voyage does not support completions" + raise NotImplementedError(msg) + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Voyage response to OpenAI format.""" + msg = "Voyage does not support completions" + raise NotImplementedError(msg) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Voyage chunk response to OpenAI format.""" + msg = "Voyage does not support completions" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Voyage.""" + if isinstance(params, str): + params = [params] + converted_params = {"texts": params} + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Voyage embedding response to OpenAI format.""" + # We need the model parameter for conversion + model = response.get("model", "voyage-model") + return _create_openai_embedding_response_from_voyage(model, response["result"]) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Voyage list models response to OpenAI format.""" + msg = "Voyage does not support listing models" + raise NotImplementedError(msg) + async def aembedding( self, model: str, inputs: str | list[str], **kwargs: Any, ) -> CreateEmbeddingResponse: - if isinstance(inputs, str): - inputs = [inputs] - + embedding_kwargs = self._convert_embedding_params(inputs, **kwargs) client = AsyncClient( api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {}) ) result = await client.embed( - texts=inputs, model=model, - **kwargs, + **embedding_kwargs, ) - return _create_openai_embedding_response_from_voyage(model, result) + response_data = {"model": model, "result": result} + return self._convert_embedding_response(response_data) async def acompletion( self, params: CompletionParams, **kwargs: Any diff --git a/src/any_llm/providers/watsonx/watsonx.py b/src/any_llm/providers/watsonx/watsonx.py index 049dfa1d..84f4459e 100644 --- a/src/any_llm/providers/watsonx/watsonx.py +++ b/src/any_llm/providers/watsonx/watsonx.py @@ -27,7 +27,7 @@ from ibm_watsonx_ai import APIClient as WatsonxClient # noqa: TC004 - from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams + from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse from any_llm.types.model import Model @@ -48,6 +48,45 @@ class WatsonxProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Watsonx API.""" + # Watsonx does not support providing reasoning effort + converted_params = params.model_dump( + exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"} + ) + if converted_params.get("reasoning_effort") == "auto": + converted_params.pop("reasoning_effort") + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Watsonx response to OpenAI format.""" + return _convert_response(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Watsonx chunk response to OpenAI format.""" + return _convert_streaming_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Watsonx.""" + msg = "Watsonx does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Watsonx embedding response to OpenAI format.""" + msg = "Watsonx does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Watsonx list models response to OpenAI format.""" + return _convert_models_list(response) + async def _stream_completion_async( self, model_inference: ModelInference, @@ -60,7 +99,7 @@ async def _stream_completion_async( params=kwargs, ) async for chunk in response_stream: - yield _convert_streaming_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) def _stream_completion( self, @@ -74,7 +113,7 @@ def _stream_completion( params=kwargs, ) for chunk in response_stream: - yield _convert_streaming_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) async def acompletion( self, @@ -101,19 +140,17 @@ async def acompletion( if params.reasoning_effort == "auto": params.reasoning_effort = None + completion_kwargs = self._convert_completion_params(params, **kwargs) + if params.stream: - kwargs = { - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), - **kwargs, - } - return self._stream_completion_async(model_inference, params.messages, **kwargs) + return self._stream_completion_async(model_inference, params.messages, **completion_kwargs) response = await model_inference.achat( messages=params.messages, - params=params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), + params=completion_kwargs, ) - return _convert_response(response) + return self._convert_completion_response(response) def completion( self, @@ -140,19 +177,17 @@ def completion( if params.reasoning_effort == "auto": params.reasoning_effort = None + completion_kwargs = self._convert_completion_params(params, **kwargs) + if params.stream: - kwargs = { - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), - **kwargs, - } - return self._stream_completion(model_inference, params.messages, **kwargs) + return self._stream_completion(model_inference, params.messages, **completion_kwargs) response = model_inference.chat( messages=params.messages, - params=params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), + params=completion_kwargs, ) - return _convert_response(response) + return self._convert_completion_response(response) def list_models(self, **kwargs: Any) -> Sequence[Model]: """ @@ -178,4 +213,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: else: models_data = {"resources": [models_response]} - return _convert_models_list(models_data) + return self._convert_list_models_response(models_data) diff --git a/src/any_llm/providers/xai/xai.py b/src/any_llm/providers/xai/xai.py index a423d284..b1b69ff4 100644 --- a/src/any_llm/providers/xai/xai.py +++ b/src/any_llm/providers/xai/xai.py @@ -2,7 +2,7 @@ from typing import Any from any_llm.provider import Provider -from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams +from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse from any_llm.types.model import Model MISSING_PACKAGES_ERROR = None @@ -40,6 +40,53 @@ class XaiProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for xAI API.""" + # xAI does not support providing reasoning effort + converted_params = params.model_dump( + exclude_none=True, + exclude={ + "model_id", + "messages", + "stream", + "response_format", + "tools", + "tool_choice", + }, + ) + if converted_params.get("reasoning_effort") == "auto": + converted_params.pop("reasoning_effort") + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert xAI response to OpenAI format.""" + return _convert_xai_completion_to_anyllm_response(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert xAI chunk response to OpenAI format.""" + return _convert_xai_chunk_to_anyllm_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for xAI.""" + msg = "xAI does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert xAI embedding response to OpenAI format.""" + msg = "xAI does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert xAI list models response to OpenAI format.""" + return _convert_models_list(response) + async def acompletion( self, params: CompletionParams, **kwargs: Any ) -> ChatCompletion | AsyncIterator[ChatCompletionChunk]: @@ -74,24 +121,12 @@ async def acompletion( elif tool_choice is not None: kwargs["tool_choice"] = tool_choice - if params.reasoning_effort == "auto": - params.reasoning_effort = None + completion_kwargs = self._convert_completion_params(params, **kwargs) chat = client.chat.create( model=params.model_id, messages=xai_messages, - **params.model_dump( - exclude_none=True, - exclude={ - "model_id", - "messages", - "stream", - "response_format", - "tools", - "tool_choice", - }, - ), - **kwargs, + **completion_kwargs, ) if params.stream: if params.response_format: @@ -101,7 +136,7 @@ async def acompletion( async def _stream() -> AsyncIterator[ChatCompletionChunk]: async for _, chunk in stream_iter: - yield _convert_xai_chunk_to_anyllm_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) return _stream() @@ -110,7 +145,7 @@ async def _stream() -> AsyncIterator[ChatCompletionChunk]: else: response = await chat.sample() - return _convert_xai_completion_to_anyllm_response(response) + return self._convert_completion_response(response) def list_models(self, **kwargs: Any) -> Sequence[Model]: """ @@ -118,4 +153,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: """ client = XaiClient(api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {})) models_list = client.models.list_language_models() - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) diff --git a/tests/conftest.py b/tests/conftest.py index 284c34f8..59f5a5e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -53,7 +53,7 @@ def provider_model_map() -> dict[ProviderName, str]: ProviderName.SAGEMAKER: "", ProviderName.WATSONX: "ibm/granite-3-8b-instruct", ProviderName.FIREWORKS: "accounts/fireworks/models/llama4-scout-instruct-basic", - ProviderName.GROQ: "llama-3.1-8b-instant", + ProviderName.GROQ: "openai/gpt-oss-20b", ProviderName.PORTKEY: "@first-integrati-d8a10f/gpt-4.1-mini", # Owned by njbrake in portkey UI ProviderName.LLAMA: "Llama-4-Maverick-17B-128E-Instruct-FP8", ProviderName.AZURE: "openai/gpt-4.1-nano",