-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Fix and refactor final answer checks #1448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
3a92725
10a0764
cd143f8
03f17e8
3092d62
4441aec
fc7cc89
54dbde6
7b39442
ca5a539
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -110,8 +110,9 @@ def populate_template(template: str, variables: dict[str, Any]) -> str: | |
|
|
||
|
|
||
| @dataclass | ||
| class FinalOutput: | ||
| output: Any | None | ||
| class ActionOutput: | ||
| output: Any | ||
| is_final_answer: bool | ||
|
|
||
|
|
||
| class PlanningPromptTemplate(TypedDict): | ||
|
|
@@ -212,6 +213,7 @@ class MultiStepAgent(ABC): | |
| tools (`list[Tool]`): [`Tool`]s that the agent can use. | ||
| model (`Callable[[list[dict[str, str]]], ChatMessage]`): Model that will generate the agent's actions. | ||
| prompt_templates ([`~agents.PromptTemplates`], *optional*): Prompt templates. | ||
| instructions (`str`, *optional*): Custom instructions for the agent, will be inserted in the system prompt. | ||
| max_steps (`int`, default `20`): Maximum number of steps the agent can take to solve the task. | ||
| add_base_tools (`bool`, default `False`): Whether to add the base tools to the agent's tools. | ||
| verbosity_level (`LogLevel`, default `LogLevel.INFO`): Level of verbosity of the agent's logs. | ||
|
|
@@ -236,6 +238,7 @@ def __init__( | |
| tools: list[Tool], | ||
| model: Model, | ||
| prompt_templates: PromptTemplates | None = None, | ||
| instructions: str | None = None, | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is from #1442, will not appear after merging pr 1442. |
||
| max_steps: int = 20, | ||
| add_base_tools: bool = False, | ||
| verbosity_level: LogLevel = LogLevel.INFO, | ||
|
|
@@ -278,9 +281,9 @@ def __init__( | |
| self.name = self._validate_name(name) | ||
| self.description = description | ||
| self.provide_run_summary = provide_run_summary | ||
| self.final_answer_checks = final_answer_checks | ||
| self.final_answer_checks = final_answer_checks if final_answer_checks is not None else [] | ||
| self.return_full_result = return_full_result | ||
|
|
||
| self.instructions = instructions | ||
| self._setup_managed_agents(managed_agents) | ||
| self._setup_tools(tools, add_base_tools) | ||
| self._validate_tools_and_managed_agents(tools, managed_agents) | ||
|
|
@@ -449,9 +452,9 @@ def run( | |
| def _run_stream( | ||
| self, task: str, max_steps: int, images: list["PIL.Image.Image"] | None = None | ||
| ) -> Generator[ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta]: | ||
| final_answer = None | ||
| self.step_number = 1 | ||
| while final_answer is None and self.step_number <= max_steps: | ||
| returned_final_answer = False | ||
| while not returned_final_answer and self.step_number <= max_steps: | ||
| if self.interrupt_switch: | ||
| raise AgentError("Agent interrupted.", self.logger) | ||
|
|
||
|
|
@@ -462,7 +465,7 @@ def _run_stream( | |
| planning_start_time = time.time() | ||
| planning_step = None | ||
| for element in self._generate_planning_step( | ||
| task, is_first_step=(len(self.memory.steps) == 1), step=self.step_number | ||
| task, is_first_step=(self.step_number == 1), step=self.step_number | ||
| ): | ||
| yield element | ||
| planning_step = element | ||
|
|
@@ -481,10 +484,19 @@ def _run_stream( | |
| timing=Timing(start_time=action_step_start_time), | ||
| observations_images=images, | ||
| ) | ||
| self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO) | ||
| try: | ||
| for el in self._execute_step(action_step): | ||
| yield el | ||
| final_answer = el | ||
| for output in self._step_stream(action_step): | ||
| # Yield streaming deltas | ||
| if not isinstance(output, ActionOutput): | ||
aymeric-roucher marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| yield output | ||
|
|
||
| if isinstance(output, ActionOutput) and output.is_final_answer: | ||
| if self.final_answer_checks: | ||
| self._validate_final_answer(output.output) | ||
| returned_final_answer = True | ||
| action_step.is_final_answer = True | ||
| final_answer = output.output | ||
| except AgentGenerationError as e: | ||
| # Agent generation errors are not caused by a Model error but an implementation error: so we should raise them and exit. | ||
| raise e | ||
|
|
@@ -497,23 +509,11 @@ def _run_stream( | |
| yield action_step | ||
| self.step_number += 1 | ||
|
|
||
| if final_answer is None and self.step_number == max_steps + 1: | ||
| if not returned_final_answer and self.step_number == max_steps + 1: | ||
| final_answer = self._handle_max_steps_reached(task, images) | ||
| yield action_step | ||
| yield FinalAnswerStep(handle_agent_output_types(final_answer)) | ||
|
|
||
| def _execute_step(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think porting this logic into a separate function obfuscates the logic more than it clarifies. |
||
| self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO) | ||
| for el in self._step_stream(memory_step): | ||
| final_answer = el | ||
| if isinstance(el, ChatMessageStreamDelta): | ||
| yield el | ||
| elif isinstance(el, FinalOutput): | ||
| final_answer = el.output | ||
| if self.final_answer_checks: | ||
| self._validate_final_answer(final_answer) | ||
| yield final_answer | ||
|
|
||
| def _validate_final_answer(self, final_answer: Any): | ||
| for check_function in self.final_answer_checks: | ||
| try: | ||
|
|
@@ -617,8 +617,7 @@ def _generate_planning_step( | |
| } | ||
| ], | ||
| } | ||
| # remove last message from memory_messages because it is the current task | ||
| input_messages = [plan_update_pre] + memory_messages[:-1] + [plan_update_post] | ||
| input_messages = [plan_update_pre] + memory_messages + [plan_update_post] | ||
aymeric-roucher marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if self.stream_outputs and hasattr(self.model, "generate_stream"): | ||
| plan_message_content = "" | ||
| input_tokens, output_tokens = 0, 0 | ||
|
|
@@ -672,7 +671,7 @@ def interrupt(self): | |
|
|
||
| def write_memory_to_messages( | ||
| self, | ||
| summary_mode: bool | None = False, | ||
| summary_mode: bool = False, | ||
| ) -> list[Message]: | ||
| """ | ||
| Reads past llm_outputs, actions, and observations or errors from the memory into a series of messages | ||
|
|
@@ -684,7 +683,7 @@ def write_memory_to_messages( | |
| messages.extend(memory_step.to_messages(summary_mode=summary_mode)) | ||
| return messages | ||
|
|
||
| def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]: | ||
| def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]: | ||
| """ | ||
| Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. | ||
| Yields ChatMessageStreamDelta during the run if streaming is enabled. | ||
|
|
@@ -1112,7 +1111,7 @@ def push_to_hub( | |
| token (`bool` or `str`, *optional*): | ||
| The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated | ||
| when running `huggingface-cli login` (stored in `~/.huggingface`). | ||
| create_pr (`bool`, *optional*, defaults to `False`): | ||
| create_pr (`bool`, *optional*, defaults to `False`) | ||
| Whether to create a PR with the uploaded files or directly commit. | ||
| """ | ||
| repo_url = create_repo( | ||
|
|
@@ -1193,11 +1192,15 @@ def __init__( | |
| def initialize_system_prompt(self) -> str: | ||
| system_prompt = populate_template( | ||
| self.prompt_templates["system_prompt"], | ||
| variables={"tools": self.tools, "managed_agents": self.managed_agents}, | ||
| variables={ | ||
| "tools": self.tools, | ||
| "managed_agents": self.managed_agents, | ||
| "custom_instructions": self.instructions, | ||
| }, | ||
| ) | ||
| return system_prompt | ||
|
|
||
| def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]: | ||
| def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]: | ||
| """ | ||
| Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. | ||
| Yields ChatMessageStreamDelta during the run if streaming is enabled. | ||
|
|
@@ -1285,14 +1288,15 @@ def process_tool_calls(self, chat_message: ChatMessage, memory_step: ActionStep) | |
| memory_step (`ActionStep)`: Memory ActionStep to update with results. | ||
|
|
||
| Yields: | ||
| `FinalOutput`: The final output of tool execution. | ||
| `ActionOutput`: The final output of tool execution. | ||
| """ | ||
| model_outputs = [] | ||
| tool_calls = [] | ||
| observations = [] | ||
|
|
||
| final_answer_call = None | ||
| parallel_calls = [] | ||
| assert chat_message.tool_calls is not None | ||
| for tool_call in chat_message.tool_calls: | ||
| tool_name = tool_call.function.name | ||
| tool_arguments = tool_call.function.arguments | ||
|
|
@@ -1332,19 +1336,19 @@ def process_single_tool_call(call_info): | |
| ) | ||
| return observation | ||
|
|
||
| # Process non-final-answer tool calls in parallel | ||
| # Process tool calls in parallel | ||
| if parallel_calls: | ||
| if len(parallel_calls) == 1: | ||
| # If there's only one call, process it directly | ||
| observations.append(process_single_tool_call(parallel_calls[0])) | ||
| yield FinalOutput(output=None) | ||
| yield ActionOutput(output=None, is_final_answer=False) | ||
| else: | ||
| # If multiple tool calls, process them in parallel | ||
| with ThreadPoolExecutor(self.max_tool_threads) as executor: | ||
| futures = [executor.submit(process_single_tool_call, call_info) for call_info in parallel_calls] | ||
| for future in as_completed(futures): | ||
| observations.append(future.result()) | ||
| yield FinalOutput(output=None) | ||
| yield ActionOutput(output=None, is_final_answer=False) | ||
|
|
||
| # Process final_answer call if present | ||
| if final_answer_call: | ||
|
|
@@ -1374,7 +1378,7 @@ def process_single_tool_call(call_info): | |
| level=LogLevel.INFO, | ||
| ) | ||
| memory_step.action_output = final_answer | ||
| yield FinalOutput(output=final_answer) | ||
| yield ActionOutput(output=final_answer, is_final_answer=True) | ||
|
|
||
| # Update memory step with all results | ||
| if model_outputs: | ||
|
|
@@ -1561,11 +1565,12 @@ def initialize_system_prompt(self) -> str: | |
| if "*" in self.authorized_imports | ||
| else str(self.authorized_imports) | ||
| ), | ||
| "custom_instructions": self.instructions, | ||
| }, | ||
| ) | ||
| return system_prompt | ||
|
|
||
| def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]: | ||
| def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]: | ||
| """ | ||
| Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. | ||
| Yields ChatMessageStreamDelta during the run if streaming is enabled. | ||
|
|
@@ -1695,7 +1700,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDe | |
| ] | ||
| self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO) | ||
| memory_step.action_output = output | ||
| yield FinalOutput(output=output if is_final_answer else None) | ||
| yield ActionOutput(output=output, is_final_answer=is_final_answer) | ||
|
|
||
| def to_dict(self) -> dict[str, Any]: | ||
| """Convert the agent to a dictionary representation. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -631,6 +631,17 @@ def check_always_fails(final_answer, agent_memory): | |
| agent.run("Dummy task.") | ||
| assert "Error raised in check" in str(agent.write_memory_to_messages()) | ||
|
|
||
| agent = CodeAgent( | ||
| model=FakeCodeModel(), | ||
| tools=[], | ||
| final_answer_checks=[lambda x, y: x == 7.2904], | ||
| verbosity_level=1000, | ||
| ) | ||
| output = agent.run("Dummy task.") | ||
| assert output == 7.2904 # Check that output is correct | ||
| assert len([step for step in agent.memory.steps if isinstance(step, ActionStep)]) == 2 | ||
| assert "Error raised in check" not in str(agent.write_memory_to_messages()) | ||
|
||
|
|
||
| def test_generation_errors_are_raised(self): | ||
| class FakeCodeModel(Model): | ||
| def generate(self, messages, stop_sequences=None): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
ActionOutputclass is missing the@dataclassdecorator, so instances won’t accept constructor arguments. Add@dataclassabove this class definition.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it isn't, go home copilot you're drunk