Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 42 additions & 37 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,9 @@ def populate_template(template: str, variables: dict[str, Any]) -> str:


@dataclass
class FinalOutput:
output: Any | None
class ActionOutput:
Copy link

Copilot AI Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ActionOutput class is missing the @dataclass decorator, so instances won’t accept constructor arguments. Add @dataclass above this class definition.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it isn't, go home copilot you're drunk

output: Any
is_final_answer: bool


class PlanningPromptTemplate(TypedDict):
Expand Down Expand Up @@ -212,6 +213,7 @@ class MultiStepAgent(ABC):
tools (`list[Tool]`): [`Tool`]s that the agent can use.
model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions.
prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates.
instructions (`str`, *optional*): Custom instructions for the agent, will be inserted in the system prompt.
max_steps (`int`, default `20`): Maximum number of steps the agent can take to solve the task.
add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools.
verbosity_level (`LogLevel`, default `LogLevel.INFO`): Level of verbosity of the agent's logs.
Expand All @@ -236,6 +238,7 @@ def __init__(
tools: list[Tool],
model: Model,
prompt_templates: PromptTemplates | None = None,
instructions: str | None = None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is from #1442, will not appear after merging pr 1442.

max_steps: int = 20,
add_base_tools: bool = False,
verbosity_level: LogLevel = LogLevel.INFO,
Expand Down Expand Up @@ -278,9 +281,9 @@ def __init__(
self.name = self._validate_name(name)
self.description = description
self.provide_run_summary = provide_run_summary
self.final_answer_checks = final_answer_checks
self.final_answer_checks = final_answer_checks if final_answer_checks is not None else []
self.return_full_result = return_full_result

self.instructions = instructions
self._setup_managed_agents(managed_agents)
self._setup_tools(tools, add_base_tools)
self._validate_tools_and_managed_agents(tools, managed_agents)
Expand Down Expand Up @@ -449,9 +452,9 @@ def run(
def _run_stream(
self, task: str, max_steps: int, images: list["PIL.Image.Image"] | None = None
) -> Generator[ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta]:
final_answer = None
self.step_number = 1
while final_answer is None and self.step_number <= max_steps:
returned_final_answer = False
while not returned_final_answer and self.step_number <= max_steps:
if self.interrupt_switch:
raise AgentError("Agent interrupted.", self.logger)

Expand All @@ -462,7 +465,7 @@ def _run_stream(
planning_start_time = time.time()
planning_step = None
for element in self._generate_planning_step(
task, is_first_step=(len(self.memory.steps) == 1), step=self.step_number
task, is_first_step=(self.step_number == 1), step=self.step_number
):
yield element
planning_step = element
Expand All @@ -481,10 +484,19 @@ def _run_stream(
timing=Timing(start_time=action_step_start_time),
observations_images=images,
)
self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
try:
for el in self._execute_step(action_step):
yield el
final_answer = el
for output in self._step_stream(action_step):
# Yield streaming deltas
if not isinstance(output, ActionOutput):
yield output

if isinstance(output, ActionOutput) and output.is_final_answer:
if self.final_answer_checks:
self._validate_final_answer(output.output)
returned_final_answer = True
action_step.is_final_answer = True
final_answer = output.output
except AgentGenerationError as e:
# Agent generation errors are not caused by a Model error but an implementation error: so we should raise them and exit.
raise e
Expand All @@ -497,23 +509,11 @@ def _run_stream(
yield action_step
self.step_number += 1

if final_answer is None and self.step_number == max_steps + 1:
if not returned_final_answer and self.step_number == max_steps + 1:
final_answer = self._handle_max_steps_reached(task, images)
yield action_step
yield FinalAnswerStep(handle_agent_output_types(final_answer))

def _execute_step(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think porting this logic into a separate function obfuscates the logic more than it clarifies.

self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
for el in self._step_stream(memory_step):
final_answer = el
if isinstance(el, ChatMessageStreamDelta):
yield el
elif isinstance(el, FinalOutput):
final_answer = el.output
if self.final_answer_checks:
self._validate_final_answer(final_answer)
yield final_answer

def _validate_final_answer(self, final_answer: Any):
for check_function in self.final_answer_checks:
try:
Expand Down Expand Up @@ -617,8 +617,7 @@ def _generate_planning_step(
}
],
}
# remove last message from memory_messages because it is the current task
input_messages = [plan_update_pre] + memory_messages[:-1] + [plan_update_post]
input_messages = [plan_update_pre] + memory_messages + [plan_update_post]
if self.stream_outputs and hasattr(self.model, "generate_stream"):
plan_message_content = ""
input_tokens, output_tokens = 0, 0
Expand Down Expand Up @@ -672,7 +671,7 @@ def interrupt(self):

def write_memory_to_messages(
self,
summary_mode: bool | None = False,
summary_mode: bool = False,
) -> list[Message]:
"""
Reads past llm_outputs, actions, and observations or errors from the memory into a series of messages
Expand All @@ -684,7 +683,7 @@ def write_memory_to_messages(
messages.extend(memory_step.to_messages(summary_mode=summary_mode))
return messages

def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]:
def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Yields ChatMessageStreamDelta during the run if streaming is enabled.
Expand Down Expand Up @@ -1112,7 +1111,7 @@ def push_to_hub(
token (`bool` or `str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
create_pr (`bool`, *optional*, defaults to `False`):
create_pr (`bool`, *optional*, defaults to `False`)
Whether to create a PR with the uploaded files or directly commit.
"""
repo_url = create_repo(
Expand Down Expand Up @@ -1193,11 +1192,15 @@ def __init__(
def initialize_system_prompt(self) -> str:
system_prompt = populate_template(
self.prompt_templates["system_prompt"],
variables={"tools": self.tools, "managed_agents": self.managed_agents},
variables={
"tools": self.tools,
"managed_agents": self.managed_agents,
"custom_instructions": self.instructions,
},
)
return system_prompt

def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]:
def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Yields ChatMessageStreamDelta during the run if streaming is enabled.
Expand Down Expand Up @@ -1285,14 +1288,15 @@ def process_tool_calls(self, chat_message: ChatMessage, memory_step: ActionStep)
memory_step (`ActionStep)`: Memory ActionStep to update with results.

Yields:
`FinalOutput`: The final output of tool execution.
`ActionOutput`: The final output of tool execution.
"""
model_outputs = []
tool_calls = []
observations = []

final_answer_call = None
parallel_calls = []
assert chat_message.tool_calls is not None
for tool_call in chat_message.tool_calls:
tool_name = tool_call.function.name
tool_arguments = tool_call.function.arguments
Expand Down Expand Up @@ -1332,19 +1336,19 @@ def process_single_tool_call(call_info):
)
return observation

# Process non-final-answer tool calls in parallel
# Process tool calls in parallel
if parallel_calls:
if len(parallel_calls) == 1:
# If there's only one call, process it directly
observations.append(process_single_tool_call(parallel_calls[0]))
yield FinalOutput(output=None)
yield ActionOutput(output=None, is_final_answer=False)
else:
# If multiple tool calls, process them in parallel
with ThreadPoolExecutor(self.max_tool_threads) as executor:
futures = [executor.submit(process_single_tool_call, call_info) for call_info in parallel_calls]
for future in as_completed(futures):
observations.append(future.result())
yield FinalOutput(output=None)
yield ActionOutput(output=None, is_final_answer=False)

# Process final_answer call if present
if final_answer_call:
Expand Down Expand Up @@ -1374,7 +1378,7 @@ def process_single_tool_call(call_info):
level=LogLevel.INFO,
)
memory_step.action_output = final_answer
yield FinalOutput(output=final_answer)
yield ActionOutput(output=final_answer, is_final_answer=True)

# Update memory step with all results
if model_outputs:
Expand Down Expand Up @@ -1561,11 +1565,12 @@ def initialize_system_prompt(self) -> str:
if "*" in self.authorized_imports
else str(self.authorized_imports)
),
"custom_instructions": self.instructions,
},
)
return system_prompt

def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]:
def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
Yields ChatMessageStreamDelta during the run if streaming is enabled.
Expand Down Expand Up @@ -1695,7 +1700,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDe
]
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
memory_step.action_output = output
yield FinalOutput(output=output if is_final_answer else None)
yield ActionOutput(output=output, is_final_answer=is_final_answer)

def to_dict(self) -> dict[str, Any]:
"""Convert the agent to a dictionary representation.
Expand Down
11 changes: 8 additions & 3 deletions src/smolagents/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,25 @@ class ActionStep(MemoryStep):
observations_images: list["PIL.Image.Image"] | None = None
action_output: Any = None
token_usage: TokenUsage | None = None
is_final_answer: bool = False

def dict(self):
# We overwrite the method to parse the tool_calls and action_output manually
return {
"step_number": self.step_number,
"timing": self.timing.dict(),
"model_input_messages": self.model_input_messages,
"tool_calls": [tc.dict() for tc in self.tool_calls] if self.tool_calls else [],
"timing": self.timing.dict(),
"token_usage": asdict(self.token_usage) if self.token_usage else None,
"step": self.step_number,
"error": self.error.dict() if self.error else None,
"model_output_message": self.model_output_message.dict() if self.model_output_message else None,
"model_output": self.model_output,
"observations": self.observations,
"observations_images": [image.tobytes() for image in self.observations_images]
if self.observations_images
else None,
"action_output": make_json_serializable(self.action_output),
"token_usage": asdict(self.token_usage) if self.token_usage else None,
"is_final_answer": self.is_final_answer,
}

def to_messages(self, summary_mode: bool = False) -> list[Message]:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,17 @@ def check_always_fails(final_answer, agent_memory):
agent.run("Dummy task.")
assert "Error raised in check" in str(agent.write_memory_to_messages())

agent = CodeAgent(
model=FakeCodeModel(),
tools=[],
final_answer_checks=[lambda x, y: x == 7.2904],
verbosity_level=1000,
)
output = agent.run("Dummy task.")
assert output == 7.2904 # Check that output is correct
assert len([step for step in agent.memory.steps if isinstance(step, ActionStep)]) == 2
assert "Error raised in check" not in str(agent.write_memory_to_messages())
Copy link
Member

@albertvillanova albertvillanova Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test passes before the fixes introduced in this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, just fixed it to not be passing before!


def test_generation_errors_are_raised(self):
class FakeCodeModel(Model):
def generate(self, messages, stop_sequences=None):
Expand Down
Loading