Skip to content

Commit 773dfd8

Browse files
dirkbrndmanuhortetysolanky
authored
fix: Model Response Unbound issue (#5303)
## Summary This fixes issues with assigning streaming values to the assistant message object. Fixes #5298 ## Type of change - [x] Bug fix - [ ] New feature - [ ] Breaking change - [ ] Improvement - [ ] Model update - [ ] Other: --- ## Checklist - [ ] Code complies with style guidelines - [ ] Ran format/validation scripts (`./scripts/format.sh` and `./scripts/validate.sh`) - [ ] Self-review completed - [ ] Documentation updated (comments, docstrings) - [ ] Examples and guides: Relevant cookbook examples have been included or updated (if applicable) - [ ] Tested in clean environment - [ ] Tests added/updated (if applicable) --- ## Additional Notes Add any important context (deployment instructions, screenshots, security considerations, etc.) --------- Co-authored-by: manu <[email protected]> Co-authored-by: ysolanky <[email protected]>
1 parent 2dec373 commit 773dfd8

File tree

7 files changed

+139
-115
lines changed

7 files changed

+139
-115
lines changed

cookbook/models/openai/chat/audio_output_stream.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Iterator
44

55
from agno.agent import Agent, RunOutputEvent # noqa
6+
from agno.db.in_memory import InMemoryDb
67
from agno.models.openai import OpenAIChat
78

89
# Audio Configuration
@@ -20,6 +21,7 @@
2021
"format": "pcm16",
2122
}, # Only pcm16 is supported with streaming
2223
),
24+
db=InMemoryDb(),
2325
)
2426
output_stream: Iterator[RunOutputEvent] = agent.run(
2527
"Tell me a 10 second story", stream=True
@@ -48,3 +50,6 @@
4850
print(f"Error decoding audio: {e}")
4951
print()
5052
print(f"Saved audio to {filename}")
53+
54+
print("Metrics:")
55+
print(agent.get_last_run_output().metrics)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import Iterator # noqa
2+
from agno.agent import Agent, RunOutputEvent # noqa
3+
from agno.models.openai import OpenAIChat
4+
from agno.db.in_memory import InMemoryDb
5+
6+
agent = Agent(model=OpenAIChat(id="gpt-4o"), db=InMemoryDb(), markdown=True)
7+
8+
# Get the response in a variable
9+
# run_response: Iterator[RunOutputEvent] = agent.run("Share a 2 sentence horror story", stream=True)
10+
# for chunk in run_response:
11+
# print(chunk.content)
12+
13+
# Print the response in the terminal
14+
agent.print_response("Share a 2 sentence horror story", stream=True)
15+
16+
run_output = agent.get_last_run_output()
17+
print("Metrics:")
18+
print(run_output.metrics)
19+
20+
print("Message Metrics:")
21+
for message in run_output.messages:
22+
if message.role == "assistant":
23+
print(message.role)
24+
print(message.metrics)

libs/agno/agno/models/base.py

Lines changed: 52 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class MessageData:
5353
response_video: Optional[Video] = None
5454
response_file: Optional[File] = None
5555

56+
response_metrics: Optional[Metrics] = None
57+
5658
# Data from the provider that we might need on subsequent messages
5759
response_provider_data: Optional[Dict[str, Any]] = None
5860

@@ -759,7 +761,6 @@ def _populate_assistant_message(
759761
Returns:
760762
Message: The populated assistant message
761763
"""
762-
# Add role to assistant message
763764
if provider_response.role is not None:
764765
assistant_message.role = provider_response.role
765766

@@ -837,14 +838,14 @@ def process_response_stream(
837838
tool_choice=tool_choice or self._tool_choice,
838839
run_response=run_response,
839840
):
840-
yield from self._populate_stream_data_and_assistant_message(
841+
for model_response_delta in self._populate_stream_data(
841842
stream_data=stream_data,
842-
assistant_message=assistant_message,
843843
model_response_delta=response_delta,
844-
)
844+
):
845+
yield model_response_delta
845846

846-
# Add final metrics to assistant message
847-
self._populate_assistant_message(assistant_message=assistant_message, provider_response=response_delta)
847+
# Populate assistant message from stream data after the stream ends
848+
self._populate_assistant_message_from_stream_data(assistant_message=assistant_message, stream_data=stream_data)
848849

849850
def response_stream(
850851
self,
@@ -908,22 +909,6 @@ def response_stream(
908909
streaming_responses.append(response)
909910
yield response
910911

911-
# Populate assistant message from stream data
912-
if stream_data.response_content:
913-
assistant_message.content = stream_data.response_content
914-
if stream_data.response_reasoning_content:
915-
assistant_message.reasoning_content = stream_data.response_reasoning_content
916-
if stream_data.response_redacted_reasoning_content:
917-
assistant_message.redacted_reasoning_content = stream_data.response_redacted_reasoning_content
918-
if stream_data.response_provider_data:
919-
assistant_message.provider_data = stream_data.response_provider_data
920-
if stream_data.response_citations:
921-
assistant_message.citations = stream_data.response_citations
922-
if stream_data.response_audio:
923-
assistant_message.audio_output = stream_data.response_audio
924-
if stream_data.response_tool_calls and len(stream_data.response_tool_calls) > 0:
925-
assistant_message.tool_calls = self.parse_tool_calls(stream_data.response_tool_calls)
926-
927912
else:
928913
self._process_model_response(
929914
messages=messages,
@@ -1035,15 +1020,14 @@ async def aprocess_response_stream(
10351020
tool_choice=tool_choice or self._tool_choice,
10361021
run_response=run_response,
10371022
): # type: ignore
1038-
for model_response in self._populate_stream_data_and_assistant_message(
1023+
for model_response_delta in self._populate_stream_data(
10391024
stream_data=stream_data,
1040-
assistant_message=assistant_message,
10411025
model_response_delta=response_delta,
10421026
):
1043-
yield model_response
1027+
yield model_response_delta
10441028

1045-
# Populate the assistant message
1046-
self._populate_assistant_message(assistant_message=assistant_message, provider_response=model_response)
1029+
# Populate assistant message from stream data after the stream ends
1030+
self._populate_assistant_message_from_stream_data(assistant_message=assistant_message, stream_data=stream_data)
10471031

10481032
async def aresponse_stream(
10491033
self,
@@ -1107,20 +1091,6 @@ async def aresponse_stream(
11071091
streaming_responses.append(model_response)
11081092
yield model_response
11091093

1110-
# Populate assistant message from stream data
1111-
if stream_data.response_content:
1112-
assistant_message.content = stream_data.response_content
1113-
if stream_data.response_reasoning_content:
1114-
assistant_message.reasoning_content = stream_data.response_reasoning_content
1115-
if stream_data.response_redacted_reasoning_content:
1116-
assistant_message.redacted_reasoning_content = stream_data.response_redacted_reasoning_content
1117-
if stream_data.response_provider_data:
1118-
assistant_message.provider_data = stream_data.response_provider_data
1119-
if stream_data.response_audio:
1120-
assistant_message.audio_output = stream_data.response_audio
1121-
if stream_data.response_tool_calls and len(stream_data.response_tool_calls) > 0:
1122-
assistant_message.tool_calls = self.parse_tool_calls(stream_data.response_tool_calls)
1123-
11241094
else:
11251095
await self._aprocess_model_response(
11261096
messages=messages,
@@ -1212,15 +1182,51 @@ async def aresponse_stream(
12121182
if self.cache_response and cache_key and streaming_responses:
12131183
self._save_streaming_responses_to_cache(cache_key, streaming_responses)
12141184

1215-
def _populate_stream_data_and_assistant_message(
1216-
self, stream_data: MessageData, assistant_message: Message, model_response_delta: ModelResponse
1185+
def _populate_assistant_message_from_stream_data(
1186+
self, assistant_message: Message, stream_data: MessageData
1187+
) -> None:
1188+
"""
1189+
Populate an assistant message with the stream data.
1190+
"""
1191+
if stream_data.response_role is not None:
1192+
assistant_message.role = stream_data.response_role
1193+
if stream_data.response_metrics is not None:
1194+
assistant_message.metrics = stream_data.response_metrics
1195+
if stream_data.response_content:
1196+
assistant_message.content = stream_data.response_content
1197+
if stream_data.response_reasoning_content:
1198+
assistant_message.reasoning_content = stream_data.response_reasoning_content
1199+
if stream_data.response_redacted_reasoning_content:
1200+
assistant_message.redacted_reasoning_content = stream_data.response_redacted_reasoning_content
1201+
if stream_data.response_provider_data:
1202+
assistant_message.provider_data = stream_data.response_provider_data
1203+
if stream_data.response_citations:
1204+
assistant_message.citations = stream_data.response_citations
1205+
if stream_data.response_audio:
1206+
assistant_message.audio_output = stream_data.response_audio
1207+
if stream_data.response_image:
1208+
assistant_message.image_output = stream_data.response_image
1209+
if stream_data.response_video:
1210+
assistant_message.video_output = stream_data.response_video
1211+
if stream_data.response_file:
1212+
assistant_message.file_output = stream_data.response_file
1213+
if stream_data.response_tool_calls and len(stream_data.response_tool_calls) > 0:
1214+
assistant_message.tool_calls = self.parse_tool_calls(stream_data.response_tool_calls)
1215+
1216+
def _populate_stream_data(
1217+
self, stream_data: MessageData, model_response_delta: ModelResponse
12171218
) -> Iterator[ModelResponse]:
12181219
"""Update the stream data and assistant message with the model response."""
1219-
# Add role to assistant message
1220-
if model_response_delta.role is not None:
1221-
assistant_message.role = model_response_delta.role
12221220

12231221
should_yield = False
1222+
if model_response_delta.role is not None:
1223+
stream_data.response_role = model_response_delta.role # type: ignore
1224+
1225+
if model_response_delta.response_usage is not None:
1226+
if stream_data.response_metrics is None:
1227+
stream_data.response_metrics = Metrics()
1228+
stream_data.response_metrics += model_response_delta.response_usage
1229+
12241230
# Update stream_data content
12251231
if model_response_delta.content is not None:
12261232
stream_data.response_content += model_response_delta.content

libs/agno/agno/models/openai/responses.py

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from agno.exceptions import ModelProviderError
1010
from agno.media import File
11-
from agno.models.base import MessageData, Model
11+
from agno.models.base import Model
1212
from agno.models.message import Citations, Message, UrlCitation
1313
from agno.models.metrics import Metrics
1414
from agno.models.response import ModelResponse
@@ -810,63 +810,6 @@ def format_function_call_results(
810810
_fc_message.tool_call_id = tool_call_ids[_fc_message_index]
811811
messages.append(_fc_message)
812812

813-
def process_response_stream(
814-
self,
815-
messages: List[Message],
816-
assistant_message: Message,
817-
stream_data: MessageData,
818-
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
819-
tools: Optional[List[Dict[str, Any]]] = None,
820-
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
821-
run_response: Optional[RunOutput] = None,
822-
) -> Iterator[ModelResponse]:
823-
"""Process the synchronous response stream."""
824-
for model_response_delta in self.invoke_stream(
825-
messages=messages,
826-
assistant_message=assistant_message,
827-
tools=tools,
828-
response_format=response_format,
829-
tool_choice=tool_choice,
830-
run_response=run_response,
831-
):
832-
yield from self._populate_stream_data_and_assistant_message(
833-
stream_data=stream_data,
834-
assistant_message=assistant_message,
835-
model_response_delta=model_response_delta,
836-
)
837-
838-
# Add final metrics to assistant message
839-
self._populate_assistant_message(assistant_message=assistant_message, provider_response=model_response_delta)
840-
841-
async def aprocess_response_stream(
842-
self,
843-
messages: List[Message],
844-
assistant_message: Message,
845-
stream_data: MessageData,
846-
response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
847-
tools: Optional[List[Dict[str, Any]]] = None,
848-
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
849-
run_response: Optional[RunOutput] = None,
850-
) -> AsyncIterator[ModelResponse]:
851-
"""Process the asynchronous response stream."""
852-
async for model_response_delta in self.ainvoke_stream(
853-
messages=messages,
854-
assistant_message=assistant_message,
855-
tools=tools,
856-
response_format=response_format,
857-
tool_choice=tool_choice,
858-
run_response=run_response,
859-
):
860-
for model_response in self._populate_stream_data_and_assistant_message(
861-
stream_data=stream_data,
862-
assistant_message=assistant_message,
863-
model_response_delta=model_response_delta,
864-
):
865-
yield model_response
866-
867-
# Add final metrics to assistant message
868-
self._populate_assistant_message(assistant_message=assistant_message, provider_response=model_response_delta)
869-
870813
def _parse_provider_response(self, response: Response, **kwargs) -> ModelResponse:
871814
"""
872815
Parse the OpenAI response into a ModelResponse.

libs/agno/agno/run/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def to_dict(self) -> Dict[str, Any]:
110110

111111
if hasattr(self, "tools") and self.tools is not None:
112112
from agno.models.response import ToolExecution
113-
113+
114114
_dict["tools"] = []
115115
for tool in self.tools:
116116
if isinstance(tool, ToolExecution):
@@ -120,7 +120,7 @@ def to_dict(self) -> Dict[str, Any]:
120120

121121
if hasattr(self, "tool") and self.tool is not None:
122122
from agno.models.response import ToolExecution
123-
123+
124124
if isinstance(self.tool, ToolExecution):
125125
_dict["tool"] = self.tool.to_dict()
126126
else:
@@ -155,7 +155,7 @@ def from_dict(cls, data: Dict[str, Any]):
155155
tool = data.pop("tool", None)
156156
if tool:
157157
from agno.models.response import ToolExecution
158-
158+
159159
data["tool"] = ToolExecution.from_dict(tool)
160160

161161
images = data.pop("images", None)

libs/agno/tests/integration/models/openai/chat/test_basic.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,25 @@ def test_basic(openai_model):
4040
_assert_metrics(response)
4141

4242

43-
def test_basic_stream(openai_model):
44-
agent = Agent(model=openai_model, markdown=True, telemetry=False)
43+
def test_basic_stream(openai_model, shared_db):
44+
agent = Agent(model=openai_model, db=shared_db, markdown=True, telemetry=False)
4545

4646
run_stream = agent.run("Say 'hi'", stream=True)
4747
for chunk in run_stream:
4848
assert chunk.content is not None
4949

50+
run_output = agent.get_last_run_output()
51+
52+
assert run_output.content is not None
53+
assert run_output.messages is not None
54+
assert len(run_output.messages) == 3
55+
assert [m.role for m in run_output.messages] == ["system", "user", "assistant"]
56+
assert run_output.messages[2].content is not None
57+
assert run_output.messages[2].role == "assistant"
58+
assert run_output.messages[2].metrics.input_tokens is not None
59+
assert run_output.messages[2].metrics.output_tokens is not None
60+
assert run_output.messages[2].metrics.total_tokens is not None
61+
5062

5163
@pytest.mark.asyncio
5264
async def test_async_basic(openai_model):
@@ -62,12 +74,24 @@ async def test_async_basic(openai_model):
6274

6375

6476
@pytest.mark.asyncio
65-
async def test_async_basic_stream(openai_model):
66-
agent = Agent(model=openai_model, markdown=True, telemetry=False)
77+
async def test_async_basic_stream(openai_model, shared_db):
78+
agent = Agent(model=openai_model, db=shared_db, markdown=True, telemetry=False)
6779

6880
async for response in agent.arun("Share a 2 sentence horror story", stream=True):
6981
assert response.content is not None
7082

83+
run_output = agent.get_last_run_output()
84+
85+
assert run_output.content is not None
86+
assert run_output.messages is not None
87+
assert len(run_output.messages) == 3
88+
assert [m.role for m in run_output.messages] == ["system", "user", "assistant"]
89+
assert run_output.messages[2].content is not None
90+
assert run_output.messages[2].role == "assistant"
91+
assert run_output.messages[2].metrics.input_tokens is not None
92+
assert run_output.messages[2].metrics.output_tokens is not None
93+
assert run_output.messages[2].metrics.total_tokens is not None
94+
7195

7296
def test_exception_handling():
7397
agent = Agent(model=OpenAIChat(id="gpt-100"), markdown=True, telemetry=False)

0 commit comments

Comments
 (0)