Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions langchain_mcp_adapters/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from contextlib import AsyncExitStack
from pathlib import Path
from types import TracebackType
from typing import Any, Literal, Optional, TypedDict, cast

Expand All @@ -12,8 +13,10 @@
from langchain_mcp_adapters.prompts import load_mcp_prompt
from langchain_mcp_adapters.tools import load_mcp_tools

EncodingErrorHandler = Literal["strict", "ignore", "replace"]

DEFAULT_ENCODING = "utf-8"
DEFAULT_ENCODING_ERROR_HANDLER = "strict"
DEFAULT_ENCODING_ERROR_HANDLER: EncodingErrorHandler = "strict"

DEFAULT_HTTP_TIMEOUT = 5
DEFAULT_SSE_READ_TIMEOUT = 60 * 5
Expand All @@ -31,17 +34,23 @@ class StdioConnection(TypedDict):
env: dict[str, str] | None
"""The environment to use when spawning the process."""

cwd: str | Path | None
"""The working directory to use when spawning the process."""

encoding: str
"""The text encoding used when sending/receiving messages to the server."""

encoding_error_handler: Literal["strict", "ignore", "replace"]
encoding_error_handler: EncodingErrorHandler
"""
The text encoding error handler.

See https://docs.python.org/3/library/codecs.html#codec-base-classes for
explanations of possible values
"""

session_kwargs: dict[str, Any] | None
"""Additional keyword arguments to pass to the ClientSession"""


class SSEConnection(TypedDict):
transport: Literal["sse"]
Expand All @@ -58,6 +67,9 @@ class SSEConnection(TypedDict):
sse_read_timeout: float
"""SSE read timeout"""

session_kwargs: dict[str, Any] | None
"""Additional keyword arguments to pass to the ClientSession"""


class MultiServerMCPClient:
"""Client for connecting to multiple MCP servers and loading LangChain-compatible tools from them."""
Expand Down Expand Up @@ -146,6 +158,7 @@ async def connect_to_server(
headers=kwargs.get("headers"),
timeout=kwargs.get("timeout", DEFAULT_HTTP_TIMEOUT),
sse_read_timeout=kwargs.get("sse_read_timeout", DEFAULT_SSE_READ_TIMEOUT),
session_kwargs=kwargs.get("session_kwargs"),
)
elif transport == "stdio":
if "command" not in kwargs:
Expand All @@ -161,6 +174,7 @@ async def connect_to_server(
encoding_error_handler=kwargs.get(
"encoding_error_handler", DEFAULT_ENCODING_ERROR_HANDLER
),
session_kwargs=kwargs.get("session_kwargs"),
)
else:
raise ValueError(f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'")
Expand All @@ -176,6 +190,7 @@ async def connect_to_server_via_stdio(
encoding_error_handler: Literal[
"strict", "ignore", "replace"
] = DEFAULT_ENCODING_ERROR_HANDLER,
session_kwargs: dict[str, Any] | None = None,
) -> None:
"""Connect to a specific MCP server using stdio

Expand All @@ -186,6 +201,7 @@ async def connect_to_server_via_stdio(
env: Environment variables for the command
encoding: Character encoding
encoding_error_handler: How to handle encoding errors
session_kwargs: Additional keyword arguments to pass to the ClientSession
"""
# NOTE: execution commands (e.g., `uvx` / `npx`) require PATH envvar to be set.
# To address this, we automatically inject existing PATH envvar into the `env` value,
Expand All @@ -205,9 +221,10 @@ async def connect_to_server_via_stdio(
# Create and store the connection
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
read, write = stdio_transport
session_kwargs = session_kwargs or {}
session = cast(
ClientSession,
await self.exit_stack.enter_async_context(ClientSession(read, write)),
await self.exit_stack.enter_async_context(ClientSession(read, write, **session_kwargs)),
)

await self._initialize_session_and_load_tools(server_name, session)
Expand All @@ -220,6 +237,7 @@ async def connect_to_server_via_sse(
headers: dict[str, Any] | None = None,
timeout: float = DEFAULT_HTTP_TIMEOUT,
sse_read_timeout: float = DEFAULT_SSE_READ_TIMEOUT,
session_kwargs: dict[str, Any] | None = None,
) -> None:
"""Connect to a specific MCP server using SSE

Expand All @@ -229,15 +247,17 @@ async def connect_to_server_via_sse(
headers: HTTP headers to send to the SSE endpoint
timeout: HTTP timeout
sse_read_timeout: SSE read timeout
session_kwargs: Additional keyword arguments to pass to the ClientSession
"""
# Create and store the connection
sse_transport = await self.exit_stack.enter_async_context(
sse_client(url, headers, timeout, sse_read_timeout)
)
read, write = sse_transport
session_kwargs = session_kwargs or {}
session = cast(
ClientSession,
await self.exit_stack.enter_async_context(ClientSession(read, write)),
await self.exit_stack.enter_async_context(ClientSession(read, write, **session_kwargs)),
)

await self._initialize_session_and_load_tools(server_name, session)
Expand All @@ -260,18 +280,8 @@ async def __aenter__(self) -> "MultiServerMCPClient":
try:
connections = self.connections or {}
for server_name, connection in connections.items():
connection_dict = connection.copy()
transport = connection_dict.pop("transport")
if transport == "stdio":
# connection_dict is a StdioConnection (with "transport" popped)
await self.connect_to_server_via_stdio(server_name, **connection_dict) # type: ignore
elif transport == "sse":
# connection_dict is a SSEConnection (with "transport" popped)
await self.connect_to_server_via_sse(server_name, **connection_dict) # type: ignore
else:
raise ValueError(
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
)
await self.connect_to_server(server_name, **connection)

return self
except Exception:
await self.exit_stack.aclose()
Expand Down