-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Fix and refactor final answer checks #1448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
3a92725
10a0764
cd143f8
03f17e8
3092d62
4441aec
fc7cc89
54dbde6
7b39442
ca5a539
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -110,8 +110,9 @@ def populate_template(template: str, variables: dict[str, Any]) -> str: | |
|
|
||
|
|
||
| @dataclass | ||
| class FinalOutput: | ||
| output: Any | None | ||
| class ActionOutput: | ||
| output: Any | ||
| is_final_answer: bool | ||
|
|
||
|
|
||
| class PlanningPromptTemplate(TypedDict): | ||
|
|
@@ -280,7 +281,7 @@ def __init__( | |
| self.name = self._validate_name(name) | ||
| self.description = description | ||
| self.provide_run_summary = provide_run_summary | ||
| self.final_answer_checks = final_answer_checks | ||
| self.final_answer_checks = final_answer_checks if final_answer_checks is not None else [] | ||
| self.return_full_result = return_full_result | ||
| self.instructions = instructions | ||
| self._setup_managed_agents(managed_agents) | ||
|
|
@@ -451,9 +452,9 @@ def run( | |
| def _run_stream( | ||
| self, task: str, max_steps: int, images: list["PIL.Image.Image"] | None = None | ||
| ) -> Generator[ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta]: | ||
| final_answer = None | ||
| self.step_number = 1 | ||
| while final_answer is None and self.step_number <= max_steps: | ||
| returned_final_answer = False | ||
| while not returned_final_answer and self.step_number <= max_steps: | ||
| if self.interrupt_switch: | ||
| raise AgentError("Agent interrupted.", self.logger) | ||
|
|
||
|
|
@@ -464,8 +465,8 @@ def _run_stream( | |
| planning_start_time = time.time() | ||
| planning_step = None | ||
| for element in self._generate_planning_step( | ||
| task, is_first_step=(len(self.memory.steps) == 1), step=self.step_number | ||
| ): | ||
| task, is_first_step=len(self.memory.steps) == 1, step=self.step_number | ||
| ): # Don't use the attribute step_number here, because there can be steps from previous runs | ||
| yield element | ||
| planning_step = element | ||
| assert isinstance(planning_step, PlanningStep) # Last yielded element should be a PlanningStep | ||
|
|
@@ -483,10 +484,19 @@ def _run_stream( | |
| timing=Timing(start_time=action_step_start_time), | ||
| observations_images=images, | ||
| ) | ||
| self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO) | ||
| try: | ||
| for el in self._execute_step(action_step): | ||
| yield el | ||
| final_answer = el | ||
| for output in self._step_stream(action_step): | ||
| # Yield streaming deltas | ||
| if not isinstance(output, ActionOutput): | ||
aymeric-roucher marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| yield output | ||
|
|
||
| if isinstance(output, ActionOutput) and output.is_final_answer: | ||
| if self.final_answer_checks: | ||
| self._validate_final_answer(output.output) | ||
| returned_final_answer = True | ||
| action_step.is_final_answer = True | ||
| final_answer = output.output | ||
| except AgentGenerationError as e: | ||
| # Agent generation errors are not caused by a Model error but an implementation error: so we should raise them and exit. | ||
| raise e | ||
|
|
@@ -499,23 +509,11 @@ def _run_stream( | |
| yield action_step | ||
| self.step_number += 1 | ||
|
|
||
| if final_answer is None and self.step_number == max_steps + 1: | ||
| if not returned_final_answer and self.step_number == max_steps + 1: | ||
| final_answer = self._handle_max_steps_reached(task, images) | ||
| yield action_step | ||
| yield FinalAnswerStep(handle_agent_output_types(final_answer)) | ||
|
|
||
| def _execute_step(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think porting this logic into a separate function obfuscates the logic more than it clarifies. |
||
| self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO) | ||
| for el in self._step_stream(memory_step): | ||
| final_answer = el | ||
| if isinstance(el, ChatMessageStreamDelta): | ||
| yield el | ||
| elif isinstance(el, FinalOutput): | ||
| final_answer = el.output | ||
| if self.final_answer_checks: | ||
| self._validate_final_answer(final_answer) | ||
| yield final_answer | ||
|
|
||
| def _validate_final_answer(self, final_answer: Any): | ||
| for check_function in self.final_answer_checks: | ||
| try: | ||
|
|
@@ -674,7 +672,7 @@ def interrupt(self): | |
|
|
||
| def write_memory_to_messages( | ||
| self, | ||
| summary_mode: bool | None = False, | ||
| summary_mode: bool = False, | ||
| ) -> list[Message]: | ||
| """ | ||
| Reads past llm_outputs, actions, and observations or errors from the memory into a series of messages | ||
|
|
@@ -686,7 +684,7 @@ def write_memory_to_messages( | |
| messages.extend(memory_step.to_messages(summary_mode=summary_mode)) | ||
| return messages | ||
|
|
||
| def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]: | ||
| def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]: | ||
| """ | ||
| Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. | ||
| Yields ChatMessageStreamDelta during the run if streaming is enabled. | ||
|
|
@@ -1203,7 +1201,7 @@ def initialize_system_prompt(self) -> str: | |
| ) | ||
| return system_prompt | ||
|
|
||
| def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]: | ||
| def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]: | ||
| """ | ||
| Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. | ||
| Yields ChatMessageStreamDelta during the run if streaming is enabled. | ||
|
|
@@ -1291,14 +1289,15 @@ def process_tool_calls(self, chat_message: ChatMessage, memory_step: ActionStep) | |
| memory_step (`ActionStep)`: Memory ActionStep to update with results. | ||
|
|
||
| Yields: | ||
| `FinalOutput`: The final output of tool execution. | ||
| `ActionOutput`: The final output of tool execution. | ||
| """ | ||
| model_outputs = [] | ||
| tool_calls = [] | ||
| observations = [] | ||
|
|
||
| final_answer_call = None | ||
| parallel_calls = [] | ||
| assert chat_message.tool_calls is not None | ||
| for tool_call in chat_message.tool_calls: | ||
| tool_name = tool_call.function.name | ||
| tool_arguments = tool_call.function.arguments | ||
|
|
@@ -1338,19 +1337,19 @@ def process_single_tool_call(call_info): | |
| ) | ||
| return observation | ||
|
|
||
| # Process non-final-answer tool calls in parallel | ||
| # Process tool calls in parallel | ||
| if parallel_calls: | ||
| if len(parallel_calls) == 1: | ||
| # If there's only one call, process it directly | ||
| observations.append(process_single_tool_call(parallel_calls[0])) | ||
| yield FinalOutput(output=None) | ||
| yield ActionOutput(output=None, is_final_answer=False) | ||
| else: | ||
| # If multiple tool calls, process them in parallel | ||
| with ThreadPoolExecutor(self.max_tool_threads) as executor: | ||
| futures = [executor.submit(process_single_tool_call, call_info) for call_info in parallel_calls] | ||
| for future in as_completed(futures): | ||
| observations.append(future.result()) | ||
| yield FinalOutput(output=None) | ||
| yield ActionOutput(output=None, is_final_answer=False) | ||
|
|
||
| # Process final_answer call if present | ||
| if final_answer_call: | ||
|
|
@@ -1380,7 +1379,7 @@ def process_single_tool_call(call_info): | |
| level=LogLevel.INFO, | ||
| ) | ||
| memory_step.action_output = final_answer | ||
| yield FinalOutput(output=final_answer) | ||
| yield ActionOutput(output=final_answer, is_final_answer=True) | ||
|
|
||
| # Update memory step with all results | ||
| if model_outputs: | ||
|
|
@@ -1572,7 +1571,7 @@ def initialize_system_prompt(self) -> str: | |
| ) | ||
| return system_prompt | ||
|
|
||
| def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]: | ||
| def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]: | ||
| """ | ||
| Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. | ||
| Yields ChatMessageStreamDelta during the run if streaming is enabled. | ||
|
|
@@ -1702,7 +1701,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDe | |
| ] | ||
| self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO) | ||
| memory_step.action_output = output | ||
| yield FinalOutput(output=output if is_final_answer else None) | ||
| yield ActionOutput(output=output, is_final_answer=is_final_answer) | ||
|
|
||
| def to_dict(self) -> dict[str, Any]: | ||
| """Convert the agent to a dictionary representation. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -619,6 +619,17 @@ def check_always_fails(final_answer, agent_memory): | |
| agent.run("Dummy task.") | ||
| assert "Error raised in check" in str(agent.write_memory_to_messages()) | ||
|
|
||
| agent = CodeAgent( | ||
| model=FakeCodeModel(), | ||
| tools=[], | ||
| final_answer_checks=[lambda x, y: x == 7.2904], | ||
| verbosity_level=1000, | ||
| ) | ||
| output = agent.run("Dummy task.") | ||
| assert output == 7.2904 # Check that output is correct | ||
| assert len([step for step in agent.memory.steps if isinstance(step, ActionStep)]) == 2 | ||
| assert "Error raised in check" not in str(agent.write_memory_to_messages()) | ||
|
||
|
|
||
| def test_generation_errors_are_raised(self): | ||
| class FakeCodeModel(Model): | ||
| def generate(self, messages, stop_sequences=None): | ||
|
|
@@ -640,19 +651,21 @@ def test_planning_step_with_injected_memory(self): | |
| agent.memory.steps.append(previous_step) | ||
|
|
||
| # Run the agent | ||
| agent.run(task, reset=False) | ||
| agent.run(task, reset=False, max_steps=2) | ||
|
|
||
| # Verify that the planning step used update plan prompts | ||
| planning_steps = [step for step in agent.memory.steps if isinstance(step, PlanningStep)] | ||
| assert len(planning_steps) > 0 | ||
|
|
||
| # Check that the planning step's model input messages contain the injected memory | ||
| planning_step = planning_steps[0] | ||
| assert len(planning_step.model_input_messages) == 3 # system message + memory messages + user message | ||
| assert planning_step.model_input_messages[0]["role"] == "system" | ||
| assert task in planning_step.model_input_messages[0]["content"][0]["text"] | ||
| assert planning_step.model_input_messages[1]["role"] == "user" | ||
| assert "Previous user request" in planning_step.model_input_messages[1]["content"][0]["text"] | ||
| update_plan_step = planning_steps[0] | ||
| assert ( | ||
| len(update_plan_step.model_input_messages) == 3 | ||
| ) # system message + memory messages (1 task message, the latest one is removed) + user message | ||
| assert update_plan_step.model_input_messages[0]["role"] == "system" | ||
| assert task in update_plan_step.model_input_messages[0]["content"][0]["text"] | ||
| assert update_plan_step.model_input_messages[1]["role"] == "user" | ||
| assert "Previous user request" in update_plan_step.model_input_messages[1]["content"][0]["text"] | ||
|
|
||
|
|
||
| class CustomFinalAnswerTool(FinalAnswerTool): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
ActionOutputclass is missing the@dataclassdecorator, so instances won’t accept constructor arguments. Add@dataclassabove this class definition.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it isn't, go home copilot you're drunk