Skip to content

Commit c89c9d2

Browse files
committed
feat: in-memory transport support
In-memory transport is officially supported by FastMCP v2. However, FastMCP v1 already supports in-memory transport (but undocumented) See also: jlowin/fastmcp#758
1 parent 4db3ccb commit c89c9d2

File tree

2 files changed

+188
-1
lines changed

2 files changed

+188
-1
lines changed

langchain_mcp_adapters/sessions.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@
1010
from datetime import timedelta
1111
from typing import TYPE_CHECKING, Any, Literal, Protocol
1212

13+
import anyio
1314
from mcp import ClientSession, StdioServerParameters
1415
from mcp.client.sse import sse_client
1516
from mcp.client.stdio import stdio_client
1617
from mcp.client.streamable_http import streamablehttp_client
18+
from mcp.server import Server
19+
from mcp.server.fastmcp import FastMCP as FastMCP1Server
20+
from mcp.shared.memory import create_client_server_memory_streams
1721
from typing_extensions import NotRequired, TypedDict
1822

1923
if TYPE_CHECKING:
@@ -183,8 +187,31 @@ class WebsocketConnection(TypedDict):
183187
"""Additional keyword arguments to pass to the ClientSession"""
184188

185189

190+
class InMemoryConnection(TypedDict):
191+
"""Configuration for In-memory transport connections to MCP servers."""
192+
193+
transport: Literal["in_memory"]
194+
195+
server: Server[Any] | FastMCP1Server
196+
"""The Server instance to connect to."""
197+
198+
raise_exceptions: NotRequired[bool]
199+
"""When False, exceptions are returned as messages to the client.
200+
When True, exceptions are raised, which will cause the server to shut down
201+
but also make tracing exceptions much easier during testing and when using
202+
in-process servers.
203+
"""
204+
205+
session_kwargs: NotRequired[dict[str, Any] | None]
206+
"""Additional keyword arguments to pass to the ClientSession"""
207+
208+
186209
Connection = (
187-
StdioConnection | SSEConnection | StreamableHttpConnection | WebsocketConnection
210+
StdioConnection
211+
| SSEConnection
212+
| StreamableHttpConnection
213+
| WebsocketConnection
214+
| InMemoryConnection
188215
)
189216

190217

