Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def multiply(a: int, b: int) -> int:
return a * b

if __name__ == "__main__":
mcp.run()
mcp.run(transport="stdio")
```

### Client
Expand Down Expand Up @@ -103,7 +103,11 @@ async def get_weather(location: str) -> int:
return "It's always sunny in New York"

if __name__ == "__main__":
mcp.run()
mcp.run(transport="sse")
```

```bash
python weather_server.py
```

### Client
Expand All @@ -115,19 +119,21 @@ from langgraph.prebuilt import create_react_agent
from langchain_openai import ChatOpenAI
model = ChatOpenAI(model="gpt-4o")

async with MultiServerMCPClient() as client:
await client.connect_to_server(
"math",
command="python",
# Make sure to update to the full absolute path to your math_server.py file
args=["/path/to/math_server.py"],
)
await client.connect_to_server(
"weather",
command="python",
# Make sure to update to the full absolute path to your weather_server.py file
args=["/path/to/weather_server.py"],
)
async with MultiServerMCPClient(
{
"math": {
"command": "python",
# Make sure to update to the full absolute path to your math_server.py file
"args": ["/path/to/math_server.py"],
"transport": "stdio",
},
"weather": {
# make sure you start your weather server on port 8000
"url": "http://localhost:8000/sse",
"transport": "sse",
}
}
) as client:
agent = create_react_agent(model, client.get_tools())
math_response = await agent.ainvoke({"messages": "what's (3 + 5) x 12?"})
weather_response = await agent.ainvoke({"messages": "what is the weather in nyc?"})
Expand Down
192 changes: 180 additions & 12 deletions langchain_mcp_adapters/client.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,169 @@
from contextlib import AsyncExitStack
from types import TracebackType
from typing import Literal, cast
from typing import Literal, TypedDict, cast

from langchain_core.tools import BaseTool
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client

from langchain_mcp_adapters.tools import load_mcp_tools

DEFAULT_ENCODING = "utf-8"
DEFAULT_ENCODING_ERROR_HANDLER = "strict"


class StdioConnection(TypedDict):
transport: Literal["stdio"]

command: str
"""The executable to run to start the server."""

args: list[str]
"""Command line arguments to pass to the executable."""

env: dict[str, str] | None
"""The environment to use when spawning the process."""

encoding: str
"""The text encoding used when sending/receiving messages to the server."""

encoding_error_handler: Literal["strict", "ignore", "replace"]
"""
The text encoding error handler.

See https://docs.python.org/3/library/codecs.html#codec-base-classes for
explanations of possible values
"""


class SSEConnection(TypedDict):
transport: Literal["sse"]

url: str
"""The URL of the SSE endpoint to connect to."""


class MultiServerMCPClient:
"""Client for connecting to multiple MCP servers and loading LangChain-compatible tools from them."""

def __init__(self) -> None:
def __init__(self, connections: dict[str, StdioConnection | SSEConnection] = None) -> None:
"""Initialize a MultiServerMCPClient with MCP servers connections.

Args:
connections: A dictionary mapping server names to connection configurations.
Each configuration can be either a StdioConnection or SSEConnection.
If None, no initial connections are established.

Example:

```python
async with MultiServerMCPClient(
{
"math": {
"command": "python",
# Make sure to update to the full absolute path to your math_server.py file
"args": ["/path/to/math_server.py"],
"transport": "stdio",
},
"weather": {
# make sure you start your weather server on port 8000
"url": "http://localhost:8000/sse",
"transport": "sse",
}
}
) as client:
all_tools = client.get_tools()
...
```
"""
self.connections = connections
self.exit_stack = AsyncExitStack()
self.sessions: dict[str, ClientSession] = {}
self.server_name_to_tools: dict[str, list[BaseTool]] = {}

async def _initialize_session_and_load_tools(
self, server_name: str, session: ClientSession
) -> None:
"""Initialize a session and load tools from it.

Args:
server_name: Name to identify this server connection
session: The ClientSession to initialize
"""
# Initialize the session
await session.initialize()
self.sessions[server_name] = session

# Load tools from this server
server_tools = await load_mcp_tools(session)
self.server_name_to_tools[server_name] = server_tools

async def connect_to_server(
self,
server_name: str,
*,
transport: Literal["stdio", "sse"] = "stdio",
**kwargs,
) -> None:
"""Connect to an MCP server using either stdio or SSE.

This is a generic method that calls either connect_to_server_via_stdio or connect_to_server_via_sse
based on the provided transport parameter.

Args:
server_name: Name to identify this server connection
transport: Type of transport to use ("stdio" or "sse"), defaults to "stdio"
**kwargs: Additional arguments to pass to the specific connection method

Raises:
ValueError: If transport is not recognized
ValueError: If required parameters for the specified transport are missing
"""
if transport == "sse":
if "url" not in kwargs:
raise ValueError("'url' parameter is required for SSE connection")
await self.connect_to_server_via_sse(server_name, url=kwargs["url"])
elif transport == "stdio":
if "command" not in kwargs:
raise ValueError("'command' parameter is required for stdio connection")
if "args" not in kwargs:
raise ValueError("'args' parameter is required for stdio connection")
await self.connect_to_server_via_stdio(
server_name,
command=kwargs["command"],
args=kwargs["args"],
env=kwargs.get("env"),
encoding=kwargs.get("encoding", DEFAULT_ENCODING),
encoding_error_handler=kwargs.get(
"encoding_error_handler", DEFAULT_ENCODING_ERROR_HANDLER
),
)
else:
raise ValueError(f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'")

async def connect_to_server_via_stdio(
self,
server_name: str,
*,
command: str,
args: list[str],
env: dict[str, str] | None = None,
encoding: str = "utf-8",
encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict",
encoding: str = DEFAULT_ENCODING,
encoding_error_handler: Literal[
"strict", "ignore", "replace"
] = DEFAULT_ENCODING_ERROR_HANDLER,
) -> None:
"""Connect to a specific MCP server"""
"""Connect to a specific MCP server using stdio

Args:
server_name: Name to identify this server connection
command: Command to execute
args: Arguments for the command
env: Environment variables for the command
encoding: Character encoding
encoding_error_handler: How to handle encoding errors
"""
server_params = StdioServerParameters(
command=command,
args=args,
Expand All @@ -44,13 +180,29 @@ async def connect_to_server(
await self.exit_stack.enter_async_context(ClientSession(read, write)),
)

# Initialize the session
await session.initialize()
self.sessions[server_name] = session
await self._initialize_session_and_load_tools(server_name, session)

# Load tools from this server
server_tools = await load_mcp_tools(session)
self.server_name_to_tools[server_name] = server_tools
async def connect_to_server_via_sse(
self,
server_name: str,
*,
url: str,
) -> None:
"""Connect to a specific MCP server using SSE

Args:
server_name: Name to identify this server connection
url: URL of the SSE server
"""
# Create and store the connection
sse_transport = await self.exit_stack.enter_async_context(sse_client(url))
read, write = sse_transport
session = cast(
ClientSession,
await self.exit_stack.enter_async_context(ClientSession(read, write)),
)

await self._initialize_session_and_load_tools(server_name, session)

def get_tools(self) -> list[BaseTool]:
"""Get a list of all tools from all connected servers."""
Expand All @@ -60,7 +212,23 @@ def get_tools(self) -> list[BaseTool]:
return all_tools

async def __aenter__(self) -> "MultiServerMCPClient":
return self
try:
connections = self.connections or {}
for server_name, connection in connections.items():
connection_dict = connection.copy()
transport = connection_dict.pop("transport")
if transport == "stdio":
await self.connect_to_server_via_stdio(server_name, **connection_dict)
elif transport == "sse":
await self.connect_to_server_via_sse(server_name, **connection_dict)
else:
raise ValueError(
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
)
return self
except Exception:
await self.exit_stack.aclose()
raise

async def __aexit__(
self,
Expand Down