Skip to content

Commit b338e9d

Browse files
committed
fix(fireworks): Properly handle reasoning content.
1 parent 3d684e9 commit b338e9d

File tree

3 files changed

+106
-9
lines changed

3 files changed

+106
-9
lines changed

src/any_llm/providers/fireworks/fireworks.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from openai import AsyncStream
55

66
from any_llm.providers.openai.base import BaseOpenAIProvider
7-
from any_llm.types.completion import Reasoning
87
from any_llm.types.responses import Response, ResponsesParams, ResponseStreamEvent
98

9+
from .utils import extract_reasoning_from_response
10+
1011

1112
class FireworksProvider(BaseOpenAIProvider):
1213
PROVIDER_NAME = "fireworks"
@@ -27,13 +28,6 @@ async def _aresponses(
2728
) -> Response | AsyncIterator[ResponseStreamEvent]:
2829
"""Call Fireworks Responses API and normalize into ChatCompletion/Chunks."""
2930
response = await super()._aresponses(params, **kwargs)
30-
3131
if isinstance(response, Response) and not isinstance(response, AsyncStream):
32-
# See https://fireworks.ai/blog/response-api for details about Fireworks Responses API support
33-
reasoning = response.output[-1].content[0].text.split("</think>")[-1] # type: ignore[union-attr,index]
34-
if reasoning:
35-
reasoning = reasoning.strip()
36-
response.output[-1].content[0].text = response.output[-1].content[0].text.split("</think>")[0] # type: ignore[union-attr,index]
37-
response.reasoning = Reasoning(content=reasoning) if reasoning else None # type: ignore[assignment]
38-
32+
return extract_reasoning_from_response(response)
3933
return response
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# mypy: disable-error-code="union-attr"
2+
from any_llm.types.completion import Reasoning
3+
from any_llm.types.responses import Response
4+
5+
6+
def extract_reasoning_from_response(response: Response) -> Response:
7+
"""Extract <think> content from Fireworks response and set reasoning field.
8+
9+
Fireworks Responses API includes reasoning content within <think></think> tags.
10+
This function extracts that content and moves it to the reasoning field.
11+
12+
Args:
13+
response: The Response object to process
14+
15+
Returns:
16+
The modified Response object with reasoning extracted
17+
"""
18+
if not response.output or not response.output[-1].content:
19+
return response
20+
21+
content_text = response.output[-1].content[0].text
22+
if "<think>" in content_text and "</think>" in content_text:
23+
reasoning = content_text.split("<think>")[1].split("</think>")[0].strip()
24+
# Skip case where reasoning is empty but tags are present
25+
if reasoning:
26+
response.reasoning = Reasoning(content=reasoning) # type: ignore[assignment]
27+
response.output[-1].content[0].text = content_text.split("</think>")[1].strip()
28+
29+
return response
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from unittest.mock import Mock
2+
3+
from any_llm.providers.fireworks.utils import extract_reasoning_from_response
4+
from any_llm.types.completion import Reasoning
5+
from any_llm.types.responses import Response
6+
7+
8+
def test_extract_reasoning_from_response_with_think_tags() -> None:
9+
"""Test that <think> content is correctly extracted into reasoning field."""
10+
# Create a mock Response with <think> tags in content
11+
mock_content = Mock()
12+
mock_content.text = "<think>This is my reasoning process</think>This is the actual response"
13+
14+
mock_output_item = Mock()
15+
mock_output_item.content = [mock_content]
16+
17+
mock_response = Mock(spec=Response)
18+
mock_response.output = [mock_output_item]
19+
mock_response.reasoning = None
20+
21+
result = extract_reasoning_from_response(mock_response)
22+
23+
assert result.reasoning is not None
24+
assert isinstance(result.reasoning, Reasoning)
25+
assert result.reasoning.content == "This is my reasoning process"
26+
assert mock_content.text == "This is the actual response"
27+
28+
29+
def test_extract_reasoning_from_response_without_think_tags() -> None:
30+
"""Test that responses without <think> tags are returned unchanged."""
31+
mock_content = Mock()
32+
mock_content.text = "This is just a regular response"
33+
34+
mock_output_item = Mock()
35+
mock_output_item.content = [mock_content]
36+
37+
mock_response = Mock(spec=Response)
38+
mock_response.output = [mock_output_item]
39+
mock_response.reasoning = None
40+
41+
result = extract_reasoning_from_response(mock_response)
42+
43+
assert result.reasoning is None
44+
assert mock_content.text == "This is just a regular response"
45+
46+
47+
def test_extract_reasoning_from_response_empty_reasoning() -> None:
48+
"""Test that empty reasoning content is handled correctly."""
49+
mock_content = Mock()
50+
mock_content.text = "<think></think>This is the actual response"
51+
52+
mock_output_item = Mock()
53+
mock_output_item.content = [mock_content]
54+
55+
mock_response = Mock(spec=Response)
56+
mock_response.output = [mock_output_item]
57+
mock_response.reasoning = None
58+
59+
result = extract_reasoning_from_response(mock_response)
60+
61+
assert result.reasoning is None
62+
assert mock_content.text == "This is the actual response"
63+
64+
65+
def test_extract_reasoning_from_response_empty_output() -> None:
66+
"""Test that responses with empty output are handled gracefully."""
67+
mock_response = Mock(spec=Response)
68+
mock_response.output = []
69+
mock_response.reasoning = None
70+
71+
result = extract_reasoning_from_response(mock_response)
72+
73+
assert result.reasoning is None
74+
assert result == mock_response

0 commit comments

Comments
 (0)