diff --git a/verl/experimental/agent_loop/tool_agent_loop.py b/verl/experimental/agent_loop/tool_agent_loop.py index 1f6208dce9..55a4d4abc6 100644 --- a/verl/experimental/agent_loop/tool_agent_loop.py +++ b/verl/experimental/agent_loop/tool_agent_loop.py @@ -67,6 +67,7 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu ), ) response_mask = [] + tools_kwargs = kwargs.get("tools_kwargs", {}) user_turns, assistant_turns = 0, 0 while True: @@ -98,7 +99,7 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu # call tools tasks = [] for tool_call in tool_calls[: self.max_parallel_calls]: - tasks.append(self._call_tool(tool_call)) + tasks.append(self._call_tool(tool_call, tools_kwargs)) with simple_timer("tool_calls", metrics): tool_responses = await asyncio.gather(*tasks) if any(isinstance(item, Exception) for item in tool_responses): @@ -134,7 +135,7 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu ) return output - async def _call_tool(self, tool_call: FunctionCall) -> dict[str, str]: + async def _call_tool(self, tool_call: FunctionCall, tools_kwargs: dict[str, Any]) -> dict[str, str]: """Call tool and return tool response.""" tool, instance_id = None, None try: @@ -142,8 +143,8 @@ async def _call_tool(self, tool_call: FunctionCall) -> dict[str, str]: tool_name = tool_call.name tool_args = json.loads(tool_call.arguments) tool = self.tools[tool_name] - - instance_id = await tool.create() + kwargs = tools_kwargs.get(tool_name, {}) + instance_id = await tool.create(create_kwargs=kwargs.get("create_kwargs", {})) tool_response, _, _ = await tool.execute(instance_id, tool_args) except Exception as e: logger.exception(f"Error when executing tool: {e}") diff --git a/verl/tools/gsm8k_tool.py b/verl/tools/gsm8k_tool.py index fd11eeb6ba..e9048fcb95 100644 --- a/verl/tools/gsm8k_tool.py +++ b/verl/tools/gsm8k_tool.py @@ -67,6 +67,8 @@ def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: if instance_id is None: instance_id = str(uuid4()) + if ground_truth is None: + ground_truth = kwargs.get("create_kwargs", {}).get("ground_truth", None) self._instance_dict[instance_id] = { "response": "", "ground_truth": ground_truth,