diff --git a/src/any_llm/providers/cerebras/cerebras.py b/src/any_llm/providers/cerebras/cerebras.py index 84296e13..24e35e9b 100644 --- a/src/any_llm/providers/cerebras/cerebras.py +++ b/src/any_llm/providers/cerebras/cerebras.py @@ -38,7 +38,7 @@ class CerebrasProvider(AnyLLM): SUPPORTS_COMPLETION_STREAMING = True SUPPORTS_COMPLETION = True SUPPORTS_RESPONSES = False - SUPPORTS_COMPLETION_REASONING = False + SUPPORTS_COMPLETION_REASONING = True SUPPORTS_COMPLETION_IMAGE = False SUPPORTS_COMPLETION_PDF = False SUPPORTS_EMBEDDING = False diff --git a/src/any_llm/providers/cerebras/utils.py b/src/any_llm/providers/cerebras/utils.py index 32f4527d..281fc516 100644 --- a/src/any_llm/providers/cerebras/utils.py +++ b/src/any_llm/providers/cerebras/utils.py @@ -12,6 +12,7 @@ Choice, CompletionUsage, Function, + Reasoning, ) from any_llm.types.model import Model @@ -75,6 +76,10 @@ def _create_openai_chunk_from_cerebras_chunk(chunk: ChatChunkResponse) -> ChatCo tool_calls_list.append(tool_call_dict) delta["tool_calls"] = tool_calls_list + reasoning_content = getattr(choice_delta, "reasoning", None) + if reasoning_content: + delta["reasoning"] = {"content": reasoning_content} + usage = getattr(chunk, "usage", None) if usage: chunk_dict["usage"] = { @@ -117,10 +122,13 @@ def _convert_response(response_data: dict[str, Any]) -> ChatCompletion: ) ) tool_calls = tool_calls_list + + reasoning_content = message_data.get("reasoning", None) message = ChatCompletionMessage( role=message_data.get("role", "assistant"), content=message_data.get("content"), tool_calls=tool_calls, + reasoning=Reasoning(content=reasoning_content) if reasoning_content else None, ) from typing import Literal, cast diff --git a/tests/conftest.py b/tests/conftest.py index 1525216a..5417c6ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,7 @@ def provider_reasoning_model_map() -> dict[LLMProvider, str]: LLMProvider.LLAMACPP: "N/A", LLMProvider.LMSTUDIO: "openai/gpt-oss-20b", # You must have LM Studio running and the server enabled LLMProvider.AZUREOPENAI: "azure/", + LLMProvider.CEREBRAS: "gpt-oss-120b", LLMProvider.COHERE: "command-a-reasoning-08-2025", LLMProvider.DEEPSEEK: "deepseek-reasoner", } diff --git a/tests/unit/providers/test_cerebras_provider.py b/tests/unit/providers/test_cerebras_provider.py index 70b5cf29..25fb97af 100644 --- a/tests/unit/providers/test_cerebras_provider.py +++ b/tests/unit/providers/test_cerebras_provider.py @@ -2,6 +2,7 @@ from any_llm.exceptions import UnsupportedParameterError from any_llm.providers.cerebras.cerebras import CerebrasProvider +from any_llm.providers.cerebras.utils import _convert_response, _create_openai_chunk_from_cerebras_chunk @pytest.mark.asyncio @@ -20,3 +21,113 @@ async def test_stream_with_response_format_raises() -> None: with pytest.raises(UnsupportedParameterError): async for _ in chunks: pass + + +def test_convert_response_extracts_reasoning() -> None: + response_data = { + "id": "test-id", + "model": "llama-3.3-70b", + "created": 1234567890, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello!", + "reasoning": "The user asked me to say hello, so I will respond with a greeting.", + "tool_calls": None, + }, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + + result = _convert_response(response_data) + + assert result.choices[0].message.content == "Hello!" + assert result.choices[0].message.reasoning is not None + assert ( + result.choices[0].message.reasoning.content + == "The user asked me to say hello, so I will respond with a greeting." + ) + + +def test_convert_response_without_reasoning() -> None: + response_data = { + "id": "test-id", + "model": "llama-3.3-70b", + "created": 1234567890, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello!", + "tool_calls": None, + }, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + + result = _convert_response(response_data) + + assert result.choices[0].message.content == "Hello!" + assert result.choices[0].message.reasoning is None + + +def test_convert_chunk_extracts_reasoning() -> None: + from unittest.mock import Mock + + mock_chunk = Mock() + mock_chunk.id = "test-chunk-id" + mock_chunk.model = "llama-3.3-70b" + mock_chunk.created = 1234567890 + + mock_delta = Mock() + mock_delta.content = "Hello!" + mock_delta.role = "assistant" + mock_delta.reasoning = "Thinking about the greeting..." + mock_delta.tool_calls = None + + mock_choice = Mock() + mock_choice.delta = mock_delta + mock_choice.finish_reason = None + + mock_chunk.choices = [mock_choice] + mock_chunk.usage = None + + result = _create_openai_chunk_from_cerebras_chunk(mock_chunk) + + assert result.choices[0].delta.content == "Hello!" + assert result.choices[0].delta.reasoning is not None + assert result.choices[0].delta.reasoning.content == "Thinking about the greeting..." + + +def test_convert_chunk_without_reasoning() -> None: + from unittest.mock import Mock + + mock_chunk = Mock() + mock_chunk.id = "test-chunk-id" + mock_chunk.model = "llama-3.3-70b" + mock_chunk.created = 1234567890 + + mock_delta = Mock() + mock_delta.content = "Hello!" + mock_delta.role = "assistant" + mock_delta.reasoning = None + mock_delta.tool_calls = None + + mock_choice = Mock() + mock_choice.delta = mock_delta + mock_choice.finish_reason = None + + mock_chunk.choices = [mock_choice] + mock_chunk.usage = None + + result = _create_openai_chunk_from_cerebras_chunk(mock_chunk) + + assert result.choices[0].delta.content == "Hello!" + assert "reasoning" not in result.choices[0].delta.model_dump(exclude_none=True)