Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/fastmcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
progress_handler: ProgressHandler | None = None,
timeout: datetime.timedelta | float | int | None = None,
init_timeout: datetime.timedelta | float | int | None = None,
client_info: mcp.types.Implementation | None = None,
auth: httpx.Auth | Literal["oauth"] | str | None = None,
):
self.transport = cast(ClientTransportT, infer_transport(transport))
Expand Down Expand Up @@ -180,6 +181,7 @@ def __init__(
"logging_callback": create_log_callback(log_handler),
"message_handler": message_handler,
"read_timeout_seconds": timeout,
"client_info": client_info,
}

if roots is not None:
Expand Down
59 changes: 35 additions & 24 deletions src/fastmcp/client/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,15 @@
import warnings
from collections.abc import AsyncIterator, Callable
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Literal,
TypedDict,
TypeVar,
cast,
overload,
)
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast, overload

import anyio
import httpx
import mcp.types
from mcp import ClientSession, StdioServerParameters
from mcp.client.session import (
ListRootsFnT,
LoggingFnT,
MessageHandlerFnT,
SamplingFnT,
)
from mcp.client.session import ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
from mcp.server.fastmcp import FastMCP as FastMCP1Server
from mcp.shared.memory import create_connected_server_and_client_session
from mcp.shared.memory import create_client_server_memory_streams
from pydantic import AnyUrl
from typing_extensions import Unpack

Expand Down Expand Up @@ -65,11 +53,12 @@
class SessionKwargs(TypedDict, total=False):
"""Keyword arguments for the MCP ClientSession constructor."""

read_timeout_seconds: datetime.timedelta | None
sampling_callback: SamplingFnT | None
list_roots_callback: ListRootsFnT | None
logging_callback: LoggingFnT | None
message_handler: MessageHandlerFnT | None
read_timeout_seconds: datetime.timedelta | None
client_info: mcp.types.Implementation | None


class ClientTransport(abc.ABC):
Expand Down Expand Up @@ -662,24 +651,46 @@ class FastMCPTransport(ClientTransport):
tests or scenarios where client and server run in the same runtime.
"""

def __init__(self, mcp: FastMCP | FastMCP1Server):
def __init__(self, mcp: FastMCP | FastMCP1Server, raise_exceptions: bool = False):
"""Initialize a FastMCPTransport from a FastMCP server instance."""

# Accept both FastMCP 2.x and FastMCP 1.0 servers. Both expose a
# ``_mcp_server`` attribute pointing to the underlying MCP server
# implementation, so we can treat them identically.
self.server = mcp
self.raise_exceptions = raise_exceptions

@contextlib.asynccontextmanager
async def connect_session(
self, **session_kwargs: Unpack[SessionKwargs]
) -> AsyncIterator[ClientSession]:
# create_connected_server_and_client_session manages the session lifecycle itself
async with create_connected_server_and_client_session(
server=self.server._mcp_server,
**session_kwargs,
) as session:
yield session
async with create_client_server_memory_streams() as (
client_streams,
server_streams,
):
client_read, client_write = client_streams
server_read, server_write = server_streams

# Create a cancel scope for the server task
async with anyio.create_task_group() as tg:
tg.start_soon(
lambda: self.server._mcp_server.run(
server_read,
server_write,
self.server._mcp_server.create_initialization_options(),
raise_exceptions=self.raise_exceptions,
)
)

try:
async with ClientSession(
read_stream=client_read,
write_stream=client_write,
**session_kwargs,
) as client_session:
yield client_session
finally:
tg.cancel_scope.cancel()

def __repr__(self) -> str:
return f"<FastMCPTransport(server='{self.server.name}')>"
Expand Down
10 changes: 10 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import sys
from typing import cast
from unittest.mock import AsyncMock

import mcp
import pytest
from mcp import McpError
from mcp.client.auth import OAuthClientProvider
Expand Down Expand Up @@ -275,6 +277,14 @@ async def test_client_connection(fastmcp_server):
assert not client.is_connected()


async def test_initialize_called_once(fastmcp_server, monkeypatch):
mock_initialize = AsyncMock()
monkeypatch.setattr(mcp.ClientSession, "initialize", mock_initialize)
client = Client(transport=FastMCPTransport(fastmcp_server))
async with client:
assert mock_initialize.call_count == 1


async def test_initialize_result_connected(fastmcp_server):
"""Test that initialize_result returns the correct result when connected."""
client = Client(transport=FastMCPTransport(fastmcp_server))
Expand Down
14 changes: 7 additions & 7 deletions tests/server/test_server_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def tool_with_context(x: int, ctx: Context) -> str:
result = await client.call_tool("tool_with_context", {"x": 42})
assert len(result) == 1
content = result[0]
assert content.text == "2" # type: ignore[attr-defined]
assert content.text == "1" # type: ignore[attr-defined]

async def test_async_context(self):
"""Test that context works in async functions."""
Expand All @@ -632,7 +632,7 @@ async def async_tool(x: int, ctx: Context) -> str:
result = await client.call_tool("async_tool", {"x": 42})
assert len(result) == 1
content = result[0]
assert content.text == "Async request 2: 42" # type: ignore[attr-defined]
assert content.text == "Async request 1: 42" # type: ignore[attr-defined]

async def test_optional_context(self):
"""Test that context is optional."""
Expand Down Expand Up @@ -696,7 +696,7 @@ async def __call__(self, x: int, ctx: Context) -> int:

async with Client(mcp) as client:
result = await client.call_tool("MyTool", {"x": 2})
assert result[0].text == "4" # type: ignore[attr-defined]
assert result[0].text == "3" # type: ignore[attr-defined]


class TestResource:
Expand Down Expand Up @@ -780,7 +780,7 @@ def resource_with_context(ctx: Context) -> str:

async with Client(mcp) as client:
result = await client.read_resource(AnyUrl("resource://test"))
assert result[0].text == "2" # type: ignore[attr-defined]
assert result[0].text == "1" # type: ignore[attr-defined]


class TestResourceTemplates:
Expand Down Expand Up @@ -1015,7 +1015,7 @@ def resource_template(param: str, ctx: Context) -> str:

async with Client(mcp) as client:
result = await client.read_resource(AnyUrl("resource://test"))
assert result[0].text.startswith("Resource template: test 2") # type: ignore[attr-defined]
assert result[0].text.startswith("Resource template: test 1") # type: ignore[attr-defined]

async def test_resource_template_context_with_callable_object(self):
mcp = FastMCP()
Expand All @@ -1031,7 +1031,7 @@ def __call__(self, param: str, ctx: Context) -> str:

async with Client(mcp) as client:
result = await client.read_resource(AnyUrl("resource://test"))
assert result[0].text.startswith("Resource template: test 2") # type: ignore[attr-defined]
assert result[0].text.startswith("Resource template: test 1") # type: ignore[attr-defined]


class TestPrompts:
Expand Down Expand Up @@ -1249,4 +1249,4 @@ def __call__(self, name: str, ctx: Context) -> str:
assert len(result.messages) == 1
message = result.messages[0]
assert message.role == "user"
assert message.content.text == "Hello, World! 2" # type: ignore[attr-defined]
assert message.content.text == "Hello, World! 1" # type: ignore[attr-defined]