Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/any_llm/providers/cerebras/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class CerebrasProvider(AnyLLM):
SUPPORTS_COMPLETION_STREAMING = True
SUPPORTS_COMPLETION = True
SUPPORTS_RESPONSES = False
SUPPORTS_COMPLETION_REASONING = False
SUPPORTS_COMPLETION_REASONING = True
SUPPORTS_COMPLETION_IMAGE = False
SUPPORTS_COMPLETION_PDF = False
SUPPORTS_EMBEDDING = False
Expand Down
8 changes: 8 additions & 0 deletions src/any_llm/providers/cerebras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Choice,
CompletionUsage,
Function,
Reasoning,
)
from any_llm.types.model import Model

Expand Down Expand Up @@ -75,6 +76,10 @@ def _create_openai_chunk_from_cerebras_chunk(chunk: ChatChunkResponse) -> ChatCo
tool_calls_list.append(tool_call_dict)
delta["tool_calls"] = tool_calls_list

reasoning_content = getattr(choice_delta, "reasoning", None)
if reasoning_content:
delta["reasoning"] = {"content": reasoning_content}

usage = getattr(chunk, "usage", None)
if usage:
chunk_dict["usage"] = {
Expand Down Expand Up @@ -117,10 +122,13 @@ def _convert_response(response_data: dict[str, Any]) -> ChatCompletion:
)
)
tool_calls = tool_calls_list

reasoning_content = message_data.get("reasoning", None)
message = ChatCompletionMessage(
role=message_data.get("role", "assistant"),
content=message_data.get("content"),
tool_calls=tool_calls,
reasoning=Reasoning(content=reasoning_content) if reasoning_content else None,
)
from typing import Literal, cast

Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def provider_reasoning_model_map() -> dict[LLMProvider, str]:
LLMProvider.LLAMACPP: "N/A",
LLMProvider.LMSTUDIO: "openai/gpt-oss-20b", # You must have LM Studio running and the server enabled
LLMProvider.AZUREOPENAI: "azure/<your_deployment_name>",
LLMProvider.CEREBRAS: "gpt-oss-120b",
LLMProvider.COHERE: "command-a-reasoning-08-2025",
LLMProvider.DEEPSEEK: "deepseek-reasoner",
}
Expand Down
111 changes: 111 additions & 0 deletions tests/unit/providers/test_cerebras_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from any_llm.exceptions import UnsupportedParameterError
from any_llm.providers.cerebras.cerebras import CerebrasProvider
from any_llm.providers.cerebras.utils import _convert_response, _create_openai_chunk_from_cerebras_chunk


@pytest.mark.asyncio
Expand All @@ -20,3 +21,113 @@ async def test_stream_with_response_format_raises() -> None:
with pytest.raises(UnsupportedParameterError):
async for _ in chunks:
pass


def test_convert_response_extracts_reasoning() -> None:
response_data = {
"id": "test-id",
"model": "llama-3.3-70b",
"created": 1234567890,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello!",
"reasoning": "The user asked me to say hello, so I will respond with a greeting.",
"tool_calls": None,
},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
}

result = _convert_response(response_data)

assert result.choices[0].message.content == "Hello!"
assert result.choices[0].message.reasoning is not None
assert (
result.choices[0].message.reasoning.content
== "The user asked me to say hello, so I will respond with a greeting."
)


def test_convert_response_without_reasoning() -> None:
response_data = {
"id": "test-id",
"model": "llama-3.3-70b",
"created": 1234567890,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello!",
"tool_calls": None,
},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
}

result = _convert_response(response_data)

assert result.choices[0].message.content == "Hello!"
assert result.choices[0].message.reasoning is None


def test_convert_chunk_extracts_reasoning() -> None:
from unittest.mock import Mock

mock_chunk = Mock()
mock_chunk.id = "test-chunk-id"
mock_chunk.model = "llama-3.3-70b"
mock_chunk.created = 1234567890

mock_delta = Mock()
mock_delta.content = "Hello!"
mock_delta.role = "assistant"
mock_delta.reasoning = "Thinking about the greeting..."
mock_delta.tool_calls = None

mock_choice = Mock()
mock_choice.delta = mock_delta
mock_choice.finish_reason = None

mock_chunk.choices = [mock_choice]
mock_chunk.usage = None

result = _create_openai_chunk_from_cerebras_chunk(mock_chunk)

assert result.choices[0].delta.content == "Hello!"
assert result.choices[0].delta.reasoning is not None
assert result.choices[0].delta.reasoning.content == "Thinking about the greeting..."


def test_convert_chunk_without_reasoning() -> None:
from unittest.mock import Mock

mock_chunk = Mock()
mock_chunk.id = "test-chunk-id"
mock_chunk.model = "llama-3.3-70b"
mock_chunk.created = 1234567890

mock_delta = Mock()
mock_delta.content = "Hello!"
mock_delta.role = "assistant"
mock_delta.reasoning = None
mock_delta.tool_calls = None

mock_choice = Mock()
mock_choice.delta = mock_delta
mock_choice.finish_reason = None

mock_chunk.choices = [mock_choice]
mock_chunk.usage = None

result = _create_openai_chunk_from_cerebras_chunk(mock_chunk)

assert result.choices[0].delta.content == "Hello!"
assert "reasoning" not in result.choices[0].delta.model_dump(exclude_none=True)