Skip to content

Commit 45d6158

Browse files
feat: hooks (#330)
ideas here: 1. Support updating headers and tool args via the `before_tool_call` hook 2. Support returning a new `CallToolResult` in the `after_tool_call` hook (can update tool result content) 3. Hooks add support for accessing server info, tool info, and config/runtime when available
1 parent 219b60c commit 45d6158

File tree

9 files changed

+463
-21
lines changed

9 files changed

+463
-21
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ client = MultiServerMCPClient(
126126
},
127127
"weather": {
128128
# Make sure you start your weather server on port 8000
129-
"url": "http://localhost:8000/mcp/",
129+
"url": "http://localhost:8000/mcp",
130130
"transport": "streamable_http",
131131
}
132132
}
@@ -172,7 +172,7 @@ from mcp.client.streamable_http import streamablehttp_client
172172
from langgraph.prebuilt import create_react_agent
173173
from langchain_mcp_adapters.tools import load_mcp_tools
174174
175-
async with streamablehttp_client("http://localhost:3000/mcp/") as (read, write, _):
175+
async with streamablehttp_client("http://localhost:3000/mcp") as (read, write, _):
176176
async with ClientSession(read, write) as session:
177177
# Initialize the connection
178178
await session.initialize()
@@ -194,7 +194,7 @@ client = MultiServerMCPClient(
194194
{
195195
"math": {
196196
"transport": "streamable_http",
197-
"url": "http://localhost:3000/mcp/"
197+
"url": "http://localhost:3000/mcp"
198198
},
199199
}
200200
)
@@ -255,7 +255,7 @@ client = MultiServerMCPClient(
255255
},
256256
"weather": {
257257
# make sure you start your weather server on port 8000
258-
"url": "http://localhost:8000/mcp/",
258+
"url": "http://localhost:8000/mcp",
259259
"transport": "streamable_http",
260260
}
261261
}
@@ -304,7 +304,7 @@ async def make_graph():
304304
},
305305
"weather": {
306306
# make sure you start your weather server on port 8000
307-
"url": "http://localhost:8000/mcp/",
307+
"url": "http://localhost:8000/mcp",
308308
"transport": "streamable_http",
309309
}
310310
}

langchain_mcp_adapters/callbacks.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,16 @@
33
from dataclasses import dataclass
44
from typing import Protocol
55

6-
from mcp.client.session import LoggingFnT
7-
from mcp.shared.session import ProgressFnT
8-
from mcp.types import LoggingMessageNotificationParams
6+
from mcp.client.session import LoggingFnT as MCPLoggingFnT
7+
from mcp.shared.session import ProgressFnT as MCPProgressFnT
8+
from mcp.types import (
9+
LoggingMessageNotificationParams as MCPLoggingMessageNotificationParams,
10+
)
11+
12+
# Type aliases to avoid direct MCP type dependencies
13+
LoggingFnT = MCPLoggingFnT
14+
ProgressFnT = MCPProgressFnT
15+
LoggingMessageNotificationParams = MCPLoggingMessageNotificationParams
916

1017

1118
@dataclass

langchain_mcp_adapters/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from mcp import ClientSession
1717

