Skip to content

Commit d391481

Browse files
feat: add handle_tool_error and handle_validation_error to load_mcp_tools
1 parent 8b02c53 commit d391481

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

langchain_mcp_adapters/client.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55
"""
66

77
import asyncio
8-
from collections.abc import AsyncIterator
8+
from collections.abc import AsyncIterator, Callable
99
from contextlib import asynccontextmanager
1010
from types import TracebackType
1111
from typing import Any
1212

1313
from langchain_core.documents.base import Blob
1414
from langchain_core.messages import AIMessage, HumanMessage
15-
from langchain_core.tools import BaseTool
15+
from langchain_core.tools import BaseTool, ToolException
1616
from mcp import ClientSession
17+
from pydantic import ValidationError
1718

1819
from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks
1920
from langchain_mcp_adapters.interceptors import ToolCallInterceptor
@@ -143,12 +144,22 @@ async def session(
143144
await session.initialize()
144145
yield session
145146

146-
async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
147+
async def get_tools(
148+
self,
149+
*,
150+
server_name: str | None = None,
151+
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False,
152+
handle_validation_error: (
153+
bool | str | Callable[[ValidationError], str] | None
154+
) = False,
155+
) -> list[BaseTool]:
147156
"""Get a list of all tools from all connected servers.
148157
149158
Args:
150159
server_name: Optional name of the server to get tools from.
151160
If `None`, all tools from all servers will be returned.
161+
handle_tool_error: Optional error handler for tool execution errors.
162+
handle_validation_error: Optional error handler for validation errors.
152163
153164
!!! note
154165
@@ -171,6 +182,8 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
171182
callbacks=self.callbacks,
172183
server_name=server_name,
173184
tool_interceptors=self.tool_interceptors,
185+
handle_tool_error=handle_tool_error,
186+
handle_validation_error=handle_validation_error,
174187
)
175188

176189
all_tools: list[BaseTool] = []
@@ -183,6 +196,8 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
183196
callbacks=self.callbacks,
184197
server_name=name,
185198
tool_interceptors=self.tool_interceptors,
199+
handle_tool_error=handle_tool_error,
200+
handle_validation_error=handle_validation_error,
186201
)
187202
)
188203
load_mcp_tool_tasks.append(load_mcp_tool_task)

langchain_mcp_adapters/tools.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
TextContent,
2626
)
2727
from mcp.types import Tool as MCPTool
28-
from pydantic import BaseModel, create_model
28+
from pydantic import BaseModel, ValidationError, create_model
2929

3030
from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks, _MCPCallbacks
3131
from langchain_mcp_adapters.interceptors import (
@@ -161,6 +161,10 @@ def convert_mcp_tool_to_langchain_tool(
161161
callbacks: Callbacks | None = None,
162162
tool_interceptors: list[ToolCallInterceptor] | None = None,
163163
server_name: str | None = None,
164+
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False,
165+
handle_validation_error: (
166+
bool | str | Callable[[ValidationError], str] | None
167+
) = False,
164168
) -> BaseTool:
165169
"""Convert an MCP tool to a LangChain tool.
166170
@@ -174,6 +178,8 @@ def convert_mcp_tool_to_langchain_tool(
174178
callbacks: Optional callbacks for handling notifications and events
175179
tool_interceptors: Optional list of interceptors for tool call processing
176180
server_name: Name of the server this tool belongs to
181+
handle_tool_error: Optional error handler for tool execution errors.
182+
handle_validation_error: Optional error handler for validation errors.
177183
178184
Returns:
179185
a LangChain tool
@@ -303,6 +309,8 @@ async def execute_tool(request: MCPToolCallRequest) -> MCPToolCallResult:
303309
coroutine=call_tool,
304310
response_format="content_and_artifact",
305311
metadata=metadata,
312+
handle_tool_error=handle_tool_error,
313+
handle_validation_error=handle_validation_error,
306314
)
307315

308316

@@ -313,6 +321,10 @@ async def load_mcp_tools(
313321
callbacks: Callbacks | None = None,
314322
tool_interceptors: list[ToolCallInterceptor] | None = None,
315323
server_name: str | None = None,
324+
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False,
325+
handle_validation_error: (
326+
bool | str | Callable[[ValidationError], str] | None
327+
) = False,
316328
) -> list[BaseTool]:
317329
"""Load all available MCP tools and convert them to LangChain [tools](https://docs.langchain.com/oss/python/langchain/tools).
318330
@@ -322,6 +334,8 @@ async def load_mcp_tools(
322334
callbacks: Optional `Callbacks` for handling notifications and events.
323335
tool_interceptors: Optional list of interceptors for tool call processing.
324336
server_name: Name of the server these tools belong to.
337+
handle_tool_error: Optional error handler for tool execution errors.
338+
handle_validation_error: Optional error handler for validation errors.
325339
326340
Returns:
327341
List of LangChain [tools](https://docs.langchain.com/oss/python/langchain/tools).
@@ -361,6 +375,8 @@ async def load_mcp_tools(
361375
callbacks=callbacks,
362376
tool_interceptors=tool_interceptors,
363377
server_name=server_name,
378+
handle_tool_error=handle_tool_error,
379+
handle_validation_error=handle_validation_error,
364380
)
365381
for tool in tools
366382
]

0 commit comments

Comments
 (0)