Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 32 additions & 33 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,9 @@ def populate_template(template: str, variables: dict[str, Any]) -> str:


@dataclass
class FinalOutput:
output: Any | None
class ActionOutput:
Copy link

Copilot AI Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ActionOutput class is missing the @dataclass decorator, so instances won’t accept constructor arguments. Add @dataclass above this class definition.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it isn't, go home copilot you're drunk

output: Any
is_final_answer: bool


class PlanningPromptTemplate(TypedDict):
Expand Down Expand Up @@ -280,7 +281,7 @@ def __init__(
self.name = self._validate_name(name)
self.description = description
self.provide_run_summary = provide_run_summary
self.final_answer_checks = final_answer_checks
self.final_answer_checks = final_answer_checks if final_answer_checks is not None else []
self.return_full_result = return_full_result
self.instructions = instructions
self._setup_managed_agents(managed_agents)
Expand Down Expand Up @@ -451,9 +452,9 @@ def run(
def _run_stream(
self, task: str, max_steps: int, images: list["PIL.Image.Image"] | None = None
) -> Generator[ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta]:
final_answer = None
self.step_number = 1
while final_answer is None and self.step_number <= max_steps:
returned_final_answer = False
while not returned_final_answer and self.step_number <= max_steps:
if self.interrupt_switch:
raise AgentError("Agent interrupted.", self.logger)

Expand All @@ -464,8 +465,8 @@ def _run_stream(
planning_start_time = time.time()
planning_step = None
for element in self._generate_planning_step(
task, is_first_step=(len(self.memory.steps) == 1), step=self.step_number
):
task, is_first_step=len(self.memory.steps) == 1, step=self.step_number
): # Don't use the attribute step_number here, because there can be steps from previous runs
yield element
planning_step = element
assert isinstance(planning_step, PlanningStep) # Last yielded element should be a PlanningStep
Expand All @@ -483,10 +484,19 @@ def _run_stream(
timing=Timing(start_time=action_step_start_time),
observations_images=images,
)
self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
try:
for el in self._execute_step(action_step):
yield el
final_answer = el
for output in self._step_stream(action_step):
# Yield streaming deltas
if not isinstance(output, ActionOutput):
yield output

if isinstance(output, ActionOutput) and output.is_final_answer:
if self.final_answer_checks:
self._validate_final_answer(output.output)
returned_final_answer = True
action_step.is_final_answer = True
final_answer = output.output
except AgentGenerationError as e:
# Agent generation errors are not caused by a Model error but an implementation error: so we should raise them and exit.
raise e
Expand All @@ -499,23 +509,11 @@ def _run_stream(
yield action_step
self.step_number += 1

if final_answer is None and self.step_number == max_steps + 1:
if not returned_final_answer and self.step_number == max_steps + 1:
final_answer = self._handle_max_steps_reached(task, images)
yield action_step
yield FinalAnswerStep(handle_agent_output_types(final_answer))

def _execute_step(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think porting this logic into a separate function obfuscates the logic more than it clarifies.

self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
for el in self._step_stream(memory_step):
final_answer = el
if isinstance(el, ChatMessageStreamDelta):
yield el
elif isinstance(el, FinalOutput):
final_answer = el.output
if self.final_answer_checks:
self._validate_final_answer(final_answer)
yield final_answer

def _validate_final_answer(self, final_answer: Any):
for check_function in self.final_answer_checks:
try:
Expand Down Expand Up @@ -674,7 +672,7 @@ def interrupt(self):

def write_memory_to_messages(
self,
summary_mode: bool | None = False,
summary_mode: bool = False,
) -> list[Message]:
"""
Reads past llm_outputs, actions, and observations or errors from the memory into a series of messages
Expand All @@ -686,7 +684,7 @@ def write_memory_to_messages(
messages.extend(memory_step.to_messages(summary_mode=summary_mode))
return messages

def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]:
def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Yields ChatMessageStreamDelta during the run if streaming is enabled.
Expand Down Expand Up @@ -1203,7 +1201,7 @@ def initialize_system_prompt(self) -> str:
)
return system_prompt

def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]:
def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Yields ChatMessageStreamDelta during the run if streaming is enabled.
Expand Down Expand Up @@ -1291,14 +1289,15 @@ def process_tool_calls(self, chat_message: ChatMessage, memory_step: ActionStep)
memory_step (`ActionStep)`: Memory ActionStep to update with results.

Yields:
`FinalOutput`: The final output of tool execution.
`ActionOutput`: The final output of tool execution.
"""
model_outputs = []
tool_calls = []
observations = []

final_answer_call = None
parallel_calls = []
assert chat_message.tool_calls is not None
for tool_call in chat_message.tool_calls:
tool_name = tool_call.function.name
tool_arguments = tool_call.function.arguments
Expand Down Expand Up @@ -1338,19 +1337,19 @@ def process_single_tool_call(call_info):
)
return observation

# Process non-final-answer tool calls in parallel
# Process tool calls in parallel
if parallel_calls:
if len(parallel_calls) == 1:
# If there's only one call, process it directly
observations.append(process_single_tool_call(parallel_calls[0]))
yield FinalOutput(output=None)
yield ActionOutput(output=None, is_final_answer=False)
else:
# If multiple tool calls, process them in parallel
with ThreadPoolExecutor(self.max_tool_threads) as executor:
futures = [executor.submit(process_single_tool_call, call_info) for call_info in parallel_calls]
for future in as_completed(futures):
observations.append(future.result())
yield FinalOutput(output=None)
yield ActionOutput(output=None, is_final_answer=False)

# Process final_answer call if present
if final_answer_call:
Expand Down Expand Up @@ -1380,7 +1379,7 @@ def process_single_tool_call(call_info):
level=LogLevel.INFO,
)
memory_step.action_output = final_answer
yield FinalOutput(output=final_answer)
yield ActionOutput(output=final_answer, is_final_answer=True)

# Update memory step with all results
if model_outputs:
Expand Down Expand Up @@ -1572,7 +1571,7 @@ def initialize_system_prompt(self) -> str:
)
return system_prompt

def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]:
def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Yields ChatMessageStreamDelta during the run if streaming is enabled.
Expand Down Expand Up @@ -1702,7 +1701,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDe
]
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
memory_step.action_output = output
yield FinalOutput(output=output if is_final_answer else None)
yield ActionOutput(output=output, is_final_answer=is_final_answer)