1818
from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks
19+
from langchain_mcp_adapters.hooks import Hooks
1920
from langchain_mcp_adapters.prompts import load_mcp_prompt
2021
from langchain_mcp_adapters.resources import load_mcp_resources
2122
from langchain_mcp_adapters.sessions import (
@@ -52,13 +53,15 @@ def __init__(
5253
connections: dict[str, Connection] | None = None,
5354
*,
5455
callbacks: Callbacks | None = None,
56+
hooks: Hooks | None = None,
5557
) -> None:
5658
"""Initialize a MultiServerMCPClient with MCP servers connections.
5759
5860
Args:
5961
connections: A dictionary mapping server names to connection configurations.
6062
If None, no initial connections are established.
6163
callbacks: Optional callbacks for handling notifications and events.
64+
hooks: Optional hooks for before/after tool call processing.
6265
6366
Example: basic usage (starting a new session on each tool call)
6467
@@ -99,6 +102,7 @@ def __init__(
99102
connections if connections is not None else {}
100103
)
101104
self.callbacks = callbacks or Callbacks()
105+
self.hooks = hooks
102106

103107
@asynccontextmanager
104108
async def session(
@@ -163,6 +167,7 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
163167
connection=self.connections[server_name],
164168
callbacks=self.callbacks,
165169
server_name=server_name,
170+
hooks=self.hooks,
166171
)
167172

168173
all_tools: list[BaseTool] = []
@@ -174,6 +179,7 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
174179
connection=connection,
175180
callbacks=self.callbacks,
176181
server_name=name,
182+
hooks=self.hooks,
177183
)
178184
)
179185
load_mcp_tool_tasks.append(load_mcp_tool_task)

langchain_mcp_adapters/hooks.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""Hook interfaces and types for MCP client lifecycle management.
2+
3+
This module provides hook interfaces for intercepting and extending
4+
MCP client behavior before and after tool calls.
5+
6+
In the future, we might add more hooks for other parts of the
7+
request / result lifecycle, for example to support elicitation.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from dataclasses import dataclass
13+
from typing import TYPE_CHECKING, Any, Protocol
14+
15+
from mcp.types import CallToolResult as MCPCallToolResult
16+
from typing_extensions import NotRequired, TypedDict
17+
18+
if TYPE_CHECKING:
19+
from langchain_core.runnables import RunnableConfig
20+
21+
# Type aliases to avoid direct MCP type dependencies
22+
CallToolResult = MCPCallToolResult
23+
24+
25+
@dataclass
26+
class ToolHookContext:
27+
"""Context object passed to hooks containing state and server information."""
28+
29+
server_name: str
30+
tool_name: str
31+
32+
# we'll add state eventually when we have a context manager like get_state()
33+
# state: object | None = None
34+
config: RunnableConfig | None = None
35+
runtime: object | None = None
36+
37+
38+
class CallToolRequestSpec(TypedDict, total=False):
39+
"""Result of before tool call hook."""
40+
41+
name: NotRequired[str]
42+
args: NotRequired[dict[str, Any]]
43+
headers: NotRequired[dict[str, Any]]
44+
45+
46+
class BeforeToolCallHook(Protocol):
47+
"""Protocol for before_tool_call hook functions.
48+
49+
Allows modification of tool call arguments and headers before execution.
50+
Return None to proceed with original request.
51+
"""
52+
53+
async def __call__(
54+
self,
55+
request: CallToolRequestSpec,
56+
context: ToolHookContext,
57+
) -> CallToolRequestSpec | None:
58+
"""Execute before tool call.
59+
60+
Args:
61+
request: The original tool call request
62+
context: Hook context with server/tool info and shared state
63+
64+
Returns:
65+
Modified CallToolRequest or None to use original request
66+
"""
67+
...
68+
69+
70+
class AfterToolCallHook(Protocol):
71+
"""Protocol for after_tool_call hook functions.
72+
73+
Allows modification of tool call results after execution.
74+
Return None to proceed with original result processing.
75+
Return CallToolResult to use the modified result.
76+
"""
77+
78+
async def __call__(
79+
self,
80+
result: CallToolResult,
81+
context: ToolHookContext,
82+
) -> CallToolResult | None:
83+
"""Execute after tool call.
84+
85+
Args:
86+
result: The original tool call result
87+
context: Hook context with server/tool info and shared state
88+
89+
Returns:
90+
- CallToolResult to use the modified result
91+
- None to use original result
92+
"""
93+
...
94+
95+
96+
class Hooks:
97+
"""Container for MCP client hook functions."""
98+
99+
def __init__(
100+
self,
101+
*,
102+
before_tool_call: BeforeToolCallHook | None = None,
103+
after_tool_call: AfterToolCallHook | None = None,
104+
) -> None:
105+
"""Initialize hooks.
106+
107+
Args:
108+
before_tool_call: Hook called before tool execution
109+
after_tool_call: Hook called after tool execution
110+
"""
111+
self.before_tool_call = before_tool_call
112+
self.after_tool_call = after_tool_call

langchain_mcp_adapters/tools.py

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,23 @@
2828
from pydantic import BaseModel, create_model
2929

3030
from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks, _MCPCallbacks
31+
from langchain_mcp_adapters.hooks import CallToolRequestSpec, Hooks, ToolHookContext
3132
from langchain_mcp_adapters.sessions import Connection, create_session
3233

34+
try:
35+
from langgraph.config import get_config
36+
from langgraph.runtime import get_runtime
37+
except ImportError:
38+
39+
def get_config() -> dict:
40+
"""no-op config getter."""
41+
return {}
42+
43+
def get_runtime() -> None:
44+
"""no-op runtime getter."""
45+
return
46+
47+
3348
NonTextContent = ImageContent | AudioContent | ResourceLink | EmbeddedResource
3449
MAX_ITERATIONS = 1000
3550

@@ -111,6 +126,7 @@ def convert_mcp_tool_to_langchain_tool(
111126
*,
112127
connection: Connection | None = None,
113128
callbacks: Callbacks | None = None,
129+
hooks: Hooks | None = None,
114130
server_name: str | None = None,
115131
) -> BaseTool:
116132
"""Convert an MCP tool to a LangChain tool.
@@ -123,6 +139,7 @@ def convert_mcp_tool_to_langchain_tool(
123139
connection: Optional connection config to use to create a new session
124140
if a `session` is not provided
125141
callbacks: Optional callbacks for handling notifications and events
142+
hooks: Optional hooks for before/after tool call processing
126143
server_name: Name of the server this tool belongs to
127144
128145
Returns:
@@ -144,27 +161,73 @@ async def call_tool(
144161
else _MCPCallbacks()
145162
)
146163

164+
tool_name = tool.name
165+
tool_args = arguments
166+
effective_connection = connection
167+
168+
# try to get config and runtime if we're in a langgraph context
169+
try:
170+
config = get_config()
171+
runtime = get_runtime()
172+
except Exception: # noqa: BLE001
173+
config = {}
174+
runtime = None
175+
176+
hook_context = ToolHookContext(
177+
server_name=server_name or "unknown",
178+
tool_name=tool.name,
179+
config=config,
180+
runtime=runtime,
181+
)
182+
183+
if hooks and hooks.before_tool_call:
184+
tool_request_spec = CallToolRequestSpec(
185+
name=tool.name,
186+
args=arguments,
187+
)
188+
189+
modified_request = await hooks.before_tool_call(
190+
tool_request_spec, hook_context
191+
)
192+
if modified_request is not None:
193+
tool_name = modified_request.get("name") or tool_name
194+
tool_args = modified_request.get("args") or tool_args
195+
196+
# If headers were modified, create a new connection with updated headers
197+
modified_headers = modified_request.get("headers")
198+
if modified_headers is not None and connection is not None:
199+
# Create a new connection config with updated headers
200+
updated_connection = dict(connection)
201+
if connection["transport"] in ("sse", "streamable_http"):
202+
existing_headers = connection.get("headers", {})
203+
updated_connection["headers"] = {
204+
**existing_headers,
205+
**modified_headers,
206+
}
207+
effective_connection = updated_connection
208+
147209
# Execute the tool call
148210
call_tool_result = None
211+
149212
if session is None:
150213
# If a session is not provided, we will create one on the fly
151-
if connection is None:
214+
if effective_connection is None:
152215
msg = "Either session or connection must be provided"
153216
raise ValueError(msg)
154217

155218
async with create_session(
156-
connection, mcp_callbacks=mcp_callbacks
219+
effective_connection, mcp_callbacks=mcp_callbacks
157220
) as tool_session:
158221
await tool_session.initialize()
159222
call_tool_result = await cast("ClientSession", tool_session).call_tool(
160-
tool.name,
161-
arguments,
223+
tool_name,
224+
tool_args,
162225
progress_callback=mcp_callbacks.progress_callback,
163226
)
164227
else:
165228
call_tool_result = await session.call_tool(
166-
tool.name,
167-
arguments,
229+
tool_name,
230+
tool_args,
168231
progress_callback=mcp_callbacks.progress_callback,
169232
)
170233

@@ -177,10 +240,14 @@ async def call_tool(
177240
)
178241
raise RuntimeError(msg)
179242

180-
return _convert_call_tool_result(call_tool_result)
243+
if hooks and hooks.after_tool_call:
244+
hook_result = await hooks.after_tool_call(call_tool_result, hook_context)
245+
if hook_result is not None:
246+
call_tool_result = hook_result
181247

182-
meta = tool.meta if hasattr(tool, "meta") else None
248+
return _convert_call_tool_result(call_tool_result)
183249

250+
meta = getattr(tool, "meta", None)
184251
base = tool.annotations.model_dump() if tool.annotations is not None else {}
185252
meta = {"_meta": meta} if meta is not None else {}
186253
metadata = {**base, **meta} or None
@@ -200,6 +267,7 @@ async def load_mcp_tools(
200267
*,
201268
connection: Connection | None = None,
202269
callbacks: Callbacks | None = None,
270+
hooks: Hooks | None = None,
203271
server_name: str | None = None,
204272
) -> list[BaseTool]:
205273
"""Load all available MCP tools and convert them to LangChain tools.
@@ -208,6 +276,7 @@ async def load_mcp_tools(
208276
session: The MCP client session. If None, connection must be provided.
209277
connection: Connection config to create a new session if session is None.
210278
callbacks: Optional callbacks for handling notifications and events.
279+
hooks: Optional hooks for before/after tool call processing.
211280
server_name: Name of the server these tools belong to.
212281
213282
Returns:
@@ -246,6 +315,7 @@ async def load_mcp_tools(
246315
tool,
247316
connection=connection,
248317
callbacks=callbacks,
318+
hooks=hooks,
249319
server_name=server_name,
250320
)
251321
for tool in tools

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ ignore = [
5555
"C901", # Too complex
5656
"PLR0913", # Too many arguments in function definition
5757
"PLR0912", # Too many branches
58+
"PLR0915", # Too many statements
59+
"ERA001", # Commented out code should be removed
5860
]
5961

6062

0 commit comments

Comments
 (0)