Skip to content

Commit 219b60c

Browse files
feat: implementing support for callbacks (#328)
Introduces a callback system for monitoring MCP server events during tool execution. Callbacks contain context with server name and tool name and are applied to all servers for a given client. **New API:** ```python from langchain_mcp_adapters.client import MultiServerMCPClient from langchain_mcp_adapters.callbacks import Callbacks async def on_progress(progress, total, message, context): print(f"[{context.server_name}:{context.tool_name}] {progress}/{total}") client = MultiServerMCPClient(connections={...}, callbacks=Callbacks(on_progress=on_progress)) ``` `Callbacks` has the potential to be extended quite broadly to support a variety of notification hooks. We simply match signatures from `mcp` + add a final `context` arg which is easily extensible as well. **What's Added:** - `CallbackContext` - provides server/tool context to callbacks - `Callbacks` - main configuration class - `LoggingMessageCallback` & `ProgressCallback` - callback protocols w/ context - Optional `callbacks` parameter on `MultiServerMCPClient`
1 parent 9bd9f74 commit 219b60c

File tree

11 files changed

+417
-37
lines changed

11 files changed

+417
-37
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""Types for callbacks."""
2+
3+
from dataclasses import dataclass
4+
from typing import Protocol
5+
6+
from mcp.client.session import LoggingFnT
7+
from mcp.shared.session import ProgressFnT
8+
from mcp.types import LoggingMessageNotificationParams
9+
10+
11+
@dataclass
12+
class CallbackContext:
13+
"""LangChain MCP client callback context."""
14+
15+
server_name: str
16+
tool_name: str | None = None
17+
18+
19+
class LoggingMessageCallback(Protocol):
20+
"""Light wrapper around the mcp.client.session.LoggingFnT.
21+
22+
Injects callback context as the last argument.
23+
"""
24+
25+
async def __call__(
26+
self,
27+
params: LoggingMessageNotificationParams,
28+
context: CallbackContext,
29+
) -> None:
30+
"""Execute callback on logging message notification."""
31+
...
32+
33+
34+
class ProgressCallback(Protocol):
35+
"""Light wrapper around the mcp.shared.session.ProgressFnT.
36+
37+
Injects callback context as the last argument.
38+
"""
39+
40+
async def __call__(
41+
self,
42+
progress: float,
43+
total: float | None,
44+
message: str | None,
45+
context: CallbackContext,
46+
) -> None:
47+
"""Execute callback on progress notification."""
48+
...
49+
50+
51+
@dataclass
52+
class _MCPCallbacks:
53+
"""Callbacks compatible with the MCP SDK. For internal use only."""
54+
55+
logging_callback: LoggingFnT | None = None
56+
progress_callback: ProgressFnT | None = None
57+
58+
59+
@dataclass
60+
class Callbacks:
61+
"""Callbacks for the LangChain MCP client."""
62+
63+
on_logging_message: LoggingMessageCallback | None = None
64+
on_progress: ProgressCallback | None = None
65+
66+
def to_mcp_format(self, *, context: CallbackContext) -> _MCPCallbacks:
67+
"""Convert the LangChain MCP client callbacks to MCP SDK callbacks.
68+
69+
Injects the LangChain CallbackContext as the last argument.
70+
"""
71+
if (on_logging_message := self.on_logging_message) is not None:
72+
73+
async def mcp_logging_callback(
74+
params: LoggingMessageNotificationParams,
75+
) -> None:
76+
await on_logging_message(params, context)
77+
else:
78+
mcp_logging_callback = None
79+
80+
if (on_progress := self.on_progress) is not None:
81+
82+
async def mcp_progress_callback(
83+
progress: float, total: float | None, message: str | None
84+
) -> None:
85+
await on_progress(progress, total, message, context)
86+
else:
87+
mcp_progress_callback = None
88+
89+
return _MCPCallbacks(
90+
logging_callback=mcp_logging_callback,
91+
progress_callback=mcp_progress_callback,
92+
)

langchain_mcp_adapters/client.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from langchain_core.tools import BaseTool
1616
from mcp import ClientSession
1717

18+
from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks
1819
from langchain_mcp_adapters.prompts import load_mcp_prompt
1920
from langchain_mcp_adapters.resources import load_mcp_resources
2021
from langchain_mcp_adapters.sessions import (
@@ -46,12 +47,18 @@ class MultiServerMCPClient:
4647
Loads LangChain-compatible tools, prompts and resources from MCP servers.
4748
"""
4849

49-
def __init__(self, connections: dict[str, Connection] | None = None) -> None:
50+
def __init__(
51+
self,
52+
connections: dict[str, Connection] | None = None,
53+
*,
54+
callbacks: Callbacks | None = None,
55+
) -> None:
5056
"""Initialize a MultiServerMCPClient with MCP servers connections.
5157
5258
Args:
5359
connections: A dictionary mapping server names to connection configurations.
5460
If None, no initial connections are established.
61+
callbacks: Optional callbacks for handling notifications and events.
5562
5663
Example: basic usage (starting a new session on each tool call)
5764
@@ -87,11 +94,11 @@ def __init__(self, connections: dict[str, Connection] | None = None) -> None:
8794
async with client.session("math") as session:
8895
tools = await load_mcp_tools(session)
8996
```
90-
9197
"""
9298
self.connections: dict[str, Connection] = (
9399
connections if connections is not None else {}
94100
)
101+
self.callbacks = callbacks or Callbacks()
95102

96103
@asynccontextmanager
97104
async def session(
@@ -120,7 +127,13 @@ async def session(
120127
)
121128
raise ValueError(msg)
122129

123-
async with create_session(self.connections[server_name]) as session:
130+
mcp_callbacks = self.callbacks.to_mcp_format(
131+
context=CallbackContext(server_name=server_name)
132+
)
133+
134+
async with create_session(
135+
self.connections[server_name], mcp_callbacks=mcp_callbacks
136+
) as session:
124137
if auto_initialize:
125138
await session.initialize()
126139
yield session
@@ -145,13 +158,23 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
145158
f"expected one of '{list(self.connections.keys())}'"
146159
)
147160
raise ValueError(msg)
148-
return await load_mcp_tools(None, connection=self.connections[server_name])
161+
return await load_mcp_tools(
162+
None,
163+
connection=self.connections[server_name],
164+
callbacks=self.callbacks,
165+
server_name=server_name,
166+
)
149167

150168
all_tools: list[BaseTool] = []
151169
load_mcp_tool_tasks = []
152-
for connection in self.connections.values():
170+
for name, connection in self.connections.items():
153171
load_mcp_tool_task = asyncio.create_task(
154-
load_mcp_tools(None, connection=connection)
172+
load_mcp_tools(
173+
None,
174+
connection=connection,
175+
callbacks=self.callbacks,
176+
server_name=name,
177+
)
155178
)
156179
load_mcp_tool_tasks.append(load_mcp_tool_task)
157180
tools_list = await asyncio.gather(*load_mcp_tool_tasks)
@@ -218,6 +241,7 @@ def __aexit__(
218241

219242

220243
__all__ = [
244+
"Callbacks",
221245
"McpHttpClientFactory",
222246
"MultiServerMCPClient",
223247
"SSEConnection",

langchain_mcp_adapters/sessions.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
import httpx
2424

25+
from langchain_mcp_adapters.callbacks import _MCPCallbacks
26+
2527
EncodingErrorHandler = Literal["strict", "ignore", "replace"]
2628

2729
DEFAULT_ENCODING = "utf-8"
@@ -187,7 +189,7 @@ class WebsocketConnection(TypedDict):
187189

188190

189191
@asynccontextmanager
190-
async def _create_stdio_session( # noqa: PLR0913
192+
async def _create_stdio_session(
191193
*,
192194
command: str,
193195
args: list[str],
@@ -233,7 +235,7 @@ async def _create_stdio_session( # noqa: PLR0913
233235

234236

235237
@asynccontextmanager
236-
async def _create_sse_session( # noqa: PLR0913
238+
async def _create_sse_session(
237239
*,
238240
url: str,
239241
headers: dict[str, Any] | None = None,
@@ -273,7 +275,7 @@ async def _create_sse_session( # noqa: PLR0913
273275

274276

275277
@asynccontextmanager
276-
async def _create_streamable_http_session( # noqa: PLR0913
278+
async def _create_streamable_http_session(
277279
*,
278280
url: str,
279281
headers: dict[str, Any] | None = None,
@@ -356,11 +358,14 @@ async def _create_websocket_session(
356358

357359

358360
@asynccontextmanager
359-
async def create_session(connection: Connection) -> AsyncIterator[ClientSession]: # noqa: C901
361+
async def create_session(
362+
connection: Connection, *, mcp_callbacks: _MCPCallbacks | None = None
363+
) -> AsyncIterator[ClientSession]:
360364
"""Create a new session to an MCP server.
361365
362366
Args:
363367
connection: Connection config to use to connect to the server
368+
mcp_callbacks: mcp sdk compatible callbacks to use for the ClientSession
364369
365370
Raises:
366371
ValueError: If transport is not recognized
@@ -381,6 +386,16 @@ async def create_session(connection: Connection) -> AsyncIterator[ClientSession]
381386
transport = connection["transport"]
382387
params = {k: v for k, v in connection.items() if k != "transport"}
383388

389+
if mcp_callbacks is not None:
390+
params["session_kwargs"] = params.get("session_kwargs", {})
391+
# right now the only callback supported on the ClientSession
392+
# is the logging callback, but long term we'll also want to
393+
# support sampling, elicitation, list roots, etc.
394+
if mcp_callbacks.logging_callback is not None:
395+
params["session_kwargs"]["logging_callback"] = (
396+
mcp_callbacks.logging_callback
397+
)
398+
384399
if transport == "sse":
385400
if "url" not in params:
386401
msg = "'url' parameter is required for SSE connection"

langchain_mcp_adapters/tools.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,21 @@
1616
from mcp import ClientSession
1717
from mcp.server.fastmcp.tools import Tool as FastMCPTool
1818
from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata
19-
from mcp.types import CallToolResult, EmbeddedResource, ImageContent, TextContent
19+
from mcp.types import (
20+
AudioContent,
21+
CallToolResult,
22+
EmbeddedResource,
23+
ImageContent,
24+
ResourceLink,
25+
TextContent,
26+
)
2027
from mcp.types import Tool as MCPTool
2128
from pydantic import BaseModel, create_model
2229

30+
from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks, _MCPCallbacks
2331
from langchain_mcp_adapters.sessions import Connection, create_session
2432

25-
NonTextContent = ImageContent | EmbeddedResource
33+
NonTextContent = ImageContent | AudioContent | ResourceLink | EmbeddedResource
2634
MAX_ITERATIONS = 1000
2735

2836

@@ -102,6 +110,8 @@ def convert_mcp_tool_to_langchain_tool(
102110
tool: MCPTool,
103111
*,
104112
connection: Connection | None = None,
113+
callbacks: Callbacks | None = None,
114+
server_name: str | None = None,
105115
) -> BaseTool:
106116
"""Convert an MCP tool to a LangChain tool.
107117
@@ -112,6 +122,8 @@ def convert_mcp_tool_to_langchain_tool(
112122
tool: MCP tool to convert
113123
connection: Optional connection config to use to create a new session
114124
if a `session` is not provided
125+
callbacks: Optional callbacks for handling notifications and events
126+
server_name: Name of the server this tool belongs to
115127
116128
Returns:
117129
a LangChain tool
@@ -124,17 +136,37 @@ def convert_mcp_tool_to_langchain_tool(
124136
async def call_tool(
125137
**arguments: dict[str, Any],
126138
) -> tuple[str | list[str], list[NonTextContent] | None]:
139+
mcp_callbacks = (
140+
callbacks.to_mcp_format(
141+
context=CallbackContext(server_name=server_name, tool_name=tool.name)
142+
)
143+
if callbacks is not None
144+
else _MCPCallbacks()
145+
)
146+
147+
# Execute the tool call
127148
call_tool_result = None
128149
if session is None:
129150
# If a session is not provided, we will create one on the fly
130-
async with create_session(connection) as tool_session:
151+
if connection is None:
152+
msg = "Either session or connection must be provided"
153+
raise ValueError(msg)
154+
155+
async with create_session(
156+
connection, mcp_callbacks=mcp_callbacks
157+
) as tool_session:
131158
await tool_session.initialize()
132159
call_tool_result = await cast("ClientSession", tool_session).call_tool(
133160
tool.name,
134161
arguments,
162+
progress_callback=mcp_callbacks.progress_callback,
135163
)
136164
else:
137-
call_tool_result = await session.call_tool(tool.name, arguments)
165+
call_tool_result = await session.call_tool(
166+
tool.name,
167+
arguments,
168+
progress_callback=mcp_callbacks.progress_callback,
169+
)
138170

139171
if call_tool_result is None:
140172
msg = (
@@ -148,6 +180,7 @@ async def call_tool(
148180
return _convert_call_tool_result(call_tool_result)
149181

150182
meta = tool.meta if hasattr(tool, "meta") else None
183+
151184
base = tool.annotations.model_dump() if tool.annotations is not None else {}
152185
meta = {"_meta": meta} if meta is not None else {}
153186
metadata = {**base, **meta} or None
@@ -166,12 +199,16 @@ async def load_mcp_tools(
166199
session: ClientSession | None,
167200
*,
168201
connection: Connection | None = None,
202+
callbacks: Callbacks | None = None,
203+
server_name: str | None = None,
169204
) -> list[BaseTool]:
170205
"""Load all available MCP tools and convert them to LangChain tools.
171206
172207
Args:
173208
session: The MCP client session. If None, connection must be provided.
174209
connection: Connection config to create a new session if session is None.
210+
callbacks: Optional callbacks for handling notifications and events.
211+
server_name: Name of the server these tools belong to.
175212
176213
Returns:
177214
List of LangChain tools. Tool annotations are returned as part
@@ -184,16 +221,33 @@ async def load_mcp_tools(
184221
msg = "Either a session or a connection config must be provided"
185222
raise ValueError(msg)
186223

224+
mcp_callbacks = (
225+
callbacks.to_mcp_format(context=CallbackContext(server_name=server_name))
226+
if callbacks is not None
227+
else _MCPCallbacks()
228+
)
229+
187230
if session is None:
188231
# If a session is not provided, we will create one on the fly
189-
async with create_session(connection) as tool_session:
232+
if connection is None:
233+
msg = "Either session or connection must be provided"
234+
raise ValueError(msg)
235+
async with create_session(
236+
connection, mcp_callbacks=mcp_callbacks
237+
) as tool_session:
190238
await tool_session.initialize()
191239
tools = await _list_all_tools(tool_session)
192240
else:
193241
tools = await _list_all_tools(session)
194242

195243
return [
196-
convert_mcp_tool_to_langchain_tool(session, tool, connection=connection)
244+
convert_mcp_tool_to_langchain_tool(
245+
session,
246+
tool,
247+
connection=connection,
248+
callbacks=callbacks,
249+
server_name=server_name,
250+
)
197251
for tool in tools
198252
]
199253

0 commit comments

Comments
 (0)