def to_dict(self) -> dict[str, Any]:
"""Convert the agent to a dictionary representation.
Expand Down
11 changes: 8 additions & 3 deletions src/smolagents/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,25 @@ class ActionStep(MemoryStep):
observations_images: list["PIL.Image.Image"] | None = None
action_output: Any = None
token_usage: TokenUsage | None = None
is_final_answer: bool = False

def dict(self):
# We overwrite the method to parse the tool_calls and action_output manually
return {
"step_number": self.step_number,
"timing": self.timing.dict(),
"model_input_messages": self.model_input_messages,
"tool_calls": [tc.dict() for tc in self.tool_calls] if self.tool_calls else [],
"timing": self.timing.dict(),
"token_usage": asdict(self.token_usage) if self.token_usage else None,
"step": self.step_number,
"error": self.error.dict() if self.error else None,
"model_output_message": self.model_output_message.dict() if self.model_output_message else None,
"model_output": self.model_output,
"observations": self.observations,
"observations_images": [image.tobytes() for image in self.observations_images]
if self.observations_images
else None,
"action_output": make_json_serializable(self.action_output),
"token_usage": asdict(self.token_usage) if self.token_usage else None,
"is_final_answer": self.is_final_answer,
}

def to_messages(self, summary_mode: bool = False) -> list[Message]:
Expand Down
27 changes: 20 additions & 7 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,17 @@ def check_always_fails(final_answer, agent_memory):
agent.run("Dummy task.")
assert "Error raised in check" in str(agent.write_memory_to_messages())

agent = CodeAgent(
model=FakeCodeModel(),
tools=[],
final_answer_checks=[lambda x, y: x == 7.2904],
verbosity_level=1000,
)
output = agent.run("Dummy task.")
assert output == 7.2904 # Check that output is correct
assert len([step for step in agent.memory.steps if isinstance(step, ActionStep)]) == 2
assert "Error raised in check" not in str(agent.write_memory_to_messages())
Copy link
Member

@albertvillanova albertvillanova Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test passes before the fixes introduced in this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, just fixed it to not be passing before!


def test_generation_errors_are_raised(self):
class FakeCodeModel(Model):
def generate(self, messages, stop_sequences=None):
Expand All @@ -640,19 +651,21 @@ def test_planning_step_with_injected_memory(self):
agent.memory.steps.append(previous_step)

# Run the agent
agent.run(task, reset=False)
agent.run(task, reset=False, max_steps=2)

# Verify that the planning step used update plan prompts
planning_steps = [step for step in agent.memory.steps if isinstance(step, PlanningStep)]
assert len(planning_steps) > 0

# Check that the planning step's model input messages contain the injected memory
planning_step = planning_steps[0]
assert len(planning_step.model_input_messages) == 3 # system message + memory messages + user message
assert planning_step.model_input_messages[0]["role"] == "system"
assert task in planning_step.model_input_messages[0]["content"][0]["text"]
assert planning_step.model_input_messages[1]["role"] == "user"
assert "Previous user request" in planning_step.model_input_messages[1]["content"][0]["text"]
update_plan_step = planning_steps[0]
assert (
len(update_plan_step.model_input_messages) == 3
) # system message + memory messages (1 task message, the latest one is removed) + user message
assert update_plan_step.model_input_messages[0]["role"] == "system"
assert task in update_plan_step.model_input_messages[0]["content"][0]["text"]
assert update_plan_step.model_input_messages[1]["role"] == "user"
assert "Previous user request" in update_plan_step.model_input_messages[1]["content"][0]["text"]


class CustomFinalAnswerTool(FinalAnswerTool):
Expand Down
13 changes: 8 additions & 5 deletions tests/test_memory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from PIL import Image

from smolagents.agents import ToolCall
from smolagents.memory import (
Expand Down Expand Up @@ -50,7 +51,7 @@ def test_action_step_dict():
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Hi"),
model_output="Hi",
observations="This is a nice observation",
observations_images=["image1.png"],
observations_images=[Image.new("RGB", (100, 100))],
action_output="Output",
token_usage=TokenUsage(input_tokens=10, output_tokens=20),
)
Expand All @@ -76,8 +77,8 @@ def test_action_step_dict():
assert "token_usage" in action_step_dict
assert action_step_dict["token_usage"] == {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}

assert "step" in action_step_dict
assert action_step_dict["step"] == 1
assert "step_number" in action_step_dict
assert action_step_dict["step_number"] == 1

assert "error" in action_step_dict
assert action_step_dict["error"] is None
Expand All @@ -97,6 +98,8 @@ def test_action_step_dict():
assert "observations" in action_step_dict
assert action_step_dict["observations"] == "This is a nice observation"

assert "observations_images" in action_step_dict

assert "action_output" in action_step_dict
assert action_step_dict["action_output"] == "Output"

Expand All @@ -113,7 +116,7 @@ def test_action_step_to_messages():
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Hi"),
model_output="Hi",
observations="This is a nice observation",
observations_images=["image1.png"],
observations_images=[Image.new("RGB", (100, 100))],
action_output="Output",
token_usage=TokenUsage(input_tokens=10, output_tokens=20),
)
Expand Down Expand Up @@ -197,7 +200,7 @@ def test_planning_step_to_messages():


def test_task_step_to_messages():
task_step = TaskStep(task="This is a task.", task_images=["task_image1.png"])
task_step = TaskStep(task="This is a task.", task_images=[Image.new("RGB", (100, 100))])
messages = task_step.to_messages(summary_mode=False)
assert len(messages) == 1
for message in messages:
Expand Down