Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions langchain_mcp_adapters/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Types for callbacks."""

from dataclasses import dataclass
from typing import Protocol

from mcp.client.session import LoggingFnT
from mcp.shared.session import ProgressFnT
from mcp.types import LoggingMessageNotificationParams


@dataclass
class CallbackContext:
"""LangChain MCP client callback context."""

server_name: str
tool_name: str | None = None


class LoggingMessageCallback(Protocol):
"""Light wrapper around the mcp.client.session.LoggingFnT that injects callback context."""

async def __call__(
self,
params: LoggingMessageNotificationParams,
context: CallbackContext,
) -> None:
"""Execute callback on logging message notification."""
...


class ProgressCallback(Protocol):
"""Light wrapper around the mcp.shared.session.ProgressFnT that injects callback context."""

async def __call__(
self,
progress: float,
total: float | None,
message: str | None,
context: CallbackContext,
) -> None:
"""Execute callback on progress notification."""
...


class _MCPCallbacks:
"""Callbacks compatible with the MCP SDK. For internal use only."""

logging_callback: LoggingFnT | None = None
progress_callback: ProgressFnT | None = None


@dataclass
class Callbacks:
"""Callbacks for the LangChain MCP client."""

on_logging_message: LoggingMessageCallback | None = None
on_progress: ProgressCallback | None = None

def to_mcp_format(self, *, context: CallbackContext) -> _MCPCallbacks:
"""Convert the LangChain MCP client callbacks to MCP SDK callbacks.

Injects the LangChain CallbackContext as the last argument.
"""
if (on_logging_message := self.on_logging_message) is not None:

def mcp_logging_callback(params: LoggingMessageNotificationParams) -> None:
on_logging_message(params, context)
else:
mcp_logging_callback = None

if (on_progress := self.on_progress) is not None:

def mcp_progress_callback(
progress: float, total: float | None, message: str | None
) -> None:
on_progress(progress, total, message, context)
else:
mcp_progress_callback = None

return _MCPCallbacks(
logging_callback=mcp_logging_callback,
progress_callback=mcp_progress_callback,
)
43 changes: 37 additions & 6 deletions langchain_mcp_adapters/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from langchain_core.tools import BaseTool
from mcp import ClientSession

from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks
from langchain_mcp_adapters.hooks import Hooks
from langchain_mcp_adapters.prompts import load_mcp_prompt
from langchain_mcp_adapters.resources import load_mcp_resources
from langchain_mcp_adapters.sessions import (
Expand Down Expand Up @@ -46,12 +48,20 @@ class MultiServerMCPClient:
Loads LangChain-compatible tools, prompts and resources from MCP servers.
"""

def __init__(self, connections: dict[str, Connection] | None = None) -> None:
def __init__(
self,
connections: dict[str, Connection] | None = None,
*,
callbacks: Callbacks | None = None,
hooks: Hooks | None = None,
) -> None:
"""Initialize a MultiServerMCPClient with MCP servers connections.

Args:
connections: A dictionary mapping server names to connection configurations.
If None, no initial connections are established.
callbacks: Optional callbacks for handling notifications and events.
hooks: Optional hooks for intercepting and modifying tool calls.

Example: basic usage (starting a new session on each tool call)

Expand Down Expand Up @@ -87,11 +97,12 @@ def __init__(self, connections: dict[str, Connection] | None = None) -> None:
async with client.session("math") as session:
tools = await load_mcp_tools(session)
```

"""
self.connections: dict[str, Connection] = (
connections if connections is not None else {}
)
self.callbacks = callbacks or Callbacks()
self.hooks = hooks or Hooks()

@asynccontextmanager
async def session(
Expand Down Expand Up @@ -120,7 +131,13 @@ async def session(
)
raise ValueError(msg)

async with create_session(self.connections[server_name]) as session:
mcp_callbacks = self.callbacks.to_mcp_format(
context=CallbackContext(server_name=server_name)
)

async with create_session(
self.connections[server_name], mcp_callbacks=mcp_callbacks
) as session:
if auto_initialize:
await session.initialize()
yield session
Expand All @@ -145,13 +162,25 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
f"expected one of '{list(self.connections.keys())}'"
)
raise ValueError(msg)
return await load_mcp_tools(None, connection=self.connections[server_name])
return await load_mcp_tools(
None,
connection=self.connections[server_name],
callbacks=self.callbacks,
hooks=self.hooks,
server_name=server_name,
)

all_tools: list[BaseTool] = []
load_mcp_tool_tasks = []
for connection in self.connections.values():
for name, connection in self.connections.items():
load_mcp_tool_task = asyncio.create_task(
load_mcp_tools(None, connection=connection)
load_mcp_tools(
None,
connection=connection,
callbacks=self.callbacks,
hooks=self.hooks,
server_name=name,
)
)
load_mcp_tool_tasks.append(load_mcp_tool_task)
tools_list = await asyncio.gather(*load_mcp_tool_tasks)
Expand Down Expand Up @@ -218,6 +247,8 @@ def __aexit__(


__all__ = [
"Callbacks",
"Hooks",
"McpHttpClientFactory",
"MultiServerMCPClient",
"SSEConnection",
Expand Down
84 changes: 84 additions & 0 deletions langchain_mcp_adapters/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Hook interfaces and types for MCP client lifecycle management.

This module provides hook interfaces for intercepting and extending
MCP client behavior before and after tool calls.

In the future, we might add more hooks for other parts of the
request / result lifecycle, for example to support elicitation.
"""

from dataclasses import dataclass, field
from typing import Any, Protocol, TypedDict

from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableConfig
from mcp.types import CallToolRequest, CallToolResult


@dataclass
class ToolHookContext:
"""Context object passed to hooks containing state and server information."""

server_name: str

state: dict[str, Any] = field(default_factory=dict)
runnable_config: RunnableConfig = field(default_factory=dict)
runtime: object = None


class CallToolSpecs(TypedDict, total=False):
headers: dict[str, Any]


class BeforeToolCallResult(TypedDict, total=False):
"""Result returned by before_tool_call hook."""

name: str
args: dict[str, Any]
headers: dict[str, Any]


class BeforeToolCallHook(Protocol):
"""Protocol for before_tool_call hook functions."""

async def __call__(
self,
request: CallToolRequest,
context: ToolHookContext,
) -> CallToolRequest | tuple[CallToolRequest, CallToolSpecs]:
"""Execute before tool call."""
...


UpdatedContent = str | list[str | dict[str, Any]]


class AfterToolCallHook(Protocol):
"""Protocol for after_tool_call hook functions."""

async def __call__(
self,
result: CallToolResult,
context: ToolHookContext,
) -> UpdatedContent | ToolMessage | None:
"""Execute after tool call."""
...


class Hooks:
"""Container for MCP client hook functions."""

def __init__(
self,
*,
before_tool_call: BeforeToolCallHook | None = None,
after_tool_call: AfterToolCallHook | None = None,
) -> None:
"""Initialize hooks.

Args:
before_tool_call: Hook called before tool execution
after_tool_call: Hook called after tool execution
"""
self.before_tool_call = before_tool_call
self.after_tool_call = after_tool_call
19 changes: 15 additions & 4 deletions langchain_mcp_adapters/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

import httpx

from langchain_mcp_adapters.callbacks import _MCPCallbacks

EncodingErrorHandler = Literal["strict", "ignore", "replace"]

DEFAULT_ENCODING = "utf-8"
Expand Down Expand Up @@ -187,7 +189,7 @@ class WebsocketConnection(TypedDict):


@asynccontextmanager
async def _create_stdio_session( # noqa: PLR0913
async def _create_stdio_session(
*,
command: str,
args: list[str],
Expand Down Expand Up @@ -233,7 +235,7 @@ async def _create_stdio_session( # noqa: PLR0913


@asynccontextmanager
async def _create_sse_session( # noqa: PLR0913
async def _create_sse_session(
*,
url: str,
headers: dict[str, Any] | None = None,
Expand Down Expand Up @@ -273,7 +275,7 @@ async def _create_sse_session( # noqa: PLR0913


@asynccontextmanager
async def _create_streamable_http_session( # noqa: PLR0913
async def _create_streamable_http_session(
*,
url: str,
headers: dict[str, Any] | None = None,
Expand Down Expand Up @@ -356,11 +358,14 @@ async def _create_websocket_session(


@asynccontextmanager
async def create_session(connection: Connection) -> AsyncIterator[ClientSession]: # noqa: C901
async def create_session(
connection: Connection, *, mcp_callbacks: _MCPCallbacks
) -> AsyncIterator[ClientSession]:
"""Create a new session to an MCP server.

Args:
connection: Connection config to use to connect to the server
mcp_callbacks: mcp sdk compatible callbacks to use for the ClientSession

Raises:
ValueError: If transport is not recognized
Expand All @@ -381,6 +386,12 @@ async def create_session(connection: Connection) -> AsyncIterator[ClientSession]
transport = connection["transport"]
params = {k: v for k, v in connection.items() if k != "transport"}

params["session_kwargs"] = params.get("session_kwargs", {})
# right now the only callback supported on the ClientSession is the logging callback
# long term we'll also want to support sampling, elicitation, list roots, etc.
if mcp_callbacks.logging_callback is not None:
params["session_kwargs"]["logging_callback"] = mcp_callbacks.logging_callback

if transport == "sse":
if "url" not in params:
msg = "'url' parameter is required for SSE connection"
Expand Down
Loading
Loading