diff --git a/examples/agent_from_any_llm.py b/examples/agent_from_any_llm.py index d5e33f0a1..593c61b2d 100644 --- a/examples/agent_from_any_llm.py +++ b/examples/agent_from_any_llm.py @@ -1,19 +1,23 @@ -from smolagents import InferenceClientModel, LiteLLMModel, OpenAIServerModel, TransformersModel, tool -from smolagents.agents import CodeAgent, ToolCallingAgent +from smolagents import ( + CodeAgent, + InferenceClientModel, + LiteLLMModel, + OpenAIServerModel, + ToolCallingAgent, + TransformersModel, + tool, +) # Choose which inference type to use! -available_inferences = ["hf_api", "hf_api_provider", "transformers", "ollama", "litellm", "openai"] -chosen_inference = "hf_api_provider" +available_inferences = ["inference_client", "transformers", "ollama", "litellm", "openai"] +chosen_inference = "inference_client" print(f"Chose model: '{chosen_inference}'") -if chosen_inference == "hf_api": - model = InferenceClientModel(model_id="meta-llama/Llama-3.3-70B-Instruct") - -elif chosen_inference == "hf_api_provider": - model = InferenceClientModel(provider="together") +if chosen_inference == "inference_client": + model = InferenceClientModel(model_id="meta-llama/Llama-3.3-70B-Instruct", provider="nebius") elif chosen_inference == "transformers": model = TransformersModel(model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct", device_map="auto", max_new_tokens=1000) diff --git a/examples/gradio_ui.py b/examples/gradio_ui.py index 87f532689..fb69481ec 100644 --- a/examples/gradio_ui.py +++ b/examples/gradio_ui.py @@ -9,7 +9,8 @@ name="example_agent", description="This is an example agent.", step_callbacks=[], - stream_outputs=False, + stream_outputs=True, + return_full_result=True, ) GradioUI(agent, file_upload_folder="./data").launch() diff --git a/examples/inspect_multiagent_run.py b/examples/inspect_multiagent_run.py index c68dccb75..75bc0a07c 100644 --- a/examples/inspect_multiagent_run.py +++ b/examples/inspect_multiagent_run.py @@ -16,18 +16,24 @@ # Then we run the agentic part! -model = InferenceClientModel() +model = InferenceClientModel(provider="nebius") search_agent = ToolCallingAgent( tools=[WebSearchTool(), VisitWebpageTool()], model=model, name="search_agent", description="This is an agent that can do web search.", + return_full_result=True, ) manager_agent = CodeAgent( tools=[], model=model, managed_agents=[search_agent], + return_full_result=True, ) -manager_agent.run("If the US keeps it 2024 growth rate, how many years would it take for the GDP to double?") +run_result = manager_agent.run( + "If the US keeps it 2024 growth rate, how many years would it take for the GDP to double?" +) +print("Here is the token usage for the manager agent", run_result.token_usage) +print("Here are the timing informations for the manager agent:", run_result.timing) diff --git a/examples/multi_llm_agent.py b/examples/multi_llm_agent.py index 6f44ff8b4..5c002e1ea 100644 --- a/examples/multi_llm_agent.py +++ b/examples/multi_llm_agent.py @@ -39,6 +39,8 @@ model_list=llm_loadbalancer_model_list, client_kwargs={"routing_strategy": "simple-shuffle"}, ) -agent = CodeAgent(tools=[WebSearchTool()], model=model, stream_outputs=True) +agent = CodeAgent(tools=[WebSearchTool()], model=model, stream_outputs=True, return_full_results=True) -agent.run("How many seconds would it take for a leopard at full speed to run through Pont des Arts?") +full_result = agent.run("How many seconds would it take for a leopard at full speed to run through Pont des Arts?") + +print(full_result) diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py index 6b8c8d35b..9006db22d 100644 --- a/src/smolagents/agents.py +++ b/src/smolagents/agents.py @@ -24,9 +24,10 @@ import time from abc import ABC, abstractmethod from collections.abc import Callable, Generator +from dataclasses import dataclass from logging import getLogger from pathlib import Path -from typing import TYPE_CHECKING, Any, TypedDict +from typing import TYPE_CHECKING, Any, Literal, TypedDict import jinja2 import yaml @@ -54,6 +55,8 @@ PlanningStep, SystemPromptStep, TaskStep, + Timing, + TokenUsage, ToolCall, ) from .models import ChatMessage, ChatMessageStreamDelta, MessageRole, Model, parse_json_if_needed @@ -96,6 +99,11 @@ def populate_template(template: str, variables: dict[str, Any]) -> str: raise Exception(f"Error during jinja template rendering: {type(e).__name__}: {e}") +@dataclass +class FinalOutput: + output: Any | None + + class PlanningPromptTemplate(TypedDict): """ Prompt templates for the planning step. @@ -166,6 +174,25 @@ class PromptTemplates(TypedDict): ) +@dataclass +class RunResult: + """Holds extended information about an agent run. + + Attributes: + output (Any | None): The final output of the agent run, if available. + state (Literal["success", "max_steps_error"]): The final state of the agent after the run. + messages (list[dict]): The agent's memory, as a list of messages. + token_usage (TokenUsage | None): Count of tokens used during the run. + timing (Timing): Timing details of the agent run: start time, end time, duration. + """ + + output: Any | None + state: Literal["success", "max_steps_error"] + messages: list[dict] + token_usage: TokenUsage | None + timing: Timing + + class MultiStepAgent(ABC): """ Agent class that solves the given task step by step, using the ReAct framework: @@ -204,6 +231,7 @@ def __init__( description: str | None = None, provide_run_summary: bool = False, final_answer_checks: list[Callable] | None = None, + return_full_result: bool = False, logger: AgentLogger | None = None, ): self.agent_name = self.__class__.__name__ @@ -230,6 +258,7 @@ def __init__( self.description = description self.provide_run_summary = provide_run_summary self.final_answer_checks = final_answer_checks + self.return_full_result = return_full_result self._setup_managed_agents(managed_agents) self._setup_tools(tools, add_base_tools) @@ -347,28 +376,81 @@ def run( if stream: # The steps are returned as they are executed through a generator to iterate on. return self._run_stream(task=self.task, max_steps=max_steps, images=images) + run_start_time = time.time() # Outputs are returned only at the end. We only look at the last step. - return list(self._run_stream(task=self.task, max_steps=max_steps, images=images))[-1].final_answer + + steps = list(self._run_stream(task=self.task, max_steps=max_steps, images=images)) + assert isinstance(steps[-1], FinalAnswerStep) + output = steps[-1].output + + if self.return_full_result: + total_input_tokens = 0 + total_output_tokens = 0 + correct_token_usage = True + for step in self.memory.steps: + if isinstance(step, (ActionStep, PlanningStep)): + if step.token_usage is None: + correct_token_usage = False + break + else: + total_input_tokens += step.token_usage.input_tokens + total_output_tokens += step.token_usage.output_tokens + if correct_token_usage: + token_usage = TokenUsage(input_tokens=total_input_tokens, output_tokens=total_output_tokens) + else: + token_usage = None + + if self.memory.steps and isinstance(getattr(self.memory.steps[-1], "error", None), AgentMaxStepsError): + state = "max_steps_error" + else: + state = "success" + + messages = self.memory.get_full_steps() + + return RunResult( + output=output, + token_usage=token_usage, + messages=messages, + timing=Timing(start_time=run_start_time, end_time=time.time()), + state=state, + ) + + return output def _run_stream( self, task: str, max_steps: int, images: list["PIL.Image.Image"] | None = None - ) -> Generator[ActionStep | PlanningStep | FinalAnswerStep]: + ) -> Generator[ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta]: final_answer = None self.step_number = 1 while final_answer is None and self.step_number <= max_steps: if self.interrupt_switch: raise AgentError("Agent interrupted.", self.logger) - step_start_time = time.time() + + # Run a planning step if scheduled if self.planning_interval is not None and ( self.step_number == 1 or (self.step_number - 1) % self.planning_interval == 0 ): + planning_start_time = time.time() + planning_step = None for element in self._generate_planning_step( task, is_first_step=(self.step_number == 1), step=self.step_number ): yield element - self.memory.steps.append(element) + planning_step = element + assert isinstance(planning_step, PlanningStep) # Last yielded element should be a PlanningStep + self.memory.steps.append(planning_step) + planning_end_time = time.time() + planning_step.timing = Timing( + start_time=planning_start_time, + end_time=planning_end_time, + ) + + # Start action step! + action_step_start_time = time.time() action_step = ActionStep( - step_number=self.step_number, start_time=step_start_time, observations_images=images + step_number=self.step_number, + timing=Timing(start_time=action_step_start_time), + observations_images=images, ) try: for el in self._execute_step(action_step): @@ -381,25 +463,27 @@ def _run_stream( # Other AgentError types are caused by the Model, so we should log them and iterate. action_step.error = e finally: - self._finalize_step(action_step, step_start_time) + self._finalize_step(action_step) self.memory.steps.append(action_step) yield action_step self.step_number += 1 if final_answer is None and self.step_number == max_steps + 1: - final_answer = self._handle_max_steps_reached(task, images, step_start_time) + final_answer = self._handle_max_steps_reached(task, images) yield action_step yield FinalAnswerStep(handle_agent_output_types(final_answer)) - def _execute_step(self, memory_step: ActionStep) -> Generator[Any]: + def _execute_step(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]: self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO) - final_answer = None for el in self._step_stream(memory_step): final_answer = el - yield el - if final_answer is not None and self.final_answer_checks: - self._validate_final_answer(final_answer) - yield final_answer + if isinstance(el, ChatMessageStreamDelta): + yield el + elif isinstance(el, FinalOutput): + final_answer = el.output + if self.final_answer_checks: + self._validate_final_answer(final_answer) + yield final_answer def _validate_final_answer(self, final_answer: Any): for check_function in self.final_answer_checks: @@ -408,33 +492,32 @@ def _validate_final_answer(self, final_answer: Any): except Exception as e: raise AgentError(f"Check {check_function.__name__} failed with error: {e}", self.logger) - def _finalize_step(self, memory_step: ActionStep, step_start_time: float): - memory_step.end_time = time.time() - memory_step.duration = memory_step.end_time - step_start_time + def _finalize_step(self, memory_step: ActionStep): + memory_step.timing.end_time = time.time() for callback in self.step_callbacks: # For compatibility with old callbacks that don't take the agent as an argument callback(memory_step) if len(inspect.signature(callback).parameters) == 1 else callback( memory_step, agent=self ) - def _handle_max_steps_reached(self, task: str, images: list["PIL.Image.Image"], step_start_time: float) -> Any: + def _handle_max_steps_reached(self, task: str, images: list["PIL.Image.Image"]) -> Any: + action_step_start_time = time.time() final_answer = self.provide_final_answer(task, images) final_memory_step = ActionStep( - step_number=self.step_number, error=AgentMaxStepsError("Reached max steps.", self.logger) + step_number=self.step_number, + error=AgentMaxStepsError("Reached max steps.", self.logger), + timing=Timing(start_time=action_step_start_time, end_time=time.time()), + token_usage=final_answer.token_usage, ) - final_memory_step.action_output = final_answer - final_memory_step.end_time = time.time() - final_memory_step.duration = final_memory_step.end_time - step_start_time + final_memory_step.action_output = final_answer.content + self._finalize_step(final_memory_step) self.memory.steps.append(final_memory_step) - for callback in self.step_callbacks: - callback(final_memory_step) if len(inspect.signature(callback).parameters) == 1 else callback( - final_memory_step, agent=self - ) - return final_answer + return final_answer.content def _generate_planning_step( self, task, is_first_step: bool, step: int - ) -> Generator[ChatMessageStreamDelta, PlanningStep]: + ) -> Generator[ChatMessageStreamDelta | PlanningStep]: + start_time = time.time() if is_first_step: input_messages = [ { @@ -453,14 +536,23 @@ def _generate_planning_step( if self.stream_outputs and hasattr(self.model, "generate_stream"): plan_message_content = "" output_stream = self.model.generate_stream(input_messages, stop_sequences=[""]) # type: ignore + input_tokens, output_tokens = 0, 0 with Live("", console=self.logger.console, vertical_overflow="visible") as live: for event in output_stream: if event.content is not None: plan_message_content += event.content live.update(Markdown(plan_message_content)) + if event.token_usage: + output_tokens += event.token_usage.output_tokens + input_tokens = event.token_usage.input_tokens yield event else: - plan_message_content = self.model.generate(input_messages, stop_sequences=[""]).content + plan_message = self.model.generate(input_messages, stop_sequences=[""]) + plan_message_content = plan_message.content + input_tokens, output_tokens = ( + plan_message.token_usage.input_tokens, + plan_message.token_usage.output_tokens, + ) plan = textwrap.dedent( f"""Here are the facts I know and the plan of action that I will follow to solve the task:\n```\n{plan_message_content}\n```""" ) @@ -499,11 +591,26 @@ def _generate_planning_step( input_messages = [plan_update_pre] + memory_messages + [plan_update_post] if self.stream_outputs and hasattr(self.model, "generate_stream"): plan_message_content = "" - for completion_delta in self.model.generate_stream(input_messages, stop_sequences=[""]): # type: ignore - plan_message_content += completion_delta.content - yield completion_delta + input_tokens, output_tokens = 0, 0 + with Live("", console=self.logger.console, vertical_overflow="visible") as live: + for event in self.model.generate_stream( + input_messages, + stop_sequences=[""], + ): # type: ignore + if event.content is not None: + plan_message_content += event.content + live.update(Markdown(plan_message_content)) + if event.token_usage: + output_tokens += event.token_usage.output_tokens + input_tokens = event.token_usage.input_tokens + yield event else: - plan_message_content = self.model.generate(input_messages, stop_sequences=[""]).content + plan_message = self.model.generate(input_messages, stop_sequences=[""]) + plan_message_content = plan_message.content + input_tokens, output_tokens = ( + plan_message.token_usage.input_tokens, + plan_message.token_usage.output_tokens, + ) plan = textwrap.dedent( f"""I still need to solve the task I was given:\n```\n{self.task}\n```\n\nHere are the facts I know and my new/updated plan of action to solve the task:\n```\n{plan_message_content}\n```""" ) @@ -513,6 +620,8 @@ def _generate_planning_step( model_input_messages=input_messages, plan=plan, model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content=plan_message_content), + token_usage=TokenUsage(input_tokens=input_tokens, output_tokens=output_tokens), + timing=Timing(start_time=start_time, end_time=time.time()), ) @property @@ -545,7 +654,7 @@ def write_memory_to_messages( messages.extend(memory_step.to_messages(summary_mode=summary_mode)) return messages - def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: + def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]: """ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. Yields either None if the step is not final, or the final answer. @@ -580,7 +689,7 @@ def extract_action(self, model_output: str, split_token: str) -> tuple[str, str] ) return rationale.strip(), action.strip() - def provide_final_answer(self, task: str, images: list["PIL.Image.Image"] | None = None) -> str: + def provide_final_answer(self, task: str, images: list["PIL.Image.Image"] | None = None) -> ChatMessage: """ Provide the final answer to the task, based on the logs of the agent's interactions. @@ -619,8 +728,8 @@ def provide_final_answer(self, task: str, images: list["PIL.Image.Image"] | None } ] try: - chat_message: ChatMessage = self.model(messages) - return chat_message.content + chat_message: ChatMessage = self.model.generate(messages) + return chat_message except Exception as e: return f"Error in generating final LLM output:\n{e}" @@ -645,7 +754,11 @@ def __call__(self, task: str, **kwargs): self.prompt_templates["managed_agent"]["task"], variables=dict(name=self.name, task=task), ) - report = self.run(full_task, **kwargs) + result = self.run(full_task, **kwargs) + if isinstance(result, RunResult): + report = result.output + else: + report = result answer = populate_template( self.prompt_templates["managed_agent"]["report"], variables=dict(name=self.name, final_answer=report) ) @@ -1039,7 +1152,7 @@ def initialize_system_prompt(self) -> str: ) return system_prompt - def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: + def _step_stream(self, memory_step: ActionStep) -> Generator[FinalOutput]: """ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. Yields either None if the step is not final, or the final answer. @@ -1052,7 +1165,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: memory_step.model_input_messages = input_messages try: - chat_message: ChatMessage = self.model( + chat_message: ChatMessage = self.model.generate( input_messages, stop_sequences=["Observation:", "Calling tools:"], tools_to_call_from=list(self.tools.values()), @@ -1083,6 +1196,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: tool_arguments = tool_call.function.arguments memory_step.model_output = str(f"Called Tool: '{tool_name}' with arguments: {tool_arguments}") memory_step.tool_calls = [ToolCall(name=tool_name, arguments=tool_arguments, id=tool_call_id)] + memory_step.token_usage = chat_message.token_usage # Execute self.logger.log( @@ -1113,7 +1227,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: ) memory_step.action_output = final_answer - yield final_answer + yield FinalOutput(output=final_answer) else: if tool_arguments is None: tool_arguments = {} @@ -1135,7 +1249,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: level=LogLevel.INFO, ) memory_step.observations = updated_information - yield None + yield FinalOutput(output=None) def _substitute_state_variables(self, arguments: dict[str, str] | str) -> dict[str, Any] | str: """Replace string values in arguments with their corresponding state values if they exist.""" @@ -1303,10 +1417,11 @@ def initialize_system_prompt(self) -> str: ) return system_prompt - def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: + def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]: """ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. - Yields either None if the step is not final, or the final answer. + Yields ChatMessageStreamDelta during the run if streaming is enabled. + At the end, yields either None if the step is not final, or the final answer. """ memory_messages = self.write_memory_to_messages() @@ -1322,15 +1437,24 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: **additional_args, ) output_text = "" + input_tokens, output_tokens = 0, 0 with Live("", console=self.logger.console, vertical_overflow="visible") as live: for event in output_stream: if event.content is not None: output_text += event.content live.update(Markdown(output_text)) + if event.token_usage: + output_tokens += event.token_usage.output_tokens + input_tokens = event.token_usage.input_tokens + assert isinstance(event, ChatMessageStreamDelta) yield event model_output = output_text - chat_message = ChatMessage(role="assistant", content=model_output) + chat_message = ChatMessage( + role="assistant", + content=model_output, + token_usage=TokenUsage(input_tokens=input_tokens, output_tokens=output_tokens), + ) memory_step.model_output_message = chat_message model_output = chat_message.content else: @@ -1353,6 +1477,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: model_output += "" memory_step.model_output_message.content = model_output + memory_step.token_usage = chat_message.token_usage memory_step.model_output = model_output except Exception as e: raise AgentGenerationError(f"Error in generating model output:\n{e}", self.logger) from e @@ -1414,7 +1539,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: ] self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO) memory_step.action_output = output - yield output if is_final_answer else None + yield FinalOutput(output=output if is_final_answer else None) def to_dict(self) -> dict[str, Any]: """Convert the agent to a dictionary representation. diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py index 384f50bc4..e5e6ede1c 100644 --- a/src/smolagents/gradio_ui.py +++ b/src/smolagents/gradio_ui.py @@ -21,20 +21,17 @@ from smolagents.agent_types import AgentAudio, AgentImage, AgentText from smolagents.agents import MultiStepAgent, PlanningStep -from smolagents.memory import ActionStep, FinalAnswerStep, MemoryStep +from smolagents.memory import ActionStep, FinalAnswerStep from smolagents.models import ChatMessageStreamDelta from smolagents.utils import _is_package_available -def get_step_footnote_content(step_log: MemoryStep, step_name: str) -> str: +def get_step_footnote_content(step_log: ActionStep | PlanningStep, step_name: str) -> str: """Get a footnote string for a step log with duration and token information""" step_footnote = f"**{step_name}**" - if hasattr(step_log, "input_token_count") and hasattr(step_log, "output_token_count"): - token_str = f" | Input tokens: {step_log.input_token_count:,} | Output tokens: {step_log.output_token_count:,}" - step_footnote += token_str - if hasattr(step_log, "duration"): - step_duration = f" | Duration: {round(float(step_log.duration), 2)}" if step_log.duration else None - step_footnote += step_duration + if step_log.token_usage is not None: + step_footnote += f" | Input tokens: {step_log.token_usage.input_tokens:,} | Output tokens: {step_log.token_usage.output_tokens:,}" + step_footnote += f" | Duration: {round(float(step_log.timing.duration), 2)}s" if step_log.timing.duration else "" step_footnote_content = f"""{step_footnote} """ return step_footnote_content @@ -94,7 +91,7 @@ def _process_action_step(step_log: ActionStep, skip_model_outputs: bool = False) import gradio as gr # Output the step number - step_number = f"Step {step_log.step_number}" if step_log.step_number is not None else "Step" + step_number = f"Step {step_log.step_number}" if not skip_model_outputs: yield gr.ChatMessage(role="assistant", content=f"**{step_number}**", metadata={"status": "done"}) @@ -197,7 +194,7 @@ def _process_final_answer_step(step_log: FinalAnswerStep) -> Generator: """ import gradio as gr - final_answer = step_log.final_answer + final_answer = step_log.output if isinstance(final_answer, AgentText): yield gr.ChatMessage( role="assistant", @@ -222,8 +219,8 @@ def _process_final_answer_step(step_log: FinalAnswerStep) -> Generator: ) -def pull_messages_from_step(step_log: MemoryStep, skip_model_outputs: bool = False): - """Extract ChatMessage objects from agent steps with proper nesting. +def pull_messages_from_step(step_log: ActionStep | PlanningStep | FinalAnswerStep, skip_model_outputs: bool = False): + """Extract Gradio ChatMessage objects from agent steps with proper nesting. Args: step_log: The step log to display as gr.ChatMessage objects. @@ -250,32 +247,27 @@ def stream_to_gradio( task_images: list | None = None, reset_agent_memory: bool = False, additional_args: dict | None = None, -): +) -> Generator: """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" if not _is_package_available("gradio"): raise ModuleNotFoundError( "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`" ) intermediate_text = "" - for step_log in agent.run( + + for event in agent.run( task, images=task_images, stream=True, reset=reset_agent_memory, additional_args=additional_args ): - # Track tokens if model provides them - if getattr(agent.model, "last_input_token_count", None) is not None: - if isinstance(step_log, (ActionStep, PlanningStep)): - step_log.input_token_count = agent.model.last_input_token_count - step_log.output_token_count = agent.model.last_output_token_count - - if isinstance(step_log, MemoryStep): + if isinstance(event, ActionStep | PlanningStep | FinalAnswerStep): intermediate_text = "" for message in pull_messages_from_step( - step_log, + event, # If we're streaming model outputs, no need to display them twice skip_model_outputs=getattr(agent, "stream_outputs", False), ): yield message - elif isinstance(step_log, ChatMessageStreamDelta): - intermediate_text += step_log.content or "" + elif isinstance(event, ChatMessageStreamDelta): + intermediate_text += event.content or "" yield intermediate_text @@ -308,24 +300,20 @@ def interact_with_agent(self, prompt, messages, session_state): for msg in stream_to_gradio(session_state["agent"], task=prompt, reset_agent_memory=False): if isinstance(msg, gr.ChatMessage): + messages[-1].metadata["status"] = "done" messages.append(msg) elif isinstance(msg, str): # Then it's only a completion delta - try: - if messages[-1].metadata["status"] == "pending": - messages[-1].content = msg - else: - messages.append( - gr.ChatMessage(role="assistant", content=msg, metadata={"status": "pending"}) - ) - except Exception as e: - raise e + msg = msg.replace("<", r"\<").replace(">", r"\>") # HTML tags seem to break Gradio Chatbot + if messages[-1].metadata["status"] == "pending": + messages[-1].content = msg + else: + messages.append(gr.ChatMessage(role="assistant", content=msg, metadata={"status": "pending"})) yield messages yield messages except Exception as e: - print(f"Error in interaction: {str(e)}") - messages.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}")) yield messages + raise gr.Error(f"Error in interaction: {str(e)}") def upload_file(self, file, file_uploads_log, allowed_file_types=None): """ diff --git a/src/smolagents/memory.py b/src/smolagents/memory.py index 38fa9e1e9..912bad6e9 100644 --- a/src/smolagents/memory.py +++ b/src/smolagents/memory.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, TypedDict from smolagents.models import ChatMessage, MessageRole -from smolagents.monitoring import AgentLogger, LogLevel +from smolagents.monitoring import AgentLogger, LogLevel, Timing, TokenUsage from smolagents.utils import AgentError, make_json_serializable @@ -50,30 +50,28 @@ def to_messages(self, summary_mode: bool = False) -> list[Message]: @dataclass class ActionStep(MemoryStep): + step_number: int + timing: Timing model_input_messages: list[Message] | None = None tool_calls: list[ToolCall] | None = None - start_time: float | None = None - end_time: float | None = None - step_number: int | None = None error: AgentError | None = None - duration: float | None = None model_output_message: ChatMessage | None = None model_output: str | None = None observations: str | None = None observations_images: list["PIL.Image.Image"] | None = None action_output: Any = None + token_usage: TokenUsage | None = None def dict(self): # We overwrite the method to parse the tool_calls and action_output manually return { "model_input_messages": self.model_input_messages, "tool_calls": [tc.dict() for tc in self.tool_calls] if self.tool_calls else [], - "start_time": self.start_time, - "end_time": self.end_time, + "timing": self.timing.dict(), + "token_usage": asdict(self.token_usage) if self.token_usage else None, "step": self.step_number, "error": self.error.dict() if self.error else None, - "duration": self.duration, - "model_output_message": self.model_output_message, + "model_output_message": self.model_output_message.dict() if self.model_output_message else None, "model_output": self.model_output, "observations": self.observations, "action_output": make_json_serializable(self.action_output), @@ -145,6 +143,8 @@ class PlanningStep(MemoryStep): model_input_messages: list[Message] model_output_message: ChatMessage plan: str + timing: Timing + token_usage: TokenUsage | None = None def to_messages(self, summary_mode: bool = False) -> list[Message]: if summary_mode: @@ -182,7 +182,7 @@ def to_messages(self, summary_mode: bool = False) -> list[Message]: @dataclass class FinalAnswerStep(MemoryStep): - final_answer: Any + output: Any class AgentMemory: diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 664788164..48a702b65 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -24,6 +24,7 @@ from threading import Thread from typing import TYPE_CHECKING, Any +from .monitoring import TokenUsage from .tools import Tool from .utils import _is_package_available, encode_image_base64, make_image_url, parse_json_blob @@ -99,12 +100,13 @@ class ChatMessage: content: str | None = None tool_calls: list[ChatMessageToolCall] | None = None raw: Any | None = None # Stores the raw output from the API + token_usage: TokenUsage | None = None def model_dump_json(self): return json.dumps(get_dict_from_nested_dataclasses(self, ignore_key="raw")) @classmethod - def from_dict(cls, data: dict, raw: Any | None = None) -> "ChatMessage": + def from_dict(cls, data: dict, raw: Any | None = None, token_usage: TokenUsage | None = None) -> "ChatMessage": if data.get("tool_calls"): tool_calls = [ ChatMessageToolCall( @@ -113,10 +115,16 @@ def from_dict(cls, data: dict, raw: Any | None = None) -> "ChatMessage": for tc in data["tool_calls"] ] data["tool_calls"] = tool_calls - return cls(role=data["role"], content=data.get("content"), tool_calls=data.get("tool_calls"), raw=raw) + return cls( + role=data["role"], + content=data.get("content"), + tool_calls=data.get("tool_calls"), + raw=raw, + token_usage=token_usage, + ) def dict(self): - return json.dumps(get_dict_from_nested_dataclasses(self)) + return get_dict_from_nested_dataclasses(self) @classmethod def from_hf_api(cls, message: "ChatCompletionOutputMessage", raw) -> "ChatMessage": @@ -142,6 +150,7 @@ def parse_json_if_needed(arguments: str | dict) -> str | dict: class ChatMessageStreamDelta: content: str | None = None tool_calls: list[ChatMessageToolCall] | None = None + token_usage: TokenUsage | None = None class MessageRole(str, Enum): @@ -301,10 +310,28 @@ def __init__( self.tool_name_key = tool_name_key self.tool_arguments_key = tool_arguments_key self.kwargs = kwargs - self.last_input_token_count: int | None = None - self.last_output_token_count: int | None = None + self._last_input_token_count: int | None = None + self._last_output_token_count: int | None = None self.model_id: str | None = model_id + @property + def last_input_token_count(self) -> int | None: + warnings.warn( + "Attribute last_input_token_count is deprecated and will be removed in version 1.20. " + "Please use TokenUsage.input_tokens instead.", + FutureWarning, + ) + return self._last_input_token_count + + @property + def last_output_token_count(self) -> int | None: + warnings.warn( + "Attribute last_output_token_count is deprecated and will be removed in version 1.20. " + "Please use TokenUsage.output_tokens instead.", + FutureWarning, + ) + return self._last_output_token_count + def _prepare_completion_kwargs( self, messages: list[dict[str, str | list[dict]]], @@ -359,17 +386,9 @@ def _prepare_completion_kwargs( return completion_kwargs - def get_token_counts(self) -> dict[str, int]: - if self.last_input_token_count is None or self.last_output_token_count is None: - raise ValueError("Token counts are not available") - return { - "input_token_count": self.last_input_token_count, - "output_token_count": self.last_output_token_count, - } - def generate( self, - messages: list[dict[str, str | list[dict]]], + messages: list[dict[str, str | list[dict]] | ChatMessage], stop_sequences: list[str] | None = None, grammar: str | None = None, tools_to_call_from: list[Tool] | None = None, @@ -378,7 +397,7 @@ def generate( """Process the input messages and return the model's response. Parameters: - messages (`list[dict[str, str]]`): + messages (`list[dict[str, str | list[dict]]] | list[ChatMessage]`): A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`. stop_sequences (`List[str]`, *optional*): A list of strings that will stop the generation if encountered in the model's output. @@ -416,8 +435,6 @@ def to_dict(self) -> dict: """ model_dictionary = { **self.kwargs, - "last_input_token_count": self.last_input_token_count, - "last_output_token_count": self.last_output_token_count, "model_id": self.model_id, } for attribute in [ @@ -446,16 +463,7 @@ def to_dict(self) -> dict: @classmethod def from_dict(cls, model_dictionary: dict[str, Any]) -> "Model": - model_instance = cls( - **{ - k: v - for k, v in model_dictionary.items() - if k not in ["last_input_token_count", "last_output_token_count"] - } - ) - model_instance.last_input_token_count = model_dictionary.pop("last_input_token_count", None) - model_instance.last_output_token_count = model_dictionary.pop("last_output_token_count", None) - return model_instance + return cls(**{k: v for k, v in model_dictionary.items()}) class VLLMModel(Model): @@ -508,7 +516,7 @@ def cleanup(self): def generate( self, - messages: list[dict[str, str | list[dict]]], + messages: list[dict[str, str | list[dict]] | ChatMessage], stop_sequences: list[str] | None = None, grammar: str | None = None, tools_to_call_from: list[Tool] | None = None, @@ -554,12 +562,16 @@ def generate( sampling_params=sampling_params, ) output_text = out[0].outputs[0].text - self.last_input_token_count = len(out[0].prompt_token_ids) - self.last_output_token_count = len(out[0].outputs[0].token_ids) + self._last_input_token_count = len(out[0].prompt_token_ids) + self._last_output_token_count = len(out[0].outputs[0].token_ids) return ChatMessage( role=MessageRole.ASSISTANT, content=output_text, raw={"out": output_text, "completion_kwargs": completion_kwargs}, + token_usage=TokenUsage( + input_tokens=len(out[0].prompt_token_ids), + output_tokens=len(out[0].outputs[0].token_ids), + ), ) @@ -627,7 +639,7 @@ def __init__( def generate( self, - messages: list[dict[str, str | list[dict]]], + messages: list[dict[str, str | list[dict]] | ChatMessage], stop_sequences: list[str] | None = None, grammar: str | None = None, tools_to_call_from: list[Tool] | None = None, @@ -651,18 +663,25 @@ def generate( add_generation_prompt=True, ) - self.last_input_token_count = len(prompt_ids) - self.last_output_token_count = 0 + output_tokens = 0 text = "" for response in self.stream_generate(self.model, self.tokenizer, prompt=prompt_ids, **completion_kwargs): - self.last_output_token_count += 1 + output_tokens += 1 text += response.text if any((stop_index := text.rfind(stop)) != -1 for stop in stops): text = text[:stop_index] break + self._last_input_token_count = len(prompt_ids) + self._last_output_token_count = output_tokens return ChatMessage( - role=MessageRole.ASSISTANT, content=text, raw={"out": text, "completion_kwargs": completion_kwargs} + role=MessageRole.ASSISTANT, + content=text, + raw={"out": text, "completion_kwargs": completion_kwargs}, + token_usage=TokenUsage( + input_tokens=len(prompt_ids), + output_tokens=output_tokens, + ), ) @@ -848,7 +867,7 @@ def _prepare_completion_args( def generate( self, - messages: list[dict[str, str | list[dict]]], + messages: list[dict[str, str | list[dict]] | ChatMessage], stop_sequences: list[str] | None = None, grammar: str | None = None, tools_to_call_from: list[Tool] | None = None, @@ -870,12 +889,12 @@ def generate( output_text = self.processor.decode(generated_tokens, skip_special_tokens=True) else: output_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) - self.last_input_token_count = count_prompt_tokens - self.last_output_token_count = len(generated_tokens) if stop_sequences is not None: output_text = remove_stop_sequences(output_text, stop_sequences) + self._last_input_token_count = count_prompt_tokens + self._last_output_token_count = len(generated_tokens) return ChatMessage( role=MessageRole.ASSISTANT, content=output_text, @@ -883,6 +902,10 @@ def generate( "out": output_text, "completion_kwargs": {key: value for key, value in generation_kwargs.items() if key != "inputs"}, }, + token_usage=TokenUsage( + input_tokens=count_prompt_tokens, + output_tokens=len(generated_tokens), + ), ) def generate_stream( @@ -905,14 +928,15 @@ def generate_stream( thread = Thread(target=self.model.generate, kwargs={"streamer": self.streamer, **generation_kwargs}) thread.start() - self.last_output_token_count = 0 - # Generate with streaming for new_text in self.streamer: - yield ChatMessageStreamDelta(content=new_text, tool_calls=None) - self.last_output_token_count += 1 - - self.last_input_token_count = count_prompt_tokens + self._last_input_token_count = count_prompt_tokens + self._last_output_token_count = 1 + yield ChatMessageStreamDelta( + content=new_text, + tool_calls=None, + token_usage=TokenUsage(input_tokens=count_prompt_tokens, output_tokens=1), + ) thread.join() @@ -1009,7 +1033,7 @@ def create_client(self): def generate( self, - messages: list[dict[str, str | list[dict]]], + messages: list[dict[str, str | list[dict]] | ChatMessage], stop_sequences: list[str] | None = None, grammar: str | None = None, tools_to_call_from: list[Tool] | None = None, @@ -1030,11 +1054,15 @@ def generate( response = self.client.completion(**completion_kwargs) - self.last_input_token_count = response.usage.prompt_tokens - self.last_output_token_count = response.usage.completion_tokens + self._last_input_token_count = response.usage.prompt_tokens + self._last_output_token_count = response.usage.completion_tokens return ChatMessage.from_dict( response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}), raw=response, + token_usage=TokenUsage( + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, + ), ) def generate_stream( @@ -1069,8 +1097,15 @@ def generate_stream( content=event.choices[0].delta.content, ) if getattr(event, "usage", None): - self.last_input_token_count = event.usage.prompt_tokens - self.last_output_token_count = event.usage.completion_tokens + self._last_input_token_count = event.usage.prompt_tokens + self._last_output_token_count = event.usage.completion_tokens + yield ChatMessageStreamDelta( + content="", + token_usage=TokenUsage( + input_tokens=event.usage.prompt_tokens, + output_tokens=event.usage.completion_tokens, + ), + ) class LiteLLMRouterModel(LiteLLMModel): @@ -1213,7 +1248,7 @@ class InferenceClientModel(ApiModel): ```python >>> engine = InferenceClientModel( ... model_id="Qwen/Qwen2.5-Coder-32B-Instruct", - ... provider="together", + ... provider="nebius", ... token="your_hf_token_here", ... max_tokens=5000, ... ) @@ -1265,7 +1300,7 @@ def create_client(self): def generate( self, - messages: list[dict[str, str | list[dict]]], + messages: list[dict[str, str | list[dict]] | ChatMessage], stop_sequences: list[str] | None = None, grammar: str | None = None, tools_to_call_from: list[Tool] | None = None, @@ -1282,9 +1317,16 @@ def generate( ) response = self.client.chat_completion(**completion_kwargs) - self.last_input_token_count = response.usage.prompt_tokens - self.last_output_token_count = response.usage.completion_tokens - return ChatMessage.from_dict(asdict(response.choices[0].message), raw=response) + self._last_input_token_count = response.usage.prompt_tokens + self._last_output_token_count = response.usage.completion_tokens + return ChatMessage.from_dict( + asdict(response.choices[0].message), + raw=response, + token_usage=TokenUsage( + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, + ), + ) def generate_stream( self, @@ -1318,8 +1360,15 @@ def generate_stream( content=event.choices[0].delta.content, ) if getattr(event, "usage", None): - self.last_input_token_count = event.usage.prompt_tokens - self.last_output_token_count = event.usage.completion_tokens + self._last_input_token_count = event.usage.prompt_tokens + self._last_output_token_count = event.usage.completion_tokens + yield ChatMessageStreamDelta( + content="", + token_usage=TokenUsage( + input_tokens=event.usage.prompt_tokens, + output_tokens=event.usage.completion_tokens, + ), + ) class HfApiModel(InferenceClientModel): @@ -1420,16 +1469,21 @@ def generate_stream( if not getattr(event.choices[0], "finish_reason", None): raise ValueError(f"No content or tool calls in event: {event}") else: - yield ChatMessageStreamDelta( - content=event.choices[0].delta.content, - ) - if getattr(event, "usage", None): - self.last_input_token_count = event.usage.prompt_tokens - self.last_output_token_count = event.usage.completion_tokens + yield ChatMessageStreamDelta(content=event.choices[0].delta.content) + if event.usage: + self._last_input_token_count = event.usage.prompt_tokens + self._last_output_token_count = event.usage.completion_tokens + yield ChatMessageStreamDelta( + content="", + token_usage=TokenUsage( + input_tokens=event.usage.prompt_tokens, + output_tokens=event.usage.completion_tokens, + ), + ) def generate( self, - messages: list[dict[str, str | list[dict]]], + messages: list[dict[str, str | list[dict]] | ChatMessage], stop_sequences: list[str] | None = None, grammar: str | None = None, tools_to_call_from: list[Tool] | None = None, @@ -1446,12 +1500,16 @@ def generate( **kwargs, ) response = self.client.chat.completions.create(**completion_kwargs) - self.last_input_token_count = response.usage.prompt_tokens - self.last_output_token_count = response.usage.completion_tokens + self._last_input_token_count = response.usage.prompt_tokens + self._last_output_token_count = response.usage.completion_tokens return ChatMessage.from_dict( response.choices[0].message.model_dump(include={"role", "content", "tool_calls"}), raw=response, + token_usage=TokenUsage( + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, + ), ) @@ -1656,7 +1714,7 @@ def create_client(self): def generate( self, - messages: list[dict[str, str | list[dict]]], + messages: list[dict[str, str | list[dict]] | ChatMessage], stop_sequences: list[str] | None = None, grammar: str | None = None, tools_to_call_from: list[Tool] | None = None, @@ -1673,13 +1731,19 @@ def generate( # self.client is created in ApiModel class response = self.client.converse(**completion_kwargs) - # Get usage - self.last_input_token_count = response["usage"]["inputTokens"] - self.last_output_token_count = response["usage"]["outputTokens"] - # Get first message response["output"]["message"]["content"] = response["output"]["message"]["content"][0]["text"] - return ChatMessage.from_dict(response["output"]["message"], raw=response) + + self._last_input_token_count = response["usage"]["inputTokens"] + self._last_output_token_count = response["usage"]["outputTokens"] + return ChatMessage.from_dict( + response["output"]["message"], + raw=response, + token_usage=TokenUsage( + input_tokens=response["usage"]["inputTokens"], + output_tokens=response["usage"]["outputTokens"], + ), + ) __all__ = [ diff --git a/src/smolagents/monitoring.py b/src/smolagents/monitoring.py index 0d827a95e..5f5c174da 100644 --- a/src/smolagents/monitoring.py +++ b/src/smolagents/monitoring.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from dataclasses import dataclass, field from enum import IntEnum from rich import box @@ -29,7 +30,45 @@ from smolagents.utils import escape_code_brackets -__all__ = ["AgentLogger", "LogLevel", "Monitor"] +__all__ = ["AgentLogger", "LogLevel", "Monitor", "TokenUsage", "Timing"] + + +@dataclass +class TokenUsage: + """ + Contains the token usage information for a given step or run. + """ + + input_tokens: int + output_tokens: int + total_tokens: int = field(init=False) + + def __post_init__(self): + self.total_tokens = self.input_tokens + self.output_tokens + + +@dataclass +class Timing: + """ + Contains the timing information for a given step or run. + """ + + start_time: float + end_time: float | None = None + + @property + def duration(self): + return None if self.end_time is None else self.end_time - self.start_time + + def dict(self): + return { + "start_time": self.start_time, + "end_time": self.end_time, + "duration": self.duration, + } + + def __repr__(self) -> str: + return f"Timing(start_time={self.start_time}, end_time={self.end_time}, duration={self.duration})" class Monitor: @@ -37,15 +76,14 @@ def __init__(self, tracked_model, logger): self.step_durations = [] self.tracked_model = tracked_model self.logger = logger - if getattr(self.tracked_model, "last_input_token_count", "Not found") != "Not found": - self.total_input_token_count = 0 - self.total_output_token_count = 0 + self.total_input_token_count = 0 + self.total_output_token_count = 0 - def get_total_token_counts(self): - return { - "input": self.total_input_token_count, - "output": self.total_output_token_count, - } + def get_total_token_counts(self) -> TokenUsage: + return TokenUsage( + input_tokens=self.total_input_token_count, + output_tokens=self.total_output_token_count, + ) def reset(self): self.step_durations = [] @@ -58,13 +96,13 @@ def update_metrics(self, step_log): Args: step_log ([`MemoryStep`]): Step log to update the monitor with. """ - step_duration = step_log.duration + step_duration = step_log.timing.duration self.step_durations.append(step_duration) console_outputs = f"[Step {len(self.step_durations)}: Duration {step_duration:.2f} seconds" - if getattr(self.tracked_model, "last_input_token_count", None) is not None: - self.total_input_token_count += self.tracked_model.last_input_token_count - self.total_output_token_count += self.tracked_model.last_output_token_count + if step_log.token_usage is not None: + self.total_input_token_count += step_log.token_usage.input_tokens + self.total_output_token_count += step_log.token_usage.output_tokens console_outputs += ( f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}" ) diff --git a/tests/test_agents.py b/tests/test_agents.py index a350dbff4..d6ace4471 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -571,9 +571,11 @@ def weather_api(location: str, celsius: bool = False) -> str: assert agent.memory.steps[0].task == task assert agent.memory.steps[1].tool_calls[0].name == "weather_api" step_memory_dict = agent.memory.get_succinct_steps()[1] - assert step_memory_dict["model_output_message"].tool_calls[0].function.name == "weather_api" - assert step_memory_dict["model_output_message"].raw["completion_kwargs"]["max_new_tokens"] == 100 + assert step_memory_dict["model_output_message"]["tool_calls"][0]["function"]["name"] == "weather_api" + assert step_memory_dict["model_output_message"]["raw"]["completion_kwargs"]["max_new_tokens"] == 100 assert "model_input_messages" in agent.memory.get_full_steps()[1] + assert step_memory_dict["token_usage"]["total_tokens"] > 100 + assert step_memory_dict["timing"]["duration"] > 0.1 def test_final_answer_checks(self): def check_always_fails(final_answer, agent_memory): @@ -678,8 +680,13 @@ def generate(self, messages, stop_sequences=None): def test_step_number(self): fake_model = MagicMock() - fake_model.last_input_token_count = 10 - fake_model.last_output_token_count = 20 + fake_model.generate.return_value = ChatMessage( + role="assistant", + content="Model output.", + tool_calls=None, + raw="Model output.", + token_usage=None, + ) max_steps = 2 agent = CodeAgent(tools=[], model=fake_model, max_steps=max_steps) assert hasattr(agent, "step_number"), "step_number attribute should be defined" @@ -812,13 +819,19 @@ def test_planning_step(self, step, expected_messages_list): ) def test_provide_final_answer(self, images, expected_messages_list): fake_model = MagicMock() - fake_model.return_value.content = "Final answer." + fake_model.generate.return_value = ChatMessage( + role="assistant", + content="Final answer.", + tool_calls=None, + raw="Final answer.", + token_usage=None, + ) agent = CodeAgent( tools=[], model=fake_model, ) task = "Test task" - final_answer = agent.provide_final_answer(task, images=images) + final_answer = agent.provide_final_answer(task, images=images).content expected_message_texts = { "FINAL_ANSWER_SYSTEM_PROMPT": agent.prompt_templates["final_answer"]["pre_messages"], "FINAL_ANSWER_USER_PROMPT": populate_template( @@ -832,8 +845,8 @@ def test_provide_final_answer(self, images, expected_messages_list): expected_content["text"] = expected_message_texts[expected_content["text"]] assert final_answer == "Final answer." # Test calls to model - assert len(fake_model.call_args_list) == 1 - for call_args, expected_messages in zip(fake_model.call_args_list, expected_messages_list): + assert len(fake_model.generate.call_args_list) == 1 + for call_args, expected_messages in zip(fake_model.generate.call_args_list, expected_messages_list): assert len(call_args.args) == 1 messages = call_args.args[0] assert isinstance(messages, list) @@ -851,8 +864,13 @@ def test_provide_final_answer(self, images, expected_messages_list): def test_interrupt(self): fake_model = MagicMock() - fake_model.return_value.content = "Model output." - fake_model.last_input_token_count = None + fake_model.generate.return_value = ChatMessage( + role="assistant", + content="Model output.", + tool_calls=None, + raw="Model output.", + token_usage=None, + ) def interrupt_callback(memory_step, agent): agent.interrupt() @@ -1202,8 +1220,13 @@ def generate(self, messages, stop_sequences=None, grammar=None): def test_local_python_executor_with_custom_functions(self): model = MagicMock() - model.last_input_token_count = 10 - model.last_output_token_count = 5 + model.generate.return_value = ChatMessage( + role="assistant", + content="", + tool_calls=None, + raw="", + token_usage=None, + ) agent = CodeAgent(tools=[], model=model, executor_kwargs={"additional_functions": {"open": open}}) agent.run("Test run") assert "open" in agent.python_executor.static_tools diff --git a/tests/test_gradio_ui.py b/tests/test_gradio_ui.py index f748bf1e0..1e39a7343 100644 --- a/tests/test_gradio_ui.py +++ b/tests/test_gradio_ui.py @@ -25,6 +25,7 @@ from smolagents.gradio_ui import GradioUI, pull_messages_from_step, stream_to_gradio from smolagents.memory import ActionStep, FinalAnswerStep, PlanningStep, ToolCall from smolagents.models import ChatMessageStreamDelta +from smolagents.monitoring import Timing, TokenUsage class GradioUITester(unittest.TestCase): @@ -221,11 +222,9 @@ def test_action_step_basic( model_output="This is the model output", observations="Some execution logs", error=None, - duration=2.5, + timing=Timing(start_time=1.0, end_time=3.5), + token_usage=TokenUsage(input_tokens=100, output_tokens=50), ) - # Set in stream_to_gradio: - step.input_token_count = 100 - step.output_token_count = 50 messages = list(pull_messages_from_step(step)) assert len(messages) == 5 # step number, model_output, logs, footnote, divider for message, expected_content in zip( @@ -246,7 +245,8 @@ def test_action_step_with_tool_calls(self): step_number=2, tool_calls=[ToolCall(name="test_tool", arguments={"answer": "Test answer"}, id="tool_call_1")], observations="Tool execution logs", - duration=1.5, + timing=Timing(start_time=1.0, end_time=2.5), + token_usage=TokenUsage(input_tokens=100, output_tokens=50), ) messages = list(pull_messages_from_step(step)) assert len(messages) == 5 # step, tool call, logs, footnote, divider @@ -266,7 +266,12 @@ def test_action_step_tool_call_formats(self, tool_name, args, expected): tool_call = Mock() tool_call.name = tool_name tool_call.arguments = args - step = ActionStep(step_number=1, tool_calls=[tool_call], duration=1.5) + step = ActionStep( + step_number=1, + tool_calls=[tool_call], + timing=Timing(start_time=1.0, end_time=2.5), + token_usage=TokenUsage(input_tokens=100, output_tokens=50), + ) messages = list(pull_messages_from_step(step)) tool_message = next( msg @@ -281,7 +286,12 @@ def test_action_step_tool_call_formats(self, tool_name, args, expected): def test_action_step_with_error(self): """Test ActionStep with error.""" - step = ActionStep(step_number=3, error="This is an error message", duration=1.0) + step = ActionStep( + step_number=3, + error="This is an error message", + timing=Timing(start_time=1.0, end_time=2.0), + token_usage=TokenUsage(input_tokens=100, output_tokens=200), + ) messages = list(pull_messages_from_step(step)) error_message = next((m for m in messages if "error" in str(m.content).lower()), None) assert error_message is not None @@ -289,7 +299,12 @@ def test_action_step_with_error(self): def test_action_step_with_images(self): """Test ActionStep with observation images.""" - step = ActionStep(step_number=4, observations_images=["image1.png", "image2.jpg"], duration=1.0) + step = ActionStep( + step_number=4, + observations_images=["image1.png", "image2.jpg"], + token_usage=TokenUsage(input_tokens=100, output_tokens=200), + timing=Timing(start_time=1.0, end_time=2.0), + ) with patch("smolagents.gradio_ui.AgentImage") as mock_agent_image: mock_agent_image.return_value.to_string.side_effect = lambda: "path/to/image.png" messages = list(pull_messages_from_step(step)) @@ -297,26 +312,33 @@ def test_action_step_with_images(self): assert len(image_messages) == 2 assert "path/to/image.png" in str(image_messages[0]) - @pytest.mark.parametrize("skip_model_outputs, expected_messages_length", [(False, 4), (True, 2)]) - def test_planning_step(self, skip_model_outputs, expected_messages_length): + @pytest.mark.parametrize( + "skip_model_outputs, expected_messages_length, token_usage", + [(False, 4, TokenUsage(input_tokens=80, output_tokens=30)), (True, 2, None)], + ) + def test_planning_step(self, skip_model_outputs, expected_messages_length, token_usage): """Test PlanningStep processing.""" step = PlanningStep( - plan="1. First step\n2. Second step", model_input_messages=Mock(), model_output_message=Mock() + plan="1. First step\n2. Second step", + model_input_messages=Mock(), + model_output_message=Mock(), + token_usage=token_usage, + timing=Timing(start_time=1.0, end_time=2.0), ) - # Set in stream_to_gradio: - step.input_token_count = 80 - step.output_token_count = 30 messages = list(pull_messages_from_step(step, skip_model_outputs=skip_model_outputs)) assert len(messages) == expected_messages_length # [header, plan,] footnote, divider expected_contents = [ "**Planning step**", "1. First step\n2. Second step", - "Input tokens: 80 | Output tokens: 30", + "Input tokens: 80 | Output tokens: 30" if token_usage else "", "-----", ] for message, expected_content in zip(messages, expected_contents[-expected_messages_length:]): assert expected_content in message.content + if not token_usage: + assert "Input tokens: 80 | Output tokens: 30" not in message.content + @pytest.mark.parametrize( "answer_type, answer_value, expected_content", [ @@ -331,7 +353,9 @@ def test_final_answer_step(self, answer_type, answer_value, expected_content): except TypeError: with patch.object(answer_type, "to_string", return_value=answer_value): final_answer = answer_type(answer_value) - step = FinalAnswerStep(final_answer=final_answer) + step = FinalAnswerStep( + output=final_answer, + ) messages = list(pull_messages_from_step(step)) assert len(messages) == 1 assert messages[0].content == expected_content @@ -339,7 +363,7 @@ def test_final_answer_step(self, answer_type, answer_value, expected_content): def test_final_answer_step_image(self): """Test FinalAnswerStep with image answer.""" with patch.object(AgentImage, "to_string", return_value="path/to/image.png"): - step = FinalAnswerStep(final_answer=AgentImage("path/to/image.png")) + step = FinalAnswerStep(output=AgentImage("path/to/image.png")) messages = list(pull_messages_from_step(step)) assert len(messages) == 1 assert messages[0].content["path"] == "path/to/image.png" @@ -348,7 +372,7 @@ def test_final_answer_step_image(self): def test_final_answer_step_audio(self): """Test FinalAnswerStep with audio answer.""" with patch.object(AgentAudio, "to_string", return_value="path/to/audio.wav"): - step = FinalAnswerStep(final_answer=AgentAudio("path/to/audio.wav")) + step = FinalAnswerStep(output=AgentAudio("path/to/audio.wav")) messages = list(pull_messages_from_step(step)) assert len(messages) == 1 assert messages[0].content["path"] == "path/to/audio.wav" diff --git a/tests/test_memory.py b/tests/test_memory.py index 04c6b7f47..4bf4fbab7 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -12,6 +12,7 @@ SystemPromptStep, TaskStep, ) +from smolagents.monitoring import Timing, TokenUsage class TestAgentMemory: @@ -37,22 +38,84 @@ def test_to_messages(self): step.to_messages() +def test_action_step_dict(): + action_step = ActionStep( + model_input_messages=[Message(role=MessageRole.USER, content="Hello")], + tool_calls=[ + ToolCall(id="id", name="get_weather", arguments={"location": "Paris"}), + ], + timing=Timing(start_time=0.0, end_time=1.0), + step_number=1, + error=None, + model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Hi"), + model_output="Hi", + observations="This is a nice observation", + observations_images=["image1.png"], + action_output="Output", + token_usage=TokenUsage(input_tokens=10, output_tokens=20), + ) + action_step_dict = action_step.dict() + # Check each key individually for better test failure messages + assert "model_input_messages" in action_step_dict + assert action_step_dict["model_input_messages"] == [Message(role=MessageRole.USER, content="Hello")] + + assert "tool_calls" in action_step_dict + assert len(action_step_dict["tool_calls"]) == 1 + assert action_step_dict["tool_calls"][0] == { + "id": "id", + "type": "function", + "function": { + "name": "get_weather", + "arguments": {"location": "Paris"}, + }, + } + + assert "timing" in action_step_dict + assert action_step_dict["timing"] == {"start_time": 0.0, "end_time": 1.0, "duration": 1.0} + + assert "token_usage" in action_step_dict + assert action_step_dict["token_usage"] == {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30} + + assert "step" in action_step_dict + assert action_step_dict["step"] == 1 + + assert "error" in action_step_dict + assert action_step_dict["error"] is None + + assert "model_output_message" in action_step_dict + assert action_step_dict["model_output_message"] == { + "role": "assistant", + "content": "Hi", + "tool_calls": None, + "raw": None, + "token_usage": None, + } + + assert "model_output" in action_step_dict + assert action_step_dict["model_output"] == "Hi" + + assert "observations" in action_step_dict + assert action_step_dict["observations"] == "This is a nice observation" + + assert "action_output" in action_step_dict + assert action_step_dict["action_output"] == "Output" + + def test_action_step_to_messages(): action_step = ActionStep( model_input_messages=[Message(role=MessageRole.USER, content="Hello")], tool_calls=[ ToolCall(id="id", name="get_weather", arguments={"location": "Paris"}), ], - start_time=0.0, - end_time=1.0, + timing=Timing(start_time=0.0, end_time=1.0), step_number=1, error=None, - duration=1.0, model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Hi"), model_output="Hi", observations="This is a nice observation", observations_images=["image1.png"], action_output="Output", + token_usage=TokenUsage(input_tokens=10, output_tokens=20), ) messages = action_step.to_messages() assert len(messages) == 4 @@ -93,16 +156,15 @@ def test_action_step_to_messages_no_tool_calls_with_observations(): action_step = ActionStep( model_input_messages=None, tool_calls=None, - start_time=None, - end_time=None, - step_number=None, + timing=Timing(start_time=0.0, end_time=1.0), + step_number=1, error=None, - duration=None, model_output_message=None, model_output=None, observations="This is an observation.", observations_images=None, action_output=None, + token_usage=TokenUsage(input_tokens=10, output_tokens=20), ) messages = action_step.to_messages() assert len(messages) == 1 @@ -116,6 +178,7 @@ def test_planning_step_to_messages(): model_input_messages=[Message(role=MessageRole.USER, content="Hello")], model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Plan"), plan="This is a plan.", + timing=Timing(start_time=0.0, end_time=1.0), ) messages = planning_step.to_messages(summary_mode=False) assert len(messages) == 2 diff --git a/tests/test_monitoring.py b/tests/test_monitoring.py index c7f6b9a64..6ebfa2982 100644 --- a/tests/test_monitoring.py +++ b/tests/test_monitoring.py @@ -20,6 +20,7 @@ from smolagents import ( AgentImage, CodeAgent, + RunResult, ToolCallingAgent, stream_to_gradio, ) @@ -28,13 +29,13 @@ ChatMessageToolCall, ChatMessageToolCallDefinition, Model, + TokenUsage, ) class FakeLLMModel(Model): - def __init__(self): - self.last_input_token_count = 10 - self.last_output_token_count = 20 + def __init__(self, give_token_usage: bool = True): + self.give_token_usage = give_token_usage def generate(self, prompt, tools_to_call_from=None, **kwargs): if tools_to_call_from is not None: @@ -48,6 +49,7 @@ def generate(self, prompt, tools_to_call_from=None, **kwargs): function=ChatMessageToolCallDefinition(name="final_answer", arguments={"answer": "image"}), ) ], + token_usage=TokenUsage(input_tokens=10, output_tokens=20) if self.give_token_usage else None, ) else: return ChatMessage( @@ -57,6 +59,7 @@ def generate(self, prompt, tools_to_call_from=None, **kwargs): ```py final_answer('This is the final answer.') ```""", + token_usage=TokenUsage(input_tokens=10, output_tokens=20) if self.give_token_usage else None, ) @@ -86,12 +89,12 @@ def test_toolcalling_agent_metrics(self): def test_code_agent_metrics_max_steps(self): class FakeLLMModelMalformedAnswer(Model): - def __init__(self): - self.last_input_token_count = 10 - self.last_output_token_count = 20 - def generate(self, prompt, **kwargs): - return ChatMessage(role="assistant", content="Malformed answer") + return ChatMessage( + role="assistant", + content="Malformed answer", + token_usage=TokenUsage(input_tokens=10, output_tokens=20), + ) agent = CodeAgent( tools=[], @@ -106,13 +109,7 @@ def generate(self, prompt, **kwargs): def test_code_agent_metrics_generation_error(self): class FakeLLMModelGenerationException(Model): - def __init__(self): - self.last_input_token_count = 10 - self.last_output_token_count = 20 - def generate(self, prompt, **kwargs): - self.last_input_token_count = 10 - self.last_output_token_count = 0 raise Exception("Cannot generate") agent = CodeAgent( @@ -120,11 +117,9 @@ def generate(self, prompt, **kwargs): model=FakeLLMModelGenerationException(), max_steps=1, ) - with pytest.raises(Exception): + with pytest.raises(Exception) as e: agent.run("Fake task") - - self.assertEqual(agent.monitor.total_input_token_count, 10) # Should have done one monitoring callbacks - self.assertEqual(agent.monitor.total_output_token_count, 0) + assert "Cannot generate" in str(e.value) def test_streaming_agent_text_output(self): agent = CodeAgent( @@ -186,3 +181,73 @@ def generate(self, prompt, **kwargs): final_message = outputs[-1] self.assertEqual(final_message.role, "assistant") self.assertIn("Malformed call", final_message.content) + + def test_run_return_full_result(self): + agent = CodeAgent( + tools=[], + model=FakeLLMModel(), + max_steps=1, + return_full_result=True, + ) + + result = agent.run("Fake task") + + self.assertIsInstance(result, RunResult) + self.assertEqual(result.output, "This is the final answer.") + self.assertEqual(result.state, "success") + self.assertEqual(result.token_usage, TokenUsage(input_tokens=10, output_tokens=20)) + self.assertIsInstance(result.messages, list) + self.assertGreater(result.timing.duration, 0) + + agent = ToolCallingAgent( + tools=[], + model=FakeLLMModel(), + max_steps=1, + return_full_result=True, + ) + + result = agent.run("Fake task") + + self.assertIsInstance(result, RunResult) + self.assertEqual(result.output, "image") + self.assertEqual(result.state, "success") + self.assertEqual(result.token_usage, TokenUsage(input_tokens=10, output_tokens=20)) + self.assertIsInstance(result.messages, list) + self.assertGreater(result.timing.duration, 0) + + # Below 2 lines should be removed when the attributes are removed + assert agent.monitor.total_input_token_count == 10 + assert agent.monitor.total_output_token_count == 20 + + def test_run_result_no_token_usage(self): + agent = CodeAgent( + tools=[], + model=FakeLLMModel(give_token_usage=False), + max_steps=1, + return_full_result=True, + ) + + result = agent.run("Fake task") + + self.assertIsInstance(result, RunResult) + self.assertEqual(result.output, "This is the final answer.") + self.assertEqual(result.state, "success") + self.assertIsNone(result.token_usage) + self.assertIsInstance(result.messages, list) + self.assertGreater(result.timing.duration, 0) + + agent = ToolCallingAgent( + tools=[], + model=FakeLLMModel(give_token_usage=False), + max_steps=1, + return_full_result=True, + ) + + result = agent.run("Fake task") + + self.assertIsInstance(result, RunResult) + self.assertEqual(result.output, "image") + self.assertEqual(result.state, "success") + self.assertIsNone(result.token_usage) + self.assertIsInstance(result.messages, list) + self.assertGreater(result.timing.duration, 0)