From a5d6de76554cb98e0a2e814136f2fb5c4d74629e Mon Sep 17 00:00:00 2001 From: daavoo Date: Tue, 30 Sep 2025 14:59:52 +0200 Subject: [PATCH] fix(sambanova): Convert response_format to expected type. --- src/any_llm/providers/sambanova/sambanova.py | 20 ++++ .../unit/providers/test_sambanova_provider.py | 102 ++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 tests/unit/providers/test_sambanova_provider.py diff --git a/src/any_llm/providers/sambanova/sambanova.py b/src/any_llm/providers/sambanova/sambanova.py index 282dc179..324720ff 100644 --- a/src/any_llm/providers/sambanova/sambanova.py +++ b/src/any_llm/providers/sambanova/sambanova.py @@ -1,4 +1,9 @@ +from typing import Any + +from pydantic import BaseModel + from any_llm.providers.openai.base import BaseOpenAIProvider +from any_llm.types.completion import CompletionParams class SambanovaProvider(BaseOpenAIProvider): @@ -8,3 +13,18 @@ class SambanovaProvider(BaseOpenAIProvider): PROVIDER_DOCUMENTATION_URL = "https://sambanova.ai/" SUPPORTS_COMPLETION_PDF = False + + @staticmethod + def _convert_completion_params(params: CompletionParams, **kwargs: Any) -> dict[str, Any]: + """Convert CompletionParams to kwargs for OpenAI API.""" + if isinstance(params.response_format, type) and issubclass(params.response_format, BaseModel): + params.response_format = { + "type": "json_schema", + "json_schema": { + "name": "response_schema", + "schema": params.response_format.model_json_schema(), + }, + } + converted_params = params.model_dump(exclude_none=True, exclude={"model_id", "messages"}) + converted_params.update(kwargs) + return converted_params diff --git a/tests/unit/providers/test_sambanova_provider.py b/tests/unit/providers/test_sambanova_provider.py new file mode 100644 index 00000000..57766dd5 --- /dev/null +++ b/tests/unit/providers/test_sambanova_provider.py @@ -0,0 +1,102 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import BaseModel + +from any_llm.providers.sambanova.sambanova import SambanovaProvider +from any_llm.types.completion import CompletionParams + + +class TestResponseSchema(BaseModel): + name: str + age: int + + +@patch("any_llm.providers.openai.base.AsyncOpenAI") +@pytest.mark.asyncio +async def test_sambanova_converts_pydantic_response_format(mock_openai_class: MagicMock) -> None: + """Test that Pydantic BaseModel response_format is converted to JSON schema format.""" + mock_client = AsyncMock() + mock_openai_class.return_value = mock_client + + # Mock the response + mock_response = MagicMock() + mock_client.chat.completions.parse = AsyncMock(return_value=mock_response) + + provider = SambanovaProvider(api_key="test-key") + + messages = [{"role": "user", "content": "Hello"}] + params = CompletionParams(model_id="test-model", messages=messages, response_format=TestResponseSchema) + + await provider._acompletion(params) + + # Verify the client was called with the converted response_format + mock_client.chat.completions.parse.assert_called_once() + call_args = mock_client.chat.completions.parse.call_args + + assert call_args is not None + kwargs = call_args.kwargs + + expected_response_format = { + "type": "json_schema", + "json_schema": { + "name": "response_schema", + "schema": TestResponseSchema.model_json_schema(), + }, + } + + assert kwargs["response_format"] == expected_response_format + assert kwargs["model"] == "test-model" + + +@patch("any_llm.providers.openai.base.AsyncOpenAI") +@pytest.mark.asyncio +async def test_sambanova_preserves_dict_response_format(mock_openai_class: MagicMock) -> None: + """Test that dict response_format is passed through unchanged.""" + mock_client = AsyncMock() + mock_openai_class.return_value = mock_client + + # Mock the response + mock_response = MagicMock() + mock_client.chat.completions.parse = AsyncMock(return_value=mock_response) + + provider = SambanovaProvider(api_key="test-key") + + messages = [{"role": "user", "content": "Hello"}] + dict_response_format = {"type": "json_object"} + params = CompletionParams(model_id="test-model", messages=messages, response_format=dict_response_format) + + await provider._acompletion(params) + + # Verify the client was called with the original dict response_format + mock_client.chat.completions.parse.assert_called_once() + call_args = mock_client.chat.completions.parse.call_args + + assert call_args is not None + kwargs = call_args.kwargs + + assert kwargs["response_format"] == dict_response_format + assert kwargs["model"] == "test-model" + + +@patch("any_llm.providers.openai.base.AsyncOpenAI") +@pytest.mark.asyncio +async def test_sambanova_without_response_format(mock_openai_class: MagicMock) -> None: + """Test normal completion without response_format.""" + mock_client = AsyncMock() + mock_openai_class.return_value = mock_client + + # Mock the response + mock_response = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=mock_response) + + provider = SambanovaProvider(api_key="test-key") + + messages = [{"role": "user", "content": "Hello"}] + params = CompletionParams(model_id="test-model", messages=messages) + + await provider._acompletion(params) + + # Verify the normal create method was called + mock_client.chat.completions.create.assert_called_once() + mock_client.chat.completions.parse.assert_not_called()