@@ -234,6 +261,46 @@ async def _create_stdio_session(
234261
yield session
235262

236263

264+
@asynccontextmanager
265+
async def _create_inmemory_session(
266+
*,
267+
server: Server[Any] | FastMCP1Server,
268+
raise_exceptions: bool = False,
269+
session_kwargs: dict[str, Any] | None = None,
270+
) -> AsyncIterator[ClientSession]:
271+
async with create_client_server_memory_streams() as (
272+
client_streams,
273+
server_streams,
274+
):
275+
if isinstance(server, FastMCP1Server):
276+
server = server._mcp_server # type: ignore[reportPrivateUsage]
277+
278+
# https://github.com/jlowin/fastmcp/pull/758
279+
client_read, client_write = client_streams
280+
server_read, server_write = server_streams
281+
282+
# Create a cancel scope for the server task
283+
async with anyio.create_task_group() as tg:
284+
tg.start_soon(
285+
lambda: server.run(
286+
server_read,
287+
server_write,
288+
server.create_initialization_options(),
289+
raise_exceptions=raise_exceptions,
290+
)
291+
)
292+
293+
try:
294+
async with ClientSession(
295+
client_read,
296+
client_write,
297+
**(session_kwargs or {}),
298+
) as client_session:
299+
yield client_session
300+
finally:
301+
tg.cancel_scope.cancel()
302+
303+
237304
@asynccontextmanager
238305
async def _create_sse_session(
239306
*,
@@ -423,6 +490,12 @@ async def create_session(
423490
raise ValueError(msg)
424491
async with _create_websocket_session(**params) as session:
425492
yield session
493+
elif transport == "in_memory":
494+
if "server" not in params:
495+
msg = "'server' parameter is required for In-memory connection"
496+
raise ValueError(msg)
497+
async with _create_inmemory_session(**params) as session:
498+
yield session
426499
else:
427500
msg = (
428501
f"Unsupported transport: {transport}. "

tests/test_inmemory.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import importlib.util
2+
import os
3+
from pathlib import Path
4+
5+
from langchain_core.messages import AIMessage
6+
from langchain_core.tools import BaseTool
7+
8+
from langchain_mcp_adapters.client import MultiServerMCPClient
9+
10+
11+
def _load_module(module_name: str, server_path: str) -> any:
12+
module_spec = importlib.util.spec_from_file_location(module_name, server_path)
13+
assert module_spec is not None
14+
15+
module = importlib.util.module_from_spec(module_spec)
16+
module_spec.loader.exec_module(module)
17+
return module
18+
19+
20+
async def test_multi_server_mcp_client(
21+
socket_enabled,
22+
websocket_server,
23+
websocket_server_port: int,
24+
):
25+
"""Test that MultiServerMCPClient can connect to multiple servers and load tools."""
26+
# Get the absolute path to the server scripts
27+
current_dir = Path(__file__).parent
28+
math_server_path = os.path.join(current_dir, "servers/math_server.py")
29+
weather_server_path = os.path.join(current_dir, "servers/weather_server.py")
30+
# import weather_server
31+
weather_server_module = _load_module("weather_server", weather_server_path)
32+
33+
client = MultiServerMCPClient(
34+
{
35+
"math": {
36+
"command": "python3",
37+
"args": [math_server_path],
38+
"transport": "stdio",
39+
},
40+
"weather": {
41+
"server": weather_server_module.mcp,
42+
"transport": "in_memory",
43+
},
44+
},
45+
)
46+
# Check that we have tools from both servers
47+
all_tools = await client.get_tools()
48+
49+
# Should have 3 tools (add, multiply, get_weather)
50+
assert len(all_tools) == 3
51+
52+
# Check that tools are BaseTool instances
53+
for tool in all_tools:
54+
assert isinstance(tool, BaseTool)
55+
56+
# Verify tool names
57+
tool_names = {tool.name for tool in all_tools}
58+
assert tool_names == {"add", "multiply", "get_weather"}
59+
60+
# Check math server tools
61+
math_tools = await client.get_tools(server_name="math")
62+
assert len(math_tools) == 2
63+
math_tool_names = {tool.name for tool in math_tools}
64+
assert math_tool_names == {"add", "multiply"}
65+
66+
# Check weather server tools
67+
weather_tools = await client.get_tools(server_name="weather")
68+
assert len(weather_tools) == 1
69+
assert weather_tools[0].name == "get_weather"
70+
71+
# Test that we can call a math tool
72+
add_tool = next(tool for tool in all_tools if tool.name == "add")
73+
result = await add_tool.ainvoke({"a": 2, "b": 3})
74+
assert result == "5"
75+
76+
# Test that we can call a weather tool
77+
weather_tool = next(tool for tool in all_tools if tool.name == "get_weather")
78+
result = await weather_tool.ainvoke({"location": "London"})
79+
assert result == "It's always sunny in London"
80+
81+
# Test the multiply tool
82+
multiply_tool = next(tool for tool in all_tools if tool.name == "multiply")
83+
result = await multiply_tool.ainvoke({"a": 4, "b": 5})
84+
assert result == "20"
85+
86+
87+
async def test_get_prompt():
88+
"""Test retrieving prompts from MCP servers."""
89+
# Get the absolute path to the server scripts
90+
current_dir = Path(__file__).parent
91+
math_server_path = os.path.join(current_dir, "servers/math_server.py")
92+
# import weather_server
93+
math_server_module = _load_module("math_server", math_server_path)
94+
95+
client = MultiServerMCPClient(
96+
{
97+
"math": {
98+
"server": math_server_module.mcp,
99+
"transport": "in_memory",
100+
}
101+
},
102+
)
103+
# Test getting a prompt from the math server
104+
messages = await client.get_prompt(
105+
"math",
106+
"configure_assistant",
107+
arguments={"skills": "math, addition, multiplication"},
108+
)
109+
110+
# Check that we got an AIMessage back
111+
assert len(messages) == 1
112+
assert isinstance(messages[0], AIMessage)
113+
assert "You are a helpful assistant" in messages[0].content
114+
assert "math, addition, multiplication" in messages[0].content

0 commit comments

Comments
 (0)