diff --git a/src/any_llm/providers/fireworks/fireworks.py b/src/any_llm/providers/fireworks/fireworks.py index 50f78ed0..4b1eb70e 100644 --- a/src/any_llm/providers/fireworks/fireworks.py +++ b/src/any_llm/providers/fireworks/fireworks.py @@ -4,9 +4,10 @@ from openai import AsyncStream from any_llm.providers.openai.base import BaseOpenAIProvider -from any_llm.types.completion import Reasoning from any_llm.types.responses import Response, ResponsesParams, ResponseStreamEvent +from .utils import extract_reasoning_from_response + class FireworksProvider(BaseOpenAIProvider): PROVIDER_NAME = "fireworks" @@ -27,13 +28,6 @@ async def _aresponses( ) -> Response | AsyncIterator[ResponseStreamEvent]: """Call Fireworks Responses API and normalize into ChatCompletion/Chunks.""" response = await super()._aresponses(params, **kwargs) - if isinstance(response, Response) and not isinstance(response, AsyncStream): - # See https://fireworks.ai/blog/response-api for details about Fireworks Responses API support - reasoning = response.output[-1].content[0].text.split("")[-1] # type: ignore[union-attr,index] - if reasoning: - reasoning = reasoning.strip() - response.output[-1].content[0].text = response.output[-1].content[0].text.split("")[0] # type: ignore[union-attr,index] - response.reasoning = Reasoning(content=reasoning) if reasoning else None # type: ignore[assignment] - + return extract_reasoning_from_response(response) return response diff --git a/src/any_llm/providers/fireworks/utils.py b/src/any_llm/providers/fireworks/utils.py new file mode 100644 index 00000000..16ffd088 --- /dev/null +++ b/src/any_llm/providers/fireworks/utils.py @@ -0,0 +1,29 @@ +# mypy: disable-error-code="union-attr" +from any_llm.types.completion import Reasoning +from any_llm.types.responses import Response + + +def extract_reasoning_from_response(response: Response) -> Response: + """Extract content from Fireworks response and set reasoning field. + + Fireworks Responses API includes reasoning content within tags. + This function extracts that content and moves it to the reasoning field. + + Args: + response: The Response object to process + + Returns: + The modified Response object with reasoning extracted + """ + if not response.output or not response.output[-1].content: + return response + + content_text = response.output[-1].content[0].text + if "" in content_text and "" in content_text: + reasoning = content_text.split("")[1].split("")[0].strip() + # Skip case where reasoning is empty but tags are present + if reasoning: + response.reasoning = Reasoning(content=reasoning) # type: ignore[assignment] + response.output[-1].content[0].text = content_text.split("")[1].strip() + + return response diff --git a/tests/unit/providers/test_fireworks_provider.py b/tests/unit/providers/test_fireworks_provider.py new file mode 100644 index 00000000..df37b6f5 --- /dev/null +++ b/tests/unit/providers/test_fireworks_provider.py @@ -0,0 +1,74 @@ +from unittest.mock import Mock + +from any_llm.providers.fireworks.utils import extract_reasoning_from_response +from any_llm.types.completion import Reasoning +from any_llm.types.responses import Response + + +def test_extract_reasoning_from_response_with_think_tags() -> None: + """Test that content is correctly extracted into reasoning field.""" + # Create a mock Response with tags in content + mock_content = Mock() + mock_content.text = "This is my reasoning processThis is the actual response" + + mock_output_item = Mock() + mock_output_item.content = [mock_content] + + mock_response = Mock(spec=Response) + mock_response.output = [mock_output_item] + mock_response.reasoning = None + + result = extract_reasoning_from_response(mock_response) + + assert result.reasoning is not None + assert isinstance(result.reasoning, Reasoning) + assert result.reasoning.content == "This is my reasoning process" + assert mock_content.text == "This is the actual response" + + +def test_extract_reasoning_from_response_without_think_tags() -> None: + """Test that responses without tags are returned unchanged.""" + mock_content = Mock() + mock_content.text = "This is just a regular response" + + mock_output_item = Mock() + mock_output_item.content = [mock_content] + + mock_response = Mock(spec=Response) + mock_response.output = [mock_output_item] + mock_response.reasoning = None + + result = extract_reasoning_from_response(mock_response) + + assert result.reasoning is None + assert mock_content.text == "This is just a regular response" + + +def test_extract_reasoning_from_response_empty_reasoning() -> None: + """Test that empty reasoning content is handled correctly.""" + mock_content = Mock() + mock_content.text = "This is the actual response" + + mock_output_item = Mock() + mock_output_item.content = [mock_content] + + mock_response = Mock(spec=Response) + mock_response.output = [mock_output_item] + mock_response.reasoning = None + + result = extract_reasoning_from_response(mock_response) + + assert result.reasoning is None + assert mock_content.text == "This is the actual response" + + +def test_extract_reasoning_from_response_empty_output() -> None: + """Test that responses with empty output are handled gracefully.""" + mock_response = Mock(spec=Response) + mock_response.output = [] + mock_response.reasoning = None + + result = extract_reasoning_from_response(mock_response) + + assert result.reasoning is None + assert result == mock_response