From 2beb3f21c0437b22a6b9062541a99e4002f49d49 Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 8 Sep 2025 10:43:04 -0400 Subject: [PATCH 1/6] initial refactor --- src/any_llm/provider.py | 36 ++++ src/any_llm/providers/anthropic/anthropic.py | 43 +++- src/any_llm/providers/azure/azure.py | 77 +++++-- src/any_llm/providers/bedrock/bedrock.py | 64 ++++-- src/any_llm/providers/cerebras/cerebras.py | 62 ++++-- src/any_llm/providers/cohere/cohere.py | 59 +++++- src/any_llm/providers/gemini/base.py | 191 +++++++++++------- src/any_llm/providers/groq/groq.py | 60 +++++- .../providers/huggingface/huggingface.py | 47 ++++- src/any_llm/providers/mistral/mistral.py | 64 ++++-- src/any_llm/providers/ollama/ollama.py | 72 +++++-- src/any_llm/providers/openai/base.py | 97 +++++++-- src/any_llm/providers/together/together.py | 74 +++++-- src/any_llm/providers/voyage/voyage.py | 53 ++++- src/any_llm/providers/watsonx/watsonx.py | 71 +++++-- src/any_llm/providers/xai/xai.py | 71 +++++-- 16 files changed, 869 insertions(+), 272 deletions(-) diff --git a/src/any_llm/provider.py b/src/any_llm/provider.py index f7d1e1da..5df5fc44 100644 --- a/src/any_llm/provider.py +++ b/src/any_llm/provider.py @@ -149,6 +149,42 @@ def _verify_and_set_api_key(self, config: ClientConfig) -> ClientConfig: raise MissingApiKeyError(self.PROVIDER_NAME, self.ENV_API_KEY_NAME) return config + @staticmethod + @abstractmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + msg = "Subclasses must implement this method" + raise NotImplementedError(msg) + + @staticmethod + @abstractmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + msg = "Subclasses must implement this method" + raise NotImplementedError(msg) + + @staticmethod + @abstractmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + msg = "Subclasses must implement this method" + raise NotImplementedError(msg) + + @staticmethod + @abstractmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + msg = "Subclasses must implement this method" + raise NotImplementedError(msg) + + @staticmethod + @abstractmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + msg = "Subclasses must implement this method" + raise NotImplementedError(msg) + + @staticmethod + @abstractmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + msg = "Subclasses must implement this method" + raise NotImplementedError(msg) + @classmethod def get_provider_metadata(cls) -> ProviderMetadata: """Get provider metadata without requiring instantiation. diff --git a/src/any_llm/providers/anthropic/anthropic.py b/src/any_llm/providers/anthropic/anthropic.py index 0918c9a0..f7e42ce0 100644 --- a/src/any_llm/providers/anthropic/anthropic.py +++ b/src/any_llm/providers/anthropic/anthropic.py @@ -2,7 +2,7 @@ from typing import Any from any_llm.provider import Provider -from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams +from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse from any_llm.types.model import Model MISSING_PACKAGES_ERROR = None @@ -39,6 +39,39 @@ class AnthropicProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Anthropic API.""" + return _convert_params(params, **kwargs) + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Anthropic Message to OpenAI ChatCompletion format.""" + return _convert_response(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Anthropic streaming chunk to OpenAI ChatCompletionChunk format.""" + model_id = kwargs.get("model_id", "unknown") + return _create_openai_chunk_from_anthropic_chunk(response, model_id) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Anthropic does not support embeddings.""" + msg = "Anthropic does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Anthropic does not support embeddings.""" + msg = "Anthropic does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Anthropic models list to OpenAI format.""" + return _convert_models_list(response) + async def _stream_completion_async( self, client: "AsyncAnthropic", **kwargs: Any ) -> AsyncIterator[ChatCompletionChunk]: @@ -47,7 +80,7 @@ async def _stream_completion_async( **kwargs, ) as anthropic_stream: async for event in anthropic_stream: - yield _create_openai_chunk_from_anthropic_chunk(event, kwargs.get("model", "unknown")) + yield self._convert_completion_chunk_response(event, model_id=kwargs.get("model", "unknown")) async def acompletion( self, @@ -62,14 +95,14 @@ async def acompletion( ) kwargs["provider_name"] = self.PROVIDER_NAME - converted_kwargs = _convert_params(params, **kwargs) + converted_kwargs = self._convert_completion_params(params, **kwargs) if converted_kwargs.pop("stream", False): return self._stream_completion_async(client, **converted_kwargs) message = await client.messages.create(**converted_kwargs) - return _convert_response(message) + return self._convert_completion_response(message) def list_models(self, **kwargs: Any) -> Sequence[Model]: """List available models from Anthropic.""" @@ -79,4 +112,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: **(self.config.client_args if self.config.client_args else {}), ) models_list = client.models.list(**kwargs) - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) diff --git a/src/any_llm/providers/azure/azure.py b/src/any_llm/providers/azure/azure.py index 417320be..8195c811 100644 --- a/src/any_llm/providers/azure/azure.py +++ b/src/any_llm/providers/azure/azure.py @@ -3,7 +3,7 @@ import os from typing import TYPE_CHECKING, Any, cast -from any_llm.provider import ClientConfig, Provider +from any_llm.provider import Provider MISSING_PACKAGES_ERROR = None try: @@ -20,12 +20,13 @@ MISSING_PACKAGES_ERROR = e if TYPE_CHECKING: - from collections.abc import AsyncIterable, AsyncIterator + from collections.abc import AsyncIterable, AsyncIterator, Sequence from azure.ai.inference import aio # noqa: TC004 from azure.ai.inference.models import ChatCompletions, EmbeddingsResult, StreamingChatCompletionsUpdate from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse + from any_llm.types.model import Model class AzureProvider(Provider): @@ -43,10 +44,6 @@ class AzureProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR - def __init__(self, config: ClientConfig) -> None: - """Initialize Azure provider.""" - super().__init__(config) - def _get_endpoint(self) -> str: """Get the Azure endpoint URL.""" if self.config.api_base: @@ -94,7 +91,7 @@ async def _stream_completion_async( ) async for chunk in azure_stream: - yield _create_openai_chunk_from_azure_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) async def acompletion( self, @@ -105,16 +102,7 @@ async def acompletion( api_version = os.getenv("AZURE_API_VERSION", kwargs.pop("api_version", "2024-02-15-preview")) client: aio.ChatCompletionsClient = self._create_chat_client_async(api_version) - if params.reasoning_effort == "auto": - params.reasoning_effort = None - - azure_response_format = None - if params.response_format: - azure_response_format = _convert_response_format(params.response_format) - - call_kwargs = params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format"}) - if azure_response_format: - call_kwargs["response_format"] = azure_response_format + call_kwargs = self._convert_completion_params(params, **kwargs) if params.stream: return self._stream_completion_async( @@ -122,7 +110,6 @@ async def acompletion( params.model_id, params.messages, **call_kwargs, - **kwargs, ) response: ChatCompletions = cast( @@ -131,11 +118,10 @@ async def acompletion( model=params.model_id, messages=params.messages, **call_kwargs, - **kwargs, ), ) - return _convert_response(response) + return self._convert_completion_response(response) async def aembedding( self, @@ -147,10 +133,59 @@ async def aembedding( api_version = os.getenv("AZURE_API_VERSION", kwargs.pop("api_version", "2024-02-15-preview")) client: aio.EmbeddingsClient = self._create_embeddings_client_async(api_version) + embedding_kwargs = self._convert_embedding_params({}, **kwargs) + response: EmbeddingsResult = await client.embed( model=model, input=inputs if isinstance(inputs, list) else [inputs], - **kwargs, + **embedding_kwargs, ) + return self._convert_embedding_response(response) + + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to Azure AI Inference format.""" + if params.reasoning_effort == "auto": + params.reasoning_effort = None + + azure_response_format = None + if params.response_format: + azure_response_format = _convert_response_format(params.response_format) + + call_kwargs = params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format"}) + if azure_response_format: + call_kwargs["response_format"] = azure_response_format + + call_kwargs.update(kwargs) + return call_kwargs + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Azure ChatCompletions response to OpenAI ChatCompletion format.""" + return _convert_response(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Azure StreamingChatCompletionsUpdate to OpenAI ChatCompletionChunk format.""" + return _create_openai_chunk_from_azure_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters to Azure AI Inference format.""" + embedding_kwargs = {} + if isinstance(params, dict): + embedding_kwargs.update(params) + embedding_kwargs.update(kwargs) + return embedding_kwargs + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Azure EmbeddingsResult to OpenAI CreateEmbeddingResponse format.""" return _create_openai_embedding_response_from_azure(response) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Azure list models response to OpenAI format. Not supported by Azure.""" + msg = "Azure provider does not support listing models" + raise NotImplementedError(msg) diff --git a/src/any_llm/providers/bedrock/bedrock.py b/src/any_llm/providers/bedrock/bedrock.py index ab2f0b87..ffb72659 100644 --- a/src/any_llm/providers/bedrock/bedrock.py +++ b/src/any_llm/providers/bedrock/bedrock.py @@ -45,6 +45,48 @@ class BedrockProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for AWS API.""" + return _convert_params(params, kwargs) + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert AWS Bedrock response to OpenAI format.""" + return _convert_response(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert AWS Bedrock chunk response to OpenAI format.""" + model = kwargs.get("model", "") + chunk = _create_openai_chunk_from_aws_chunk(response, model) + if chunk is None: + msg = "Failed to convert AWS chunk to OpenAI format" + raise ValueError(msg) + return chunk + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for AWS Bedrock.""" + # For bedrock, we don't need to convert the params, just pass them through + return kwargs + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert AWS Bedrock embedding response to OpenAI format.""" + return _create_openai_embedding_response_from_aws( + response["embedding_data"], response["model"], response["total_tokens"] + ) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert AWS Bedrock list models response to OpenAI format.""" + models_list = response.get("modelSummaries", []) + # AWS doesn't provide a creation date for models + # AWS doesn't provide typing, but per https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/list_foundation_models.html + # the modelId is a string and will not be None + return [Model(id=model["modelId"], object="model", created=0, owned_by="aws") for model in models_list] + def __init__(self, config: ClientConfig) -> None: """Initialize AWS Bedrock provider.""" # This intentionally does not call super().__init__(config) because AWS has a different way of handling credentials @@ -103,7 +145,7 @@ def completion( **(self.config.client_args if self.config.client_args else {}), ) - completion_kwargs = _convert_params(params, kwargs) + completion_kwargs = self._convert_completion_params(params, **kwargs) if params.response_format: if params.stream: @@ -129,15 +171,13 @@ def completion( ) stream_generator = response_stream["stream"] return ( - chunk - for chunk in ( - _create_openai_chunk_from_aws_chunk(item, model=params.model_id) for item in stream_generator - ) - if chunk is not None + self._convert_completion_chunk_response(item, model=params.model_id) + for item in stream_generator + if _create_openai_chunk_from_aws_chunk(item, model=params.model_id) is not None ) response = client.converse(**completion_kwargs) - return _convert_response(response) + return self._convert_completion_response(response) async def aembedding( self, @@ -193,7 +233,8 @@ def embedding( total_tokens += response_body.get("inputTextTokenCount", 0) - return _create_openai_embedding_response_from_aws(embedding_data, model, total_tokens) + response_data = {"embedding_data": embedding_data, "model": model, "total_tokens": total_tokens} + return self._convert_embedding_response(response_data) def list_models(self, **kwargs: Any) -> Sequence[Model]: """ @@ -205,8 +246,5 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: region_name=self.region_name, **(self.config.client_args if self.config.client_args else {}), ) # type: ignore[no-untyped-call] - models_list = client.list_foundation_models(**kwargs).get("modelSummaries", []) - # AWS doesn't provide a creation date for models - # AWS doesn't provide typing, but per https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/list_foundation_models.html - # the modelId is a string and will not be None - return [Model(id=model["modelId"], object="model", created=0, owned_by="aws") for model in models_list] + response = client.list_foundation_models(**kwargs) + return self._convert_list_models_response(response) diff --git a/src/any_llm/providers/cerebras/cerebras.py b/src/any_llm/providers/cerebras/cerebras.py index a2fe7067..ff564818 100644 --- a/src/any_llm/providers/cerebras/cerebras.py +++ b/src/any_llm/providers/cerebras/cerebras.py @@ -5,7 +5,7 @@ from any_llm.exceptions import UnsupportedParameterError from any_llm.provider import Provider -from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams +from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse from any_llm.types.model import Model MISSING_PACKAGES_ERROR = None @@ -38,6 +38,46 @@ class CerebrasProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Cerebras API.""" + # Cerebras does not support providing reasoning effort + converted_params = params.model_dump(exclude_none=True, exclude={"model_id", "messages", "stream"}) + if converted_params.get("reasoning_effort") == "auto": + converted_params["reasoning_effort"] = None + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Cerebras response to OpenAI format.""" + return _convert_response(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Cerebras chunk response to OpenAI format.""" + if isinstance(response, ChatChunkResponse): + return _create_openai_chunk_from_cerebras_chunk(response) + msg = f"Unsupported chunk type: {type(response)}" + raise ValueError(msg) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Cerebras.""" + msg = "Cerebras does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Cerebras embedding response to OpenAI format.""" + msg = "Cerebras does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Cerebras list models response to OpenAI format.""" + return _convert_models_list(response) + async def _stream_completion_async( self, model: str, @@ -61,11 +101,7 @@ async def _stream_completion_async( ) async for chunk in cast("cerebras.AsyncStream[ChatCompletion]", cerebras_stream): - if isinstance(chunk, ChatChunkResponse): - yield _create_openai_chunk_from_cerebras_chunk(chunk) - else: - msg = f"Unsupported chunk type: {type(chunk)}" - raise ValueError(msg) + yield self._convert_completion_chunk_response(chunk) async def acompletion( self, @@ -74,16 +110,13 @@ async def acompletion( ) -> ChatCompletion | AsyncIterator[ChatCompletionChunk]: """Create a chat completion using Cerebras with instructor support for structured outputs.""" - # Cerebras does not support providing reasoning effort - if params.reasoning_effort == "auto": - params.reasoning_effort = None + completion_kwargs = self._convert_completion_params(params, **kwargs) if params.stream: return self._stream_completion_async( params.model_id, params.messages, - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "stream"}), - **kwargs, + **completion_kwargs, ) client = cerebras.AsyncCerebras( @@ -105,8 +138,7 @@ async def acompletion( response = await client.chat.completions.create( model=params.model_id, messages=params.messages, - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "stream"}), - **kwargs, + **completion_kwargs, ) if hasattr(response, "model_dump"): @@ -115,7 +147,7 @@ async def acompletion( msg = "Streaming responses are not supported in this context" raise ValueError(msg) - return _convert_response(response_data) + return self._convert_completion_response(response_data) def list_models(self, **kwargs: Any) -> Sequence[Model]: """ @@ -125,4 +157,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {}) ) models_list = client.models.list(**kwargs) - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) diff --git a/src/any_llm/providers/cohere/cohere.py b/src/any_llm/providers/cohere/cohere.py index 5ab1a63a..d79229a8 100644 --- a/src/any_llm/providers/cohere/cohere.py +++ b/src/any_llm/providers/cohere/cohere.py @@ -5,7 +5,7 @@ from any_llm.exceptions import UnsupportedParameterError from any_llm.provider import Provider -from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams +from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse from any_llm.types.model import Model MISSING_PACKAGES_ERROR = None @@ -38,6 +38,47 @@ class CohereProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Cohere API.""" + # Cohere does not support providing reasoning effort + converted_params = params.model_dump( + exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"} + ) + if converted_params.get("reasoning_effort") == "auto": + converted_params["reasoning_effort"] = None + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any, **kwargs: Any) -> ChatCompletion: + """Convert Cohere response to OpenAI format.""" + # We need the model parameter for conversion + model = kwargs.get("model", getattr(response, "model", "cohere-model")) + return _convert_response(response, model) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Cohere chunk response to OpenAI format.""" + return _create_openai_chunk_from_cohere_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Cohere.""" + msg = "Cohere does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Cohere embedding response to OpenAI format.""" + msg = "Cohere does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Cohere list models response to OpenAI format.""" + return _convert_models_list(response) + async def _stream_completion_async( self, model: str, messages: list[dict[str, Any]], **kwargs: Any ) -> AsyncIterator[ChatCompletionChunk]: @@ -52,7 +93,7 @@ async def _stream_completion_async( ) async for chunk in cohere_stream: - yield _create_openai_chunk_from_cohere_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) @staticmethod def _preprocess_response_format(response_format: type[BaseModel] | dict[str, Any]) -> dict[str, Any]: @@ -73,9 +114,6 @@ def _preprocess_response_format(response_format: type[BaseModel] | dict[str, Any async def acompletion( self, params: CompletionParams, **kwargs: Any ) -> ChatCompletion | AsyncIterator[ChatCompletionChunk]: - if params.reasoning_effort == "auto": - params.reasoning_effort = None - if params.response_format is not None: kwargs["response_format"] = self._preprocess_response_format(params.response_format) if params.stream and params.response_format is not None: @@ -85,14 +123,14 @@ async def acompletion( msg = "parallel_tool_calls" raise UnsupportedParameterError(msg, self.PROVIDER_NAME) + completion_kwargs = self._convert_completion_params(params, **kwargs) patched_messages = _patch_messages(params.messages) if params.stream: return self._stream_completion_async( params.model_id, patched_messages, - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), - **kwargs, + **completion_kwargs, ) client = cohere.AsyncClientV2( @@ -103,11 +141,10 @@ async def acompletion( response = await client.chat( model=params.model_id, messages=patched_messages, # type: ignore[arg-type] - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "stream", "response_format"}), - **kwargs, + **completion_kwargs, ) - return _convert_response(response, params.model_id) + return self._convert_completion_response(response, model=params.model_id) def list_models(self, **kwargs: Any) -> Sequence[Model]: """ @@ -117,4 +154,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {}) ) model_list = client.models.list(**kwargs) - return _convert_models_list(model_list) + return self._convert_list_models_response(model_list) diff --git a/src/any_llm/providers/gemini/base.py b/src/any_llm/providers/gemini/base.py index bc4f00ad..9c2fed4a 100644 --- a/src/any_llm/providers/gemini/base.py +++ b/src/any_llm/providers/gemini/base.py @@ -1,6 +1,6 @@ from abc import abstractmethod from collections.abc import AsyncIterator, Sequence -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, cast from pydantic import BaseModel @@ -56,6 +56,111 @@ class GoogleProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Google API.""" + base_kwargs = params.model_dump( + exclude_none=True, + exclude={ + "model_id", + "messages", + "response_format", + "stream", + "tools", + "tool_choice", + "reasoning_effort", + "max_tokens", + }, + ) + if params.max_tokens is not None: + base_kwargs["max_output_tokens"] = params.max_tokens + base_kwargs.update(kwargs) + return base_kwargs + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Google response data to OpenAI ChatCompletion format.""" + # Expect response to be a tuple of (response_dict, model_id) + response_dict, model_id = response + choices_out: list[Choice] = [] + for i, choice_item in enumerate(response_dict.get("choices", [])): + message_dict: dict[str, Any] = choice_item.get("message", {}) + tool_calls: list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageToolCall] | None = None + if message_dict.get("tool_calls"): + tool_calls_list: list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageToolCall] = [] + for tc in message_dict["tool_calls"]: + tool_calls_list.append( + ChatCompletionMessageFunctionToolCall( + id=tc.get("id"), + type="function", + function=Function( + name=tc["function"]["name"], + arguments=tc["function"]["arguments"], + ), + ) + ) + tool_calls = tool_calls_list + + reasoning_content = message_dict.get("reasoning") + message = ChatCompletionMessage( + role="assistant", + content=message_dict.get("content"), + tool_calls=tool_calls, + reasoning=Reasoning(content=reasoning_content) if reasoning_content else None, + ) + from typing import Literal + + choices_out.append( + Choice( + index=i, + finish_reason=cast( + "Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call']", + choice_item.get("finish_reason", "stop"), + ), + message=message, + ) + ) + + usage_dict = response_dict.get("usage", {}) + usage = CompletionUsage( + prompt_tokens=usage_dict.get("prompt_tokens", 0), + completion_tokens=usage_dict.get("completion_tokens", 0), + total_tokens=usage_dict.get("total_tokens", 0), + ) + + return ChatCompletion( + id=response_dict.get("id", ""), + model=model_id, + created=response_dict.get("created", 0), + object="chat.completion", + choices=choices_out, + usage=usage, + ) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Google chunk response to OpenAI format.""" + return _create_openai_chunk_from_google_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Google API.""" + converted_params = {"contents": params} + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Google embedding response to OpenAI format.""" + # We need the model parameter for conversion + model = response.get("model", "google-model") + return _create_openai_embedding_response_from_google(model, response["result"]) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Google list models response to OpenAI format.""" + return _convert_models_list(response) + @abstractmethod def _get_client(self, config: ClientConfig) -> "genai.Client": """Get the appropriate client for this provider implementation.""" @@ -67,13 +172,14 @@ async def aembedding( **kwargs: Any, ) -> CreateEmbeddingResponse: client = self._get_client(self.config) + embedding_kwargs = self._convert_embedding_params(inputs, **kwargs) result = await client.aio.models.embed_content( model=model, - contents=inputs, # type: ignore[arg-type] - **kwargs, + **embedding_kwargs, ) - return _create_openai_embedding_response_from_google(model, result) + response_data = {"model": model, "result": result} + return self._convert_embedding_response(response_data) async def acompletion( self, @@ -104,24 +210,7 @@ async def acompletion( stream = bool(params.stream) response_format = params.response_format - base_kwargs = params.model_dump( - exclude_none=True, - exclude={ - "model_id", - "messages", - "response_format", - "stream", - "tools", - "tool_choice", - "reasoning_effort", - "max_tokens", - }, - ) - - if params.max_tokens is not None: - base_kwargs["max_output_tokens"] = params.max_tokens - - base_kwargs.update(kwargs) + base_kwargs = self._convert_completion_params(params, **kwargs) generation_config = types.GenerateContentConfig(**base_kwargs) if isinstance(response_format, type) and issubclass(response_format, BaseModel): generation_config.response_mime_type = "application/json" @@ -141,7 +230,7 @@ async def acompletion( async def _stream() -> AsyncIterator[ChatCompletionChunk]: async for chunk in response_stream: - yield _create_openai_chunk_from_google_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) return _stream() @@ -152,62 +241,10 @@ async def _stream() -> AsyncIterator[ChatCompletionChunk]: ) response_dict = _convert_response_to_response_dict(response) - - choices_out: list[Choice] = [] - for i, choice_item in enumerate(response_dict.get("choices", [])): - message_dict: dict[str, Any] = choice_item.get("message", {}) - tool_calls: list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageToolCall] | None = None - if message_dict.get("tool_calls"): - tool_calls_list: list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageToolCall] = [] - for tc in message_dict["tool_calls"]: - tool_calls_list.append( - ChatCompletionMessageFunctionToolCall( - id=tc.get("id"), - type="function", - function=Function( - name=tc["function"]["name"], - arguments=tc["function"]["arguments"], - ), - ) - ) - tool_calls = tool_calls_list - - reasoning_content = message_dict.get("reasoning") - message = ChatCompletionMessage( - role="assistant", - content=message_dict.get("content"), - tool_calls=tool_calls, - reasoning=Reasoning(content=reasoning_content) if reasoning_content else None, - ) - choices_out.append( - Choice( - index=i, - finish_reason=cast( - "Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call']", - choice_item.get("finish_reason", "stop"), - ), - message=message, - ) - ) - - usage_dict = response_dict.get("usage", {}) - usage = CompletionUsage( - prompt_tokens=usage_dict.get("prompt_tokens", 0), - completion_tokens=usage_dict.get("completion_tokens", 0), - total_tokens=usage_dict.get("total_tokens", 0), - ) - - return ChatCompletion( - id=response_dict.get("id", ""), - model=params.model_id, - created=response_dict.get("created", 0), - object="chat.completion", - choices=choices_out, - usage=usage, - ) + return self._convert_completion_response((response_dict, params.model_id)) def list_models(self, **kwargs: Any) -> Sequence[Model]: """Fetch available models from the /v1/models endpoint.""" client = self._get_client(self.config) models_list = client.models.list(**kwargs) - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) diff --git a/src/any_llm/providers/groq/groq.py b/src/any_llm/providers/groq/groq.py index 22988d8d..2122eac3 100644 --- a/src/any_llm/providers/groq/groq.py +++ b/src/any_llm/providers/groq/groq.py @@ -9,6 +9,9 @@ from any_llm.provider import Provider from any_llm.types.responses import Response, ResponseStreamEvent +if TYPE_CHECKING: + from any_llm.types.completion import CreateEmbeddingResponse + MISSING_PACKAGES_ERROR = None try: import groq @@ -49,6 +52,45 @@ class GroqProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Groq API.""" + # Groq does not support providing reasoning effort + converted_params = params.model_dump( + exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"} + ) + if converted_params.get("reasoning_effort") == "auto": + converted_params["reasoning_effort"] = None + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Groq response to OpenAI format.""" + return to_chat_completion(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Groq chunk response to OpenAI format.""" + return _create_openai_chunk_from_groq_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Groq.""" + msg = "Groq does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Groq embedding response to OpenAI format.""" + msg = "Groq does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Groq list models response to OpenAI format.""" + return _convert_models_list(response) + async def _stream_async_completion( self, client: groq.AsyncGroq, params: CompletionParams, **kwargs: Any ) -> AsyncIterator[ChatCompletionChunk]: @@ -56,16 +98,16 @@ async def _stream_async_completion( msg = "stream and response_format" raise UnsupportedParameterError(msg, self.PROVIDER_NAME) + completion_kwargs = self._convert_completion_params(params, **kwargs) stream: GroqAsyncStream[GroqChatCompletionChunk] = await client.chat.completions.create( model=params.model_id, messages=cast("Any", params.messages), - **params.model_dump(exclude_none=True, exclude={"model_id", "messages"}), - **kwargs, + **completion_kwargs, ) async def _stream() -> AsyncIterator[ChatCompletionChunk]: async for chunk in stream: - yield _create_openai_chunk_from_groq_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) return _stream() @@ -77,9 +119,6 @@ async def acompletion( api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {}) ) - if params.reasoning_effort == "auto": - params.reasoning_effort = None - if params.response_format: if isinstance(params.response_format, type) and issubclass(params.response_format, BaseModel): kwargs["response_format"] = { @@ -92,6 +131,8 @@ async def acompletion( else: kwargs["response_format"] = params.response_format + completion_kwargs = self._convert_completion_params(params, **kwargs) + if params.stream: return await self._stream_async_completion( client, @@ -101,11 +142,10 @@ async def acompletion( response: GroqChatCompletion = await client.chat.completions.create( model=params.model_id, messages=cast("Any", params.messages), - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), - **kwargs, + **completion_kwargs, ) - return to_chat_completion(response) + return self._convert_completion_response(response) async def aresponses( self, model: str, input_data: Any, **kwargs: Any @@ -141,4 +181,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: """ client = groq.Groq(api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {})) models_list = client.models.list(**kwargs) - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) diff --git a/src/any_llm/providers/huggingface/huggingface.py b/src/any_llm/providers/huggingface/huggingface.py index ec63e9ea..cab57484 100644 --- a/src/any_llm/providers/huggingface/huggingface.py +++ b/src/any_llm/providers/huggingface/huggingface.py @@ -9,6 +9,7 @@ Choice, CompletionParams, CompletionUsage, + CreateEmbeddingResponse, ) from any_llm.types.model import Model @@ -46,6 +47,42 @@ class HuggingfaceProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for HuggingFace API.""" + return _convert_params(params, **kwargs) + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert HuggingFace response to OpenAI format.""" + # If it's already our ChatCompletion type, return it + if isinstance(response, ChatCompletion): + return response + # Otherwise, validate it as our type + return ChatCompletion.model_validate(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert HuggingFace chunk response to OpenAI format.""" + return _create_openai_chunk_from_huggingface_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for HuggingFace.""" + msg = "HuggingFace does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert HuggingFace embedding response to OpenAI format.""" + msg = "HuggingFace does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert HuggingFace list models response to OpenAI format.""" + return _convert_models_list(response) + async def _stream_completion_async( self, client: "AsyncInferenceClient", @@ -55,7 +92,7 @@ async def _stream_completion_async( response: AsyncIterator[HuggingFaceChatCompletionStreamOutput] = await client.chat_completion(**kwargs) async for chunk in response: - yield _create_openai_chunk_from_huggingface_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) def _stream_completion( self, @@ -67,7 +104,7 @@ def _stream_completion( **kwargs, ) for chunk in response: - yield _create_openai_chunk_from_huggingface_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) async def acompletion( self, @@ -81,7 +118,7 @@ async def acompletion( **(self.config.client_args if self.config.client_args else {}), ) - converted_kwargs = _convert_params(params, **kwargs) + converted_kwargs = self._convert_completion_params(params, **kwargs) if params.stream: converted_kwargs["stream"] = True @@ -130,7 +167,7 @@ def completion( **(self.config.client_args if self.config.client_args else {}), ) - converted_kwargs = _convert_params(params, **kwargs) + converted_kwargs = self._convert_completion_params(params, **kwargs) if params.stream: converted_kwargs["stream"] = True @@ -180,4 +217,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: if kwargs.get("limit") is None: kwargs["limit"] = 20 models_list = client.list_models(**kwargs) - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) diff --git a/src/any_llm/providers/mistral/mistral.py b/src/any_llm/providers/mistral/mistral.py index 7c13278a..454b020a 100644 --- a/src/any_llm/providers/mistral/mistral.py +++ b/src/any_llm/providers/mistral/mistral.py @@ -47,13 +47,54 @@ class MistralProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Mistral API.""" + # Mistral does not support providing reasoning effort + converted_params = params.model_dump( + exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"} + ) + if converted_params.get("reasoning_effort") == "auto": + converted_params["reasoning_effort"] = None + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Mistral response to OpenAI format.""" + # We need the model parameter for conversion + model = getattr(response, "model", "mistral-model") + return _create_mistral_completion_from_response(response_data=response, model=model) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Mistral chunk response to OpenAI format.""" + return _create_openai_chunk_from_mistral_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Mistral.""" + converted_params = {"inputs": params} + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Mistral embedding response to OpenAI format.""" + return _create_openai_embedding_response_from_mistral(response) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Mistral list models response to OpenAI format.""" + return _convert_models_list(response) + async def _stream_completion_async( self, client: Mistral, model: str, messages: list[dict[str, Any]], **kwargs: Any ) -> AsyncIterator[ChatCompletionChunk]: mistral_stream = await client.chat.stream_async(model=model, messages=messages, **kwargs) # type: ignore[arg-type] async for event in mistral_stream: - yield _create_openai_chunk_from_mistral_chunk(event) + yield self._convert_completion_chunk_response(event) async def acompletion( self, params: CompletionParams, **kwargs: Any @@ -76,26 +117,23 @@ async def acompletion( **(self.config.client_args if self.config.client_args else {}), ) + completion_kwargs = self._convert_completion_params(params, **kwargs) + if params.stream: return self._stream_completion_async( client, params.model_id, patched_messages, - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), - **kwargs, + **completion_kwargs, ) response = await client.chat.complete_async( model=params.model_id, messages=patched_messages, # type: ignore[arg-type] - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), - **kwargs, + **completion_kwargs, ) - return _create_mistral_completion_from_response( - response_data=response, - model=params.model_id, - ) + return self._convert_completion_response(response) async def aembedding( self, @@ -108,13 +146,13 @@ async def aembedding( server_url=self.config.api_base, **(self.config.client_args if self.config.client_args else {}), ) + embedding_kwargs = self._convert_embedding_params(inputs, **kwargs) result: EmbeddingResponse = await client.embeddings.create_async( model=model, - inputs=inputs, - **kwargs, + **embedding_kwargs, ) - return _create_openai_embedding_response_from_mistral(result) + return self._convert_embedding_response(result) def list_models(self, **kwargs: Any) -> Sequence[Model]: """ @@ -126,4 +164,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: **(self.config.client_args if self.config.client_args else {}), ) models_list = client.models.list(**kwargs) - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) diff --git a/src/any_llm/providers/ollama/ollama.py b/src/any_llm/providers/ollama/ollama.py index c10362c3..a45e98eb 100644 --- a/src/any_llm/providers/ollama/ollama.py +++ b/src/any_llm/providers/ollama/ollama.py @@ -52,6 +52,46 @@ class OllamaProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Ollama API.""" + # Ollama does not support providing reasoning effort + converted_params = params.model_dump( + exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"} + ) + if converted_params.get("reasoning_effort") == "auto": + converted_params["reasoning_effort"] = None + converted_params.update(kwargs) + converted_params["num_ctx"] = converted_params.get("num_ctx", 32000) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Ollama response to OpenAI format.""" + return _create_chat_completion_from_ollama_response(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Ollama chunk response to OpenAI format.""" + return _create_openai_chunk_from_ollama_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Ollama.""" + converted_params = {"input": params} + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Ollama embedding response to OpenAI format.""" + return _create_openai_embedding_response_from_ollama(response) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Ollama list models response to OpenAI format.""" + return _convert_models_list(response) + def __init__(self, config: ClientConfig) -> None: """We don't use the Provider init because by default we don't require an API key.""" self._verify_no_missing_packages() @@ -76,7 +116,7 @@ async def _stream_completion_async( options=kwargs, ) async for chunk in response: - yield _create_openai_chunk_from_ollama_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) async def acompletion( self, @@ -115,33 +155,25 @@ async def acompletion( cleaned_messages.append(cleaned_message) - if params.reasoning_effort == "auto": - params.reasoning_effort = None - - kwargs = { - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), - **kwargs, - } - - kwargs["num_ctx"] = kwargs.get("num_ctx", 32000) + completion_kwargs = self._convert_completion_params(params, **kwargs) if params.reasoning_effort is not None: - kwargs["think"] = True + completion_kwargs["think"] = True client = AsyncClient(host=self.url, **(self.config.client_args if self.config.client_args else {})) if params.stream: - return self._stream_completion_async(client, params.model_id, cleaned_messages, **kwargs) + return self._stream_completion_async(client, params.model_id, cleaned_messages, **completion_kwargs) response: OllamaChatResponse = await client.chat( model=params.model_id, - tools=kwargs.pop("tools", None), - think=kwargs.pop("think", None), + tools=completion_kwargs.pop("tools", None), + think=completion_kwargs.pop("think", None), messages=cleaned_messages, format=output_format, - options=kwargs, + options=completion_kwargs, ) - return _create_chat_completion_from_ollama_response(response) + return self._convert_completion_response(response) async def aembedding( self, @@ -152,12 +184,12 @@ async def aembedding( """Generate embeddings using Ollama.""" client = AsyncClient(host=self.url, **(self.config.client_args if self.config.client_args else {})) + embedding_kwargs = self._convert_embedding_params(inputs, **kwargs) response = await client.embed( model=model, - input=inputs, - **kwargs, + **embedding_kwargs, ) - return _create_openai_embedding_response_from_ollama(response) + return self._convert_embedding_response(response) def list_models(self, **kwargs: Any) -> Sequence[Model]: """ @@ -165,4 +197,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: """ client = Client(host=self.url, **(self.config.client_args if self.config.client_args else {})) models_list = client.list(**kwargs) - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) diff --git a/src/any_llm/providers/openai/base.py b/src/any_llm/providers/openai/base.py index 0db8e016..1f7f1c26 100644 --- a/src/any_llm/providers/openai/base.py +++ b/src/any_llm/providers/openai/base.py @@ -37,6 +37,68 @@ class BaseOpenAIProvider(Provider, ABC): _DEFAULT_REASONING_EFFORT: Literal["minimal", "low", "medium", "high", "auto"] | None = None + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for OpenAI API.""" + converted_params = params.model_dump(exclude_none=True, exclude={"model_id", "messages"}) + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert OpenAI response to OpenAI format (passthrough).""" + if isinstance(response, OpenAIChatCompletion): + return _convert_chat_completion(response) + # If it's already our ChatCompletion type, return it + if isinstance(response, ChatCompletion): + return response + # Otherwise, validate it as our type + return ChatCompletion.model_validate(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert OpenAI chunk response to OpenAI format (passthrough).""" + if isinstance(response, OpenAIChatCompletionChunk): + if not isinstance(response.created, int): + logger.warning( + "API returned an unexpected created type: %s. Setting to int.", + type(response.created), + ) + response.created = int(response.created) + normalized_chunk = _normalize_openai_dict_response(response.model_dump()) + return ChatCompletionChunk.model_validate(normalized_chunk) + # If it's already our ChatCompletionChunk type, return it + if isinstance(response, ChatCompletionChunk): + return response + # Otherwise, validate it as our type + return ChatCompletionChunk.model_validate(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for OpenAI API.""" + converted_params = {"input": params} + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert OpenAI embedding response to OpenAI format (passthrough).""" + if isinstance(response, CreateEmbeddingResponse): + return response + return CreateEmbeddingResponse.model_validate(response) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert OpenAI list models response to OpenAI format (passthrough).""" + if hasattr(response, "data"): + # Validate each model in the data + return [Model.model_validate(model) if not isinstance(model, Model) else model for model in response.data] + # If it's already a sequence of our Model type, return it + if isinstance(response, (list, tuple)) and all(isinstance(item, Model) for item in response): + return response + # Otherwise, validate each item + return [Model.model_validate(item) if not isinstance(item, Model) else item for item in response] + def _get_client(self, sync: bool = False) -> AsyncOpenAI | OpenAI: _client_class = OpenAI if sync else AsyncOpenAI return _client_class( @@ -50,19 +112,11 @@ def _convert_completion_response_async( ) -> ChatCompletion | AsyncIterator[ChatCompletionChunk]: """Convert an OpenAI completion response to an AnyLLM completion response.""" if isinstance(response, OpenAIChatCompletion): - return _convert_chat_completion(response) + return self._convert_completion_response(response) async def chunk_iterator() -> AsyncIterator[ChatCompletionChunk]: async for chunk in response: - if not isinstance(chunk.created, int): - logger.warning( - "API returned an unexpected created type: %s. Setting to int.", - type(chunk.created), - ) - chunk.created = int(chunk.created) - - normalized_chunk = _normalize_openai_dict_response(chunk.model_dump()) - yield ChatCompletionChunk.model_validate(normalized_chunk) + yield self._convert_completion_chunk_response(chunk) return chunk_iterator() @@ -74,6 +128,8 @@ async def acompletion( if params.reasoning_effort == "auto": params.reasoning_effort = self._DEFAULT_REASONING_EFFORT + completion_kwargs = self._convert_completion_params(params, **kwargs) + if params.response_format: if params.stream: msg = "stream is not supported for response_format" @@ -82,15 +138,13 @@ async def acompletion( response = await client.chat.completions.parse( model=params.model_id, messages=cast("Any", params.messages), - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "stream"}), - **kwargs, + **completion_kwargs, ) else: response = await client.chat.completions.create( model=params.model_id, messages=cast("Any", params.messages), - **params.model_dump(exclude_none=True, exclude={"model_id", "messages"}), - **kwargs, + **completion_kwargs, ) return self._convert_completion_response_async(response) @@ -123,11 +177,13 @@ async def aembedding( client = cast("AsyncOpenAI", self._get_client()) - return await client.embeddings.create( - model=model, - input=inputs, - dimensions=kwargs.get("dimensions", NOT_GIVEN), - **kwargs, + embedding_kwargs = self._convert_embedding_params(inputs, **kwargs) + return self._convert_embedding_response( + await client.embeddings.create( + model=model, + dimensions=kwargs.get("dimensions", NOT_GIVEN), + **embedding_kwargs, + ) ) def list_models(self, **kwargs: Any) -> Sequence[Model]: @@ -138,4 +194,5 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: message = f"{self.PROVIDER_NAME} does not support listing models." raise NotImplementedError(message) client = cast("OpenAI", self._get_client(sync=True)) - return client.models.list(**kwargs).data + response = client.models.list(**kwargs) + return self._convert_list_models_response(response) diff --git a/src/any_llm/providers/together/together.py b/src/any_llm/providers/together/together.py index f68ae564..39dc52c2 100644 --- a/src/any_llm/providers/together/together.py +++ b/src/any_llm/providers/together/together.py @@ -1,4 +1,4 @@ -from collections.abc import AsyncIterator, Iterator +from collections.abc import AsyncIterator, Iterator, Sequence from typing import TYPE_CHECKING, Any, cast from any_llm.provider import Provider @@ -6,7 +6,9 @@ ChatCompletion, ChatCompletionChunk, CompletionParams, + CreateEmbeddingResponse, ) +from any_llm.types.model import Model from any_llm.utils.instructor import _convert_instructor_response MISSING_PACKAGES_ERROR = None @@ -42,6 +44,46 @@ class TogetherProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Together API.""" + # Together does not support providing reasoning effort + converted_params = params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format"}) + if converted_params.get("reasoning_effort") == "auto": + converted_params["reasoning_effort"] = None + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Together response to OpenAI format.""" + # We need the model parameter for conversion + model = response.get("model", "together-model") + return _convert_together_response_to_chat_completion(response, model) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Together chunk response to OpenAI format.""" + return _create_openai_chunk_from_together_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Together.""" + msg = "Together does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Together embedding response to OpenAI format.""" + msg = "Together does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Together list models response to OpenAI format.""" + msg = "Together does not support listing models" + raise NotImplementedError(msg) + async def _stream_completion_async( self, client: "together.AsyncTogether", @@ -61,7 +103,7 @@ async def _stream_completion_async( ), ) async for chunk in response: - yield _create_openai_chunk_from_together_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) def _stream_completion( self, @@ -82,7 +124,7 @@ def _stream_completion( ), ) for chunk in response: - yield _create_openai_chunk_from_together_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) def completion( self, @@ -96,9 +138,6 @@ def completion( **(self.config.client_args if self.config.client_args else {}), ) - if params.reasoning_effort == "auto": - params.reasoning_effort = None - if params.response_format: instructor_client = instructor.patch(client, mode=instructor.Mode.JSON) # type: ignore [call-overload] @@ -112,13 +151,14 @@ def completion( return _convert_instructor_response(instructor_response, params.model_id, self.PROVIDER_NAME) + completion_kwargs = self._convert_completion_params(params, **kwargs) + if params.stream: return self._stream_completion( client, params.model_id, params.messages, - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "stream"}), - **kwargs, + **completion_kwargs, ) response = cast( @@ -126,12 +166,11 @@ def completion( client.chat.completions.create( model=params.model_id, messages=cast("Any", params.messages), - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format"}), - **kwargs, + **completion_kwargs, ), ) - return _convert_together_response_to_chat_completion(response.model_dump(), params.model_id) + return self._convert_completion_response(response.model_dump()) async def acompletion( self, @@ -145,9 +184,6 @@ async def acompletion( **(self.config.client_args if self.config.client_args else {}), ) - if params.reasoning_effort == "auto": - params.reasoning_effort = None - if params.response_format: instructor_client = instructor.patch(client, mode=instructor.Mode.JSON) # type: ignore [call-overload] @@ -164,13 +200,14 @@ async def acompletion( return _convert_instructor_response(instructor_response, params.model_id, self.PROVIDER_NAME) + completion_kwargs = self._convert_completion_params(params, **kwargs) + if params.stream: return self._stream_completion_async( client, params.model_id, params.messages, - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "reasoning_effort", "stream"}), - **kwargs, + **completion_kwargs, ) response = cast( @@ -178,9 +215,8 @@ async def acompletion( await client.chat.completions.create( model=params.model_id, messages=cast("Any", params.messages), - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format"}), - **kwargs, + **completion_kwargs, ), ) - return _convert_together_response_to_chat_completion(response.model_dump(), params.model_id) + return self._convert_completion_response(response.model_dump()) diff --git a/src/any_llm/providers/voyage/voyage.py b/src/any_llm/providers/voyage/voyage.py index 37aa5613..dd853dcf 100644 --- a/src/any_llm/providers/voyage/voyage.py +++ b/src/any_llm/providers/voyage/voyage.py @@ -1,8 +1,9 @@ -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Sequence from typing import Any from any_llm.provider import Provider from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse +from any_llm.types.model import Model MISSING_PACKAGES_ERROR = None try: @@ -33,24 +34,62 @@ class VoyageProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Voyage API.""" + msg = "Voyage does not support completions" + raise NotImplementedError(msg) + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Voyage response to OpenAI format.""" + msg = "Voyage does not support completions" + raise NotImplementedError(msg) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Voyage chunk response to OpenAI format.""" + msg = "Voyage does not support completions" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Voyage.""" + if isinstance(params, str): + params = [params] + converted_params = {"texts": params} + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Voyage embedding response to OpenAI format.""" + # We need the model parameter for conversion + model = response.get("model", "voyage-model") + return _create_openai_embedding_response_from_voyage(model, response["result"]) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Voyage list models response to OpenAI format.""" + msg = "Voyage does not support listing models" + raise NotImplementedError(msg) + async def aembedding( self, model: str, inputs: str | list[str], **kwargs: Any, ) -> CreateEmbeddingResponse: - if isinstance(inputs, str): - inputs = [inputs] - + embedding_kwargs = self._convert_embedding_params(inputs, **kwargs) client = AsyncClient( api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {}) ) result = await client.embed( - texts=inputs, model=model, - **kwargs, + **embedding_kwargs, ) - return _create_openai_embedding_response_from_voyage(model, result) + response_data = {"model": model, "result": result} + return self._convert_embedding_response(response_data) async def acompletion( self, params: CompletionParams, **kwargs: Any diff --git a/src/any_llm/providers/watsonx/watsonx.py b/src/any_llm/providers/watsonx/watsonx.py index 652f956a..dcc82c67 100644 --- a/src/any_llm/providers/watsonx/watsonx.py +++ b/src/any_llm/providers/watsonx/watsonx.py @@ -27,7 +27,7 @@ from ibm_watsonx_ai import APIClient as WatsonxClient # noqa: TC004 - from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams + from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse from any_llm.types.model import Model @@ -47,6 +47,45 @@ class WatsonxProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for Watsonx API.""" + # Watsonx does not support providing reasoning effort + converted_params = params.model_dump( + exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"} + ) + if converted_params.get("reasoning_effort") == "auto": + converted_params["reasoning_effort"] = None + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert Watsonx response to OpenAI format.""" + return _convert_response(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert Watsonx chunk response to OpenAI format.""" + return _convert_streaming_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for Watsonx.""" + msg = "Watsonx does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert Watsonx embedding response to OpenAI format.""" + msg = "Watsonx does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert Watsonx list models response to OpenAI format.""" + return _convert_models_list(response) + async def _stream_completion_async( self, model_inference: ModelInference, @@ -59,7 +98,7 @@ async def _stream_completion_async( params=kwargs, ) async for chunk in response_stream: - yield _convert_streaming_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) def _stream_completion( self, @@ -73,7 +112,7 @@ def _stream_completion( params=kwargs, ) for chunk in response_stream: - yield _convert_streaming_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) async def acompletion( self, @@ -100,19 +139,17 @@ async def acompletion( if params.reasoning_effort == "auto": params.reasoning_effort = None + completion_kwargs = self._convert_completion_params(params, **kwargs) + if params.stream: - kwargs = { - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), - **kwargs, - } - return self._stream_completion_async(model_inference, params.messages, **kwargs) + return self._stream_completion_async(model_inference, params.messages, **completion_kwargs) response = await model_inference.achat( messages=params.messages, - params=params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), + params=completion_kwargs, ) - return _convert_response(response) + return self._convert_completion_response(response) def completion( self, @@ -139,19 +176,17 @@ def completion( if params.reasoning_effort == "auto": params.reasoning_effort = None + completion_kwargs = self._convert_completion_params(params, **kwargs) + if params.stream: - kwargs = { - **params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), - **kwargs, - } - return self._stream_completion(model_inference, params.messages, **kwargs) + return self._stream_completion(model_inference, params.messages, **completion_kwargs) response = model_inference.chat( messages=params.messages, - params=params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"}), + params=completion_kwargs, ) - return _convert_response(response) + return self._convert_completion_response(response) def list_models(self, **kwargs: Any) -> Sequence[Model]: """ @@ -177,4 +212,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: else: models_data = {"resources": [models_response]} - return _convert_models_list(models_data) + return self._convert_list_models_response(models_data) diff --git a/src/any_llm/providers/xai/xai.py b/src/any_llm/providers/xai/xai.py index 593b051b..db973060 100644 --- a/src/any_llm/providers/xai/xai.py +++ b/src/any_llm/providers/xai/xai.py @@ -2,7 +2,7 @@ from typing import Any from any_llm.provider import Provider -from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams +from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse from any_llm.types.model import Model MISSING_PACKAGES_ERROR = None @@ -39,6 +39,53 @@ class XaiProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for xAI API.""" + # xAI does not support providing reasoning effort + converted_params = params.model_dump( + exclude_none=True, + exclude={ + "model_id", + "messages", + "stream", + "response_format", + "tools", + "tool_choice", + }, + ) + if converted_params.get("reasoning_effort") == "auto": + converted_params["reasoning_effort"] = None + converted_params.update(kwargs) + return converted_params + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert xAI response to OpenAI format.""" + return _convert_xai_completion_to_anyllm_response(response) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert xAI chunk response to OpenAI format.""" + return _convert_xai_chunk_to_anyllm_chunk(response) + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for xAI.""" + msg = "xAI does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert xAI embedding response to OpenAI format.""" + msg = "xAI does not support embeddings" + raise NotImplementedError(msg) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert xAI list models response to OpenAI format.""" + return _convert_models_list(response) + async def acompletion( self, params: CompletionParams, **kwargs: Any ) -> ChatCompletion | AsyncIterator[ChatCompletionChunk]: @@ -73,24 +120,12 @@ async def acompletion( elif tool_choice is not None: kwargs["tool_choice"] = tool_choice - if params.reasoning_effort == "auto": - params.reasoning_effort = None + completion_kwargs = self._convert_completion_params(params, **kwargs) chat = client.chat.create( model=params.model_id, messages=xai_messages, - **params.model_dump( - exclude_none=True, - exclude={ - "model_id", - "messages", - "stream", - "response_format", - "tools", - "tool_choice", - }, - ), - **kwargs, + **completion_kwargs, ) if params.stream: if params.response_format: @@ -100,7 +135,7 @@ async def acompletion( async def _stream() -> AsyncIterator[ChatCompletionChunk]: async for _, chunk in stream_iter: - yield _convert_xai_chunk_to_anyllm_chunk(chunk) + yield self._convert_completion_chunk_response(chunk) return _stream() @@ -109,7 +144,7 @@ async def _stream() -> AsyncIterator[ChatCompletionChunk]: else: response = await chat.sample() - return _convert_xai_completion_to_anyllm_response(response) + return self._convert_completion_response(response) def list_models(self, **kwargs: Any) -> Sequence[Model]: """ @@ -117,4 +152,4 @@ def list_models(self, **kwargs: Any) -> Sequence[Model]: """ client = XaiClient(api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {})) models_list = client.models.list_language_models() - return _convert_models_list(models_list) + return self._convert_list_models_response(models_list) From 5e4eec8f48cc652a7229b7c57fc6a5988a3d75cf Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 8 Sep 2025 15:43:40 -0400 Subject: [PATCH 2/6] add sagemaker --- src/any_llm/providers/anthropic/anthropic.py | 7 ++- src/any_llm/providers/sagemaker/sagemaker.py | 58 ++++++++++++++++---- 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/src/any_llm/providers/anthropic/anthropic.py b/src/any_llm/providers/anthropic/anthropic.py index f7e42ce0..71314a16 100644 --- a/src/any_llm/providers/anthropic/anthropic.py +++ b/src/any_llm/providers/anthropic/anthropic.py @@ -8,6 +8,9 @@ MISSING_PACKAGES_ERROR = None try: from anthropic import Anthropic, AsyncAnthropic + from anthropic.pagination import SyncPage + from anthropic.types import Message + from anthropic.types.model_info import ModelInfo as AnthropicModelInfo from .utils import ( _convert_models_list, @@ -45,7 +48,7 @@ def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[ return _convert_params(params, **kwargs) @staticmethod - def _convert_completion_response(response: Any) -> ChatCompletion: + def _convert_completion_response(response: "Message") -> ChatCompletion: """Convert Anthropic Message to OpenAI ChatCompletion format.""" return _convert_response(response) @@ -68,7 +71,7 @@ def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: raise NotImplementedError(msg) @staticmethod - def _convert_list_models_response(response: Any) -> Sequence[Model]: + def _convert_list_models_response(response: "Message") -> Sequence[Model]: """Convert Anthropic models list to OpenAI format.""" return _convert_models_list(response) diff --git a/src/any_llm/providers/sagemaker/sagemaker.py b/src/any_llm/providers/sagemaker/sagemaker.py index 6e7d6d2f..b61115fc 100644 --- a/src/any_llm/providers/sagemaker/sagemaker.py +++ b/src/any_llm/providers/sagemaker/sagemaker.py @@ -2,7 +2,7 @@ import functools import json import os -from collections.abc import AsyncIterator, Callable, Iterator +from collections.abc import AsyncIterator, Callable, Iterator, Sequence from typing import Any from pydantic import BaseModel @@ -11,6 +11,7 @@ from any_llm.logging import logger from any_llm.provider import ClientConfig, Provider from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse +from any_llm.types.model import Model from any_llm.utils.instructor import _convert_instructor_response MISSING_PACKAGES_ERROR = None @@ -43,6 +44,44 @@ class SagemakerProvider(Provider): MISSING_PACKAGES_ERROR = MISSING_PACKAGES_ERROR + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for SageMaker API.""" + return _convert_params(params, kwargs) + + @staticmethod + def _convert_completion_response(response: Any) -> ChatCompletion: + """Convert SageMaker response to OpenAI format.""" + model = response.get("model", "") + return _convert_response(response, model) + + @staticmethod + def _convert_completion_chunk_response(response: Any, **kwargs: Any) -> ChatCompletionChunk: + """Convert SageMaker chunk response to OpenAI format.""" + model = kwargs.get("model", "") + chunk = _create_openai_chunk_from_sagemaker_chunk(response, model) + if chunk is None: + msg = "Failed to convert SageMaker chunk to OpenAI format" + raise ValueError(msg) + return chunk + + @staticmethod + def _convert_embedding_params(params: Any, **kwargs: Any) -> dict[str, Any]: + """Convert embedding parameters for SageMaker.""" + return kwargs + + @staticmethod + def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: + """Convert SageMaker embedding response to OpenAI format.""" + return _create_openai_embedding_response_from_sagemaker( + response["embedding_data"], response["model"], response["total_tokens"] + ) + + @staticmethod + def _convert_list_models_response(response: Any) -> Sequence[Model]: + """Convert SageMaker list models response to OpenAI format.""" + return [] + def __init__(self, config: ClientConfig) -> None: """Initialize AWS SageMaker provider.""" logger.warning( @@ -103,7 +142,7 @@ def completion( **(self.config.client_args if self.config.client_args else {}), ) - completion_kwargs = _convert_params(params, kwargs) + completion_kwargs = self._convert_completion_params(params, **kwargs) if params.response_format: if params.stream: @@ -127,7 +166,7 @@ def completion( return _convert_instructor_response(structured_response, params.model_id, "aws") except (ValueError, TypeError) as e: logger.warning("Failed to parse structured response: %s", e) - return _convert_response(response_body, params.model_id) + return self._convert_completion_response({"model": params.model_id, **response_body}) if params.stream: response = client.invoke_endpoint_with_response_stream( @@ -138,11 +177,9 @@ def completion( event_stream = response["Body"] return ( - chunk - for chunk in ( - _create_openai_chunk_from_sagemaker_chunk(event, model=params.model_id) for event in event_stream - ) - if chunk is not None + self._convert_completion_chunk_response(event, model=params.model_id) + for event in event_stream + if _create_openai_chunk_from_sagemaker_chunk(event, model=params.model_id) is not None ) response = client.invoke_endpoint( @@ -152,7 +189,7 @@ def completion( ) response_body = json.loads(response["Body"].read()) - return _convert_response(response_body, params.model_id) + return self._convert_completion_response({"model": params.model_id, **response_body}) async def aembedding( self, @@ -221,4 +258,5 @@ def embedding( embedding_data.append({"embedding": embedding, "index": index}) total_tokens += response_body.get("usage", {}).get("prompt_tokens", len(text.split())) - return _create_openai_embedding_response_from_sagemaker(embedding_data, model, total_tokens) + response_data = {"embedding_data": embedding_data, "model": model, "total_tokens": total_tokens} + return self._convert_embedding_response(response_data) From b601f09ab6796d98c0d00144e0d55a13e90ca429 Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 8 Sep 2025 15:45:35 -0400 Subject: [PATCH 3/6] lint --- src/any_llm/providers/anthropic/anthropic.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/any_llm/providers/anthropic/anthropic.py b/src/any_llm/providers/anthropic/anthropic.py index 71314a16..f27f0d73 100644 --- a/src/any_llm/providers/anthropic/anthropic.py +++ b/src/any_llm/providers/anthropic/anthropic.py @@ -1,5 +1,5 @@ from collections.abc import AsyncIterator, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any from any_llm.provider import Provider from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionParams, CreateEmbeddingResponse @@ -8,9 +8,6 @@ MISSING_PACKAGES_ERROR = None try: from anthropic import Anthropic, AsyncAnthropic - from anthropic.pagination import SyncPage - from anthropic.types import Message - from anthropic.types.model_info import ModelInfo as AnthropicModelInfo from .utils import ( _convert_models_list, @@ -21,6 +18,11 @@ except ImportError as e: MISSING_PACKAGES_ERROR = e +if TYPE_CHECKING: + from anthropic.pagination import SyncPage + from anthropic.types import Message + from anthropic.types.model_info import ModelInfo as AnthropicModelInfo + class AnthropicProvider(Provider): """ @@ -71,7 +73,7 @@ def _convert_embedding_response(response: Any) -> CreateEmbeddingResponse: raise NotImplementedError(msg) @staticmethod - def _convert_list_models_response(response: "Message") -> Sequence[Model]: + def _convert_list_models_response(response: "SyncPage[AnthropicModelInfo]") -> Sequence[Model]: """Convert Anthropic models list to OpenAI format.""" return _convert_models_list(response) From 5bf6c18d36ba2b14eb70524bca002856bd764f5f Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 8 Sep 2025 16:16:26 -0400 Subject: [PATCH 4/6] Fixes based on failing integration tests --- src/any_llm/providers/cerebras/cerebras.py | 26 +++++++++++----------- src/any_llm/providers/cohere/cohere.py | 2 +- src/any_llm/providers/groq/groq.py | 6 ++--- src/any_llm/providers/mistral/mistral.py | 2 +- src/any_llm/providers/ollama/ollama.py | 4 ++-- src/any_llm/providers/openai/base.py | 2 +- src/any_llm/providers/together/together.py | 2 +- src/any_llm/providers/watsonx/watsonx.py | 2 +- src/any_llm/providers/xai/xai.py | 2 +- tests/conftest.py | 2 +- 10 files changed, 24 insertions(+), 26 deletions(-) diff --git a/src/any_llm/providers/cerebras/cerebras.py b/src/any_llm/providers/cerebras/cerebras.py index ff564818..7d5c6000 100644 --- a/src/any_llm/providers/cerebras/cerebras.py +++ b/src/any_llm/providers/cerebras/cerebras.py @@ -44,7 +44,7 @@ def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[ # Cerebras does not support providing reasoning effort converted_params = params.model_dump(exclude_none=True, exclude={"model_id", "messages", "stream"}) if converted_params.get("reasoning_effort") == "auto": - converted_params["reasoning_effort"] = None + converted_params.pop("reasoning_effort") converted_params.update(kwargs) return converted_params @@ -110,6 +110,18 @@ async def acompletion( ) -> ChatCompletion | AsyncIterator[ChatCompletionChunk]: """Create a chat completion using Cerebras with instructor support for structured outputs.""" + if params.response_format: + # See https://inference-docs.cerebras.ai/capabilities/structured-outputs for guide to creating schema + if isinstance(params.response_format, type) and issubclass(params.response_format, BaseModel): + params.response_format = { + "type": "json_schema", + "json_schema": { + "name": "response_schema", + "schema": params.response_format.model_json_schema(), + "strict": True, + }, + } + completion_kwargs = self._convert_completion_params(params, **kwargs) if params.stream: @@ -123,18 +135,6 @@ async def acompletion( api_key=self.config.api_key, **(self.config.client_args if self.config.client_args else {}) ) - if params.response_format: - # See https://inference-docs.cerebras.ai/capabilities/structured-outputs for guide to creating schema - if isinstance(params.response_format, type) and issubclass(params.response_format, BaseModel): - params.response_format = { - "type": "json_schema", - "json_schema": { - "name": "response_schema", - "schema": params.response_format.model_json_schema(), - "strict": True, - }, - } - response = await client.chat.completions.create( model=params.model_id, messages=params.messages, diff --git a/src/any_llm/providers/cohere/cohere.py b/src/any_llm/providers/cohere/cohere.py index d79229a8..2a91f95a 100644 --- a/src/any_llm/providers/cohere/cohere.py +++ b/src/any_llm/providers/cohere/cohere.py @@ -46,7 +46,7 @@ def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[ exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"} ) if converted_params.get("reasoning_effort") == "auto": - converted_params["reasoning_effort"] = None + converted_params.pop("reasoning_effort") converted_params.update(kwargs) return converted_params diff --git a/src/any_llm/providers/groq/groq.py b/src/any_llm/providers/groq/groq.py index 2122eac3..d73b64eb 100644 --- a/src/any_llm/providers/groq/groq.py +++ b/src/any_llm/providers/groq/groq.py @@ -56,11 +56,9 @@ class GroqProvider(Provider): def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: """Convert CompletionParams to kwargs for Groq API.""" # Groq does not support providing reasoning effort - converted_params = params.model_dump( - exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"} - ) + converted_params = params.model_dump(exclude_none=True, exclude={"model_id", "messages"}) if converted_params.get("reasoning_effort") == "auto": - converted_params["reasoning_effort"] = None + converted_params.pop("reasoning_effort") converted_params.update(kwargs) return converted_params diff --git a/src/any_llm/providers/mistral/mistral.py b/src/any_llm/providers/mistral/mistral.py index 454b020a..93c897e4 100644 --- a/src/any_llm/providers/mistral/mistral.py +++ b/src/any_llm/providers/mistral/mistral.py @@ -55,7 +55,7 @@ def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[ exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"} ) if converted_params.get("reasoning_effort") == "auto": - converted_params["reasoning_effort"] = None + converted_params.pop("reasoning_effort") converted_params.update(kwargs) return converted_params diff --git a/src/any_llm/providers/ollama/ollama.py b/src/any_llm/providers/ollama/ollama.py index a45e98eb..f46d24be 100644 --- a/src/any_llm/providers/ollama/ollama.py +++ b/src/any_llm/providers/ollama/ollama.py @@ -60,7 +60,7 @@ def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[ exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"} ) if converted_params.get("reasoning_effort") == "auto": - converted_params["reasoning_effort"] = None + converted_params.pop("reasoning_effort") converted_params.update(kwargs) converted_params["num_ctx"] = converted_params.get("num_ctx", 32000) return converted_params @@ -157,7 +157,7 @@ async def acompletion( completion_kwargs = self._convert_completion_params(params, **kwargs) - if params.reasoning_effort is not None: + if completion_kwargs.get("reasoning_effort") is not None: completion_kwargs["think"] = True client = AsyncClient(host=self.url, **(self.config.client_args if self.config.client_args else {})) diff --git a/src/any_llm/providers/openai/base.py b/src/any_llm/providers/openai/base.py index 1f7f1c26..8884f952 100644 --- a/src/any_llm/providers/openai/base.py +++ b/src/any_llm/providers/openai/base.py @@ -134,7 +134,7 @@ async def acompletion( if params.stream: msg = "stream is not supported for response_format" raise ValueError(msg) - + completion_kwargs.pop("stream") response = await client.chat.completions.parse( model=params.model_id, messages=cast("Any", params.messages), diff --git a/src/any_llm/providers/together/together.py b/src/any_llm/providers/together/together.py index 39dc52c2..984f70f7 100644 --- a/src/any_llm/providers/together/together.py +++ b/src/any_llm/providers/together/together.py @@ -50,7 +50,7 @@ def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[ # Together does not support providing reasoning effort converted_params = params.model_dump(exclude_none=True, exclude={"model_id", "messages", "response_format"}) if converted_params.get("reasoning_effort") == "auto": - converted_params["reasoning_effort"] = None + converted_params.pop("reasoning_effort") converted_params.update(kwargs) return converted_params diff --git a/src/any_llm/providers/watsonx/watsonx.py b/src/any_llm/providers/watsonx/watsonx.py index dcc82c67..9d8ee4c6 100644 --- a/src/any_llm/providers/watsonx/watsonx.py +++ b/src/any_llm/providers/watsonx/watsonx.py @@ -55,7 +55,7 @@ def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[ exclude_none=True, exclude={"model_id", "messages", "response_format", "stream"} ) if converted_params.get("reasoning_effort") == "auto": - converted_params["reasoning_effort"] = None + converted_params.pop("reasoning_effort") converted_params.update(kwargs) return converted_params diff --git a/src/any_llm/providers/xai/xai.py b/src/any_llm/providers/xai/xai.py index db973060..c8c3fbe7 100644 --- a/src/any_llm/providers/xai/xai.py +++ b/src/any_llm/providers/xai/xai.py @@ -55,7 +55,7 @@ def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[ }, ) if converted_params.get("reasoning_effort") == "auto": - converted_params["reasoning_effort"] = None + converted_params.pop("reasoning_effort") converted_params.update(kwargs) return converted_params diff --git a/tests/conftest.py b/tests/conftest.py index 11acac5b..61d6ea17 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -53,7 +53,7 @@ def provider_model_map() -> dict[ProviderName, str]: ProviderName.SAGEMAKER: "", ProviderName.WATSONX: "ibm/granite-3-8b-instruct", ProviderName.FIREWORKS: "accounts/fireworks/models/llama4-scout-instruct-basic", - ProviderName.GROQ: "llama-3.1-8b-instant", + ProviderName.GROQ: "openai/gpt-oss-20b", ProviderName.PORTKEY: "@first-integrati-d8a10f/gpt-4.1-mini", # Owned by njbrake in portkey UI ProviderName.LLAMA: "Llama-4-Maverick-17B-128E-Instruct-FP8", ProviderName.AZURE: "openai/gpt-4.1-nano", From b02be970006b54b0c748e426b373a6d91be17746 Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 8 Sep 2025 16:24:35 -0400 Subject: [PATCH 5/6] add helper script --- scripts/check_missing_api_keys.py | 129 ++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 scripts/check_missing_api_keys.py diff --git a/scripts/check_missing_api_keys.py b/scripts/check_missing_api_keys.py new file mode 100644 index 00000000..8b350e20 --- /dev/null +++ b/scripts/check_missing_api_keys.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +""" +Script to check which providers are being skipped due to missing API keys. + +This script attempts to instantiate each provider and reports which ones +fail due to missing API keys, missing packages, or other issues. +""" + +import os +import sys +from pathlib import Path + +src_path = Path(__file__).parent.parent / "src" +sys.path.insert(0, str(src_path)) + +from any_llm.provider import ProviderFactory, ProviderName +from any_llm.exceptions import MissingApiKeyError + + +def check_provider_status(): + """Check the status of all providers and categorize them.""" + available_providers = [] + missing_api_keys = [] + missing_packages = [] + other_errors = [] + + print("Checking provider status...") + print("=" * 50) + + for provider_name in ProviderName: + try: + provider_class = ProviderFactory.get_provider_class(provider_name) + + if provider_class.MISSING_PACKAGES_ERROR is not None: + missing_packages.append( + { + "name": provider_name.value, + "error": str(provider_class.MISSING_PACKAGES_ERROR), + "env_var": provider_class.ENV_API_KEY_NAME, + } + ) + continue + + available_providers.append( + { + "name": provider_name.value, + "env_var": provider_class.ENV_API_KEY_NAME, + "api_key_set": bool(os.getenv(provider_class.ENV_API_KEY_NAME)), + } + ) + + except MissingApiKeyError as e: + missing_api_keys.append({"name": provider_name.value, "env_var": e.env_var_name, "error": str(e)}) + except ImportError as e: + missing_packages.append({"name": provider_name.value, "error": str(e), "env_var": "N/A"}) + except Exception as e: + other_errors.append({"name": provider_name.value, "error": str(e), "error_type": type(e).__name__}) + + return available_providers, missing_api_keys, missing_packages, other_errors + + +def print_results(available_providers, missing_api_keys, missing_packages, other_errors): + """Print formatted results of the provider status check.""" + + if available_providers: + print(f"āœ… Available Providers ({len(available_providers)}):") + for provider in sorted(available_providers, key=lambda x: x["name"]): + key_status = "šŸ”‘" if provider["api_key_set"] else "šŸ”“" + print(f" {key_status} {provider['name']} (env: {provider['env_var']})") + print() + + if missing_api_keys: + print(f"šŸ”‘ Missing API Keys ({len(missing_api_keys)}):") + for provider in sorted(missing_api_keys, key=lambda x: x["name"]): + print(f" āŒ {provider['name']} - Set {provider['env_var']}") + print() + + if missing_packages: + print(f"šŸ“¦ Missing Packages ({len(missing_packages)}):") + for provider in sorted(missing_packages, key=lambda x: x["name"]): + print(f" āŒ {provider['name']} - {provider['error']}") + print() + + + if other_errors: + print(f"āš ļø Other Errors ({len(other_errors)}):") + for provider in sorted(other_errors, key=lambda x: x["name"]): + print(f" āŒ {provider['name']} ({provider['error_type']}) - {provider['error']}") + print() + + + total_providers = len(list(ProviderName)) + available_count = len(available_providers) + missing_keys_count = len(missing_api_keys) + missing_packages_count = len(missing_packages) + other_errors_count = len(other_errors) + + print("šŸ“Š Summary:") + print(f" Total providers: {total_providers}") + print(f" Available: {available_count}") + print(f" Missing API keys: {missing_keys_count}") + print(f" Missing packages: {missing_packages_count}") + print(f" Other errors: {other_errors_count}") + + if missing_api_keys: + print("\nšŸ’” To fix missing API keys, set these environment variables:") + for provider in sorted(missing_api_keys, key=lambda x: x["name"]): + print(f" export {provider['env_var']}='your-api-key-here'") + + +def main(): + """Main function to run the provider status check.""" + if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help"]: + print(__doc__) + print("\nUsage:") + print(" python scripts/check_missing_api_keys.py") + print(" python scripts/check_missing_api_keys.py --help") + return + + try: + available, missing_keys, missing_packages, other_errors = check_provider_status() + print_results(available, missing_keys, missing_packages, other_errors) + except Exception as e: + print(f"Error running provider check: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() From 725963bd23e626c102f2502736a5ec2c7c546e48 Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 8 Sep 2025 16:26:26 -0400 Subject: [PATCH 6/6] lint --- scripts/check_missing_api_keys.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) mode change 100644 => 100755 scripts/check_missing_api_keys.py diff --git a/scripts/check_missing_api_keys.py b/scripts/check_missing_api_keys.py old mode 100644 new mode 100755 index 8b350e20..4fd2b29e --- a/scripts/check_missing_api_keys.py +++ b/scripts/check_missing_api_keys.py @@ -10,12 +10,12 @@ import sys from pathlib import Path +from any_llm.exceptions import MissingApiKeyError +from any_llm.provider import ProviderFactory, ProviderName + src_path = Path(__file__).parent.parent / "src" sys.path.insert(0, str(src_path)) -from any_llm.provider import ProviderFactory, ProviderName -from any_llm.exceptions import MissingApiKeyError - def check_provider_status(): """Check the status of all providers and categorize them.""" @@ -81,14 +81,12 @@ def print_results(available_providers, missing_api_keys, missing_packages, other print(f" āŒ {provider['name']} - {provider['error']}") print() - if other_errors: print(f"āš ļø Other Errors ({len(other_errors)}):") for provider in sorted(other_errors, key=lambda x: x["name"]): print(f" āŒ {provider['name']} ({provider['error_type']}) - {provider['error']}") print() - total_providers = len(list(ProviderName)) available_count = len(available_providers) missing_keys_count = len(missing_api_keys)