diff --git a/docs/servers/middleware.mdx b/docs/servers/middleware.mdx index 66ebe638c..b38a0a649 100644 --- a/docs/servers/middleware.mdx +++ b/docs/servers/middleware.mdx @@ -449,6 +449,51 @@ mcp.add_middleware(DetailedTimingMiddleware()) The built-in versions include custom logger support, proper formatting, and **DetailedTimingMiddleware** provides operation-specific hooks like `on_call_tool` and `on_read_resource` for granular timing. +### Caching Middleware + +Caching middleware is essential for improving performance and reducing server load. FastMCP provides caching middleware at `fastmcp.server.middleware.caching`. + +Here's how to use the full version: + +```python +from fastmcp.server.middleware.caching import ResponseCachingMiddleware + +mcp.add_middleware(ResponseCachingMiddleware()) +``` + +Out of the box, it caches call/list tool, resources, and prompts to an in-memory cache. Sending a notification of a tool/resource/prompt change will invalidate the cache for the affected method. List calls are stored under global keys, if you share a key_value backend across servers, keep this in mind and consider using the PrefixCollectionsWrapper in py-key-value-aio to namespace collections by server. + +Each method can be configured individually, for example, caching list tools for 30 seconds, skipping caching for tools other than `tool1` and not caching and requests to read resources: + +```python +from fastmcp.server.middleware.caching import ResponseCachingMiddleware, CallToolSettings, ListToolsSettings, ReadResourceSettings + +mcp.add_middleware(ResponseCachingMiddleware( + list_tools_settings=ListToolsSettings( + ttl=30, + ), + call_tool_settings=CallToolSettings( + included_tools=["tool1"], + ), + read_resource_settings=ReadResourceSettings( + enabled=False + ) +)) +``` + +It can also be configured to cache to disk: + +```python +from fastmcp.server.middleware.caching import ResponseCachingMiddleware +from key_value.aio.stores.disk import DiskStore + +mcp.add_middleware(ResponseCachingMiddleware( + cache_storage=DiskStore(directory="cache"), +)) +``` + +See the Contrib modules for caching middleware implementations that support additional features like distributed caching. + ### Logging Middleware Request and response logging is crucial for debugging, monitoring, and understanding usage patterns in your MCP server. FastMCP provides comprehensive logging middleware at `fastmcp.server.middleware.logging`. diff --git a/pyproject.toml b/pyproject.toml index d81a55dfe..e6c93d38f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "pydantic[email]>=2.11.7", "pyperclip>=1.9.0", "openapi-core>=0.19.5", - "py-key-value-aio[disk,memory]>=0.2.2", + "py-key-value-aio[disk,memory]>=0.2.2,<0.3.0", "websockets>=15.0.1", ] diff --git a/src/fastmcp/server/middleware/caching.py b/src/fastmcp/server/middleware/caching.py new file mode 100644 index 000000000..133d6ca95 --- /dev/null +++ b/src/fastmcp/server/middleware/caching.py @@ -0,0 +1,469 @@ +"""A middleware for response caching.""" + +from collections.abc import Sequence +from logging import Logger +from typing import Any, TypedDict + +import mcp.types +import pydantic_core +from key_value.aio.adapters.pydantic import PydanticAdapter +from key_value.aio.protocols.key_value import AsyncKeyValue +from key_value.aio.stores.memory import MemoryStore +from key_value.aio.wrappers.limit_size import LimitSizeWrapper +from key_value.aio.wrappers.statistics import StatisticsWrapper +from key_value.aio.wrappers.statistics.wrapper import ( + KVStoreCollectionStatistics, +) +from mcp.server.lowlevel.helper_types import ReadResourceContents +from pydantic import BaseModel, Field +from typing_extensions import NotRequired, Self, override + +from fastmcp.prompts.prompt import Prompt +from fastmcp.resources.resource import Resource +from fastmcp.server.middleware.middleware import CallNext, Middleware, MiddlewareContext +from fastmcp.tools.tool import Tool, ToolResult +from fastmcp.utilities.logging import get_logger + +logger: Logger = get_logger(name=__name__) + +# Constants +ONE_HOUR_IN_SECONDS = 3600 +FIVE_MINUTES_IN_SECONDS = 300 + +ONE_MB_IN_BYTES = 1024 * 1024 + +GLOBAL_KEY = "__global__" + + +class CachableReadResourceContents(BaseModel): + """A wrapper for ReadResourceContents that can be cached.""" + + content: str | bytes + mime_type: str | None = None + + def get_size(self) -> int: + return len(self.model_dump_json()) + + @classmethod + def get_sizes(cls, values: Sequence[Self]) -> int: + return sum([item.get_size() for item in values]) + + @classmethod + def wrap(cls, values: Sequence[ReadResourceContents]) -> list[Self]: + return [cls(content=item.content, mime_type=item.mime_type) for item in values] + + @classmethod + def unwrap(cls, values: Sequence[Self]) -> list[ReadResourceContents]: + return [ + ReadResourceContents(content=item.content, mime_type=item.mime_type) + for item in values + ] + + +class CachableToolResult(BaseModel): + content: list[mcp.types.ContentBlock] + structured_content: dict[str, Any] | None + + @classmethod + def wrap(cls, value: ToolResult) -> Self: + return cls(content=value.content, structured_content=value.structured_content) + + def unwrap(self) -> ToolResult: + return ToolResult( + content=self.content, structured_content=self.structured_content + ) + + +class SharedMethodSettings(TypedDict): + """Shared config for a cache method.""" + + ttl: NotRequired[int] + enabled: NotRequired[bool] + + +class ListToolsSettings(SharedMethodSettings): + """Configuration options for Tool-related caching.""" + + +class ListResourcesSettings(SharedMethodSettings): + """Configuration options for Resource-related caching.""" + + +class ListPromptsSettings(SharedMethodSettings): + """Configuration options for Prompt-related caching.""" + + +class CallToolSettings(SharedMethodSettings): + """Configuration options for Tool-related caching.""" + + included_tools: NotRequired[list[str]] + excluded_tools: NotRequired[list[str]] + + +class ReadResourceSettings(SharedMethodSettings): + """Configuration options for Resource-related caching.""" + + +class GetPromptSettings(SharedMethodSettings): + """Configuration options for Prompt-related caching.""" + + +class ResponseCachingStatistics(BaseModel): + list_tools: KVStoreCollectionStatistics | None = Field(default=None) + list_resources: KVStoreCollectionStatistics | None = Field(default=None) + list_prompts: KVStoreCollectionStatistics | None = Field(default=None) + read_resource: KVStoreCollectionStatistics | None = Field(default=None) + get_prompt: KVStoreCollectionStatistics | None = Field(default=None) + call_tool: KVStoreCollectionStatistics | None = Field(default=None) + + +class ResponseCachingMiddleware(Middleware): + """The response caching middleware offers a simple way to cache responses to mcp methods. The Middleware + supports cache invalidation via notifications from the server. The Middleware implements TTL-based caching + but cache implementations may offer additional features like LRU eviction, size limits, and more. + + When items are retrieved from the cache they will no longer be the original objects, but rather no-op objects + this means that response caching may not be compatible with other middleware that expects original subclasses. + + Notes: + - Caches `tools/call`, `resources/read`, `prompts/get`, `tools/list`, `resources/list`, and `prompts/list` requests. + - Cache keys are derived from method name and arguments. + """ + + def __init__( + self, + cache_storage: AsyncKeyValue | None = None, + list_tools_settings: ListToolsSettings | None = None, + list_resources_settings: ListResourcesSettings | None = None, + list_prompts_settings: ListPromptsSettings | None = None, + read_resource_settings: ReadResourceSettings | None = None, + get_prompt_settings: GetPromptSettings | None = None, + call_tool_settings: CallToolSettings | None = None, + max_item_size: int = ONE_MB_IN_BYTES, + ): + """Initialize the response caching middleware. + + Args: + cache_storage: The cache backend to use. If None, an in-memory cache is used. + list_tools_settings: The settings for the list tools method. If None, the default settings are used (5 minute TTL). + list_resources_settings: The settings for the list resources method. If None, the default settings are used (5 minute TTL). + list_prompts_settings: The settings for the list prompts method. If None, the default settings are used (5 minute TTL). + read_resource_settings: The settings for the read resource method. If None, the default settings are used (1 hour TTL). + get_prompt_settings: The settings for the get prompt method. If None, the default settings are used (1 hour TTL). + call_tool_settings: The settings for the call tool method. If None, the default settings are used (1 hour TTL). + max_item_size: The maximum size of items eligible for caching. Defaults to 1MB. + """ + + self._backend: AsyncKeyValue = cache_storage or MemoryStore() + + # When the size limit is exceeded, the put will silently fail + self._size_limiter: LimitSizeWrapper = LimitSizeWrapper( + key_value=self._backend, max_size=max_item_size, raise_on_too_large=False + ) + self._stats: StatisticsWrapper = StatisticsWrapper(key_value=self._size_limiter) + + self._list_tools_settings: ListToolsSettings = ( + list_tools_settings or ListToolsSettings() + ) + self._list_resources_settings: ListResourcesSettings = ( + list_resources_settings or ListResourcesSettings() + ) + self._list_prompts_settings: ListPromptsSettings = ( + list_prompts_settings or ListPromptsSettings() + ) + + self._read_resource_settings: ReadResourceSettings = ( + read_resource_settings or ReadResourceSettings() + ) + self._get_prompt_settings: GetPromptSettings = ( + get_prompt_settings or GetPromptSettings() + ) + self._call_tool_settings: CallToolSettings = ( + call_tool_settings or CallToolSettings() + ) + + self._list_tools_cache: PydanticAdapter[list[Tool]] = PydanticAdapter( + key_value=self._stats, + pydantic_model=list[Tool], + default_collection="tools/list", + ) + + self._list_resources_cache: PydanticAdapter[list[Resource]] = PydanticAdapter( + key_value=self._stats, + pydantic_model=list[Resource], + default_collection="resources/list", + ) + + self._list_prompts_cache: PydanticAdapter[list[Prompt]] = PydanticAdapter( + key_value=self._stats, + pydantic_model=list[Prompt], + default_collection="prompts/list", + ) + + self._read_resource_cache: PydanticAdapter[ + list[CachableReadResourceContents] + ] = PydanticAdapter( + key_value=self._stats, + pydantic_model=list[CachableReadResourceContents], + default_collection="resources/read", + ) + + self._get_prompt_cache: PydanticAdapter[mcp.types.GetPromptResult] = ( + PydanticAdapter( + key_value=self._stats, + pydantic_model=mcp.types.GetPromptResult, + default_collection="prompts/get", + ) + ) + + self._call_tool_cache: PydanticAdapter[CachableToolResult] = PydanticAdapter( + key_value=self._stats, + pydantic_model=CachableToolResult, + default_collection="tools/call", + ) + + @override + async def on_list_tools( + self, + context: MiddlewareContext[mcp.types.ListToolsRequest], + call_next: CallNext[mcp.types.ListToolsRequest, Sequence[Tool]], + ) -> Sequence[Tool]: + """List tools from the cache, if caching is enabled, and the result is in the cache. Otherwise, + otherwise call the next middleware and store the result in the cache if caching is enabled.""" + if self._list_tools_settings.get("enabled") is False: + return await call_next(context) + + if cached_value := await self._list_tools_cache.get(key=GLOBAL_KEY): + return cached_value + + tools: Sequence[Tool] = await call_next(context=context) + + # Turn any subclass of Tool into a Tool + cachable_tools: list[Tool] = [ + Tool( + name=tool.name, + title=tool.title, + description=tool.description, + parameters=tool.parameters, + output_schema=tool.output_schema, + annotations=tool.annotations, + meta=tool.meta, + tags=tool.tags, + enabled=tool.enabled, + ) + for tool in tools + ] + + await self._list_tools_cache.put( + key=GLOBAL_KEY, + value=cachable_tools, + ttl=self._list_tools_settings.get("ttl", FIVE_MINUTES_IN_SECONDS), + ) + + return cachable_tools + + @override + async def on_list_resources( + self, + context: MiddlewareContext[mcp.types.ListResourcesRequest], + call_next: CallNext[mcp.types.ListResourcesRequest, Sequence[Resource]], + ) -> Sequence[Resource]: + """List resources from the cache, if caching is enabled, and the result is in the cache. Otherwise, + otherwise call the next middleware and store the result in the cache if caching is enabled.""" + if self._list_resources_settings.get("enabled") is False: + return await call_next(context) + + if cached_value := await self._list_resources_cache.get(key=GLOBAL_KEY): + return cached_value + + resources: Sequence[Resource] = await call_next(context=context) + + # Turn any subclass of Resource into a Resource + cachable_resources: list[Resource] = [ + Resource( + name=resource.name, + title=resource.title, + description=resource.description, + tags=resource.tags, + meta=resource.meta, + mime_type=resource.mime_type, + annotations=resource.annotations, + enabled=resource.enabled, + uri=resource.uri, + ) + for resource in resources + ] + + await self._list_resources_cache.put( + key=GLOBAL_KEY, + value=cachable_resources, + ttl=self._list_resources_settings.get("ttl", FIVE_MINUTES_IN_SECONDS), + ) + + return cachable_resources + + @override + async def on_list_prompts( + self, + context: MiddlewareContext[mcp.types.ListPromptsRequest], + call_next: CallNext[mcp.types.ListPromptsRequest, Sequence[Prompt]], + ) -> Sequence[Prompt]: + """List prompts from the cache, if caching is enabled, and the result is in the cache. Otherwise, + otherwise call the next middleware and store the result in the cache if caching is enabled.""" + if self._list_prompts_settings.get("enabled") is False: + return await call_next(context) + + if cached_value := await self._list_prompts_cache.get(key=GLOBAL_KEY): + return cached_value + + prompts: Sequence[Prompt] = await call_next(context=context) + + # Turn any subclass of Prompt into a Prompt + cachable_prompts: list[Prompt] = [ + Prompt( + name=prompt.name, + title=prompt.title, + description=prompt.description, + tags=prompt.tags, + meta=prompt.meta, + enabled=prompt.enabled, + arguments=prompt.arguments, + ) + for prompt in prompts + ] + + await self._list_prompts_cache.put( + key=GLOBAL_KEY, + value=cachable_prompts, + ttl=self._list_prompts_settings.get("ttl", FIVE_MINUTES_IN_SECONDS), + ) + + return cachable_prompts + + @override + async def on_call_tool( + self, + context: MiddlewareContext[mcp.types.CallToolRequestParams], + call_next: CallNext[mcp.types.CallToolRequestParams, ToolResult], + ) -> ToolResult: + """Call a tool from the cache, if caching is enabled, and the result is in the cache. Otherwise, + otherwise call the next middleware and store the result in the cache if caching is enabled.""" + tool_name = context.message.name + + if self._call_tool_settings.get( + "enabled" + ) is False or not self._matches_tool_cache_settings(tool_name=tool_name): + return await call_next(context=context) + + cache_key: str = f"{tool_name}:{_get_arguments_str(context.message.arguments)}" + + if cached_value := await self._call_tool_cache.get(key=cache_key): + return cached_value.unwrap() + + tool_result: ToolResult = await call_next(context=context) + cachable_tool_result: CachableToolResult = CachableToolResult.wrap( + value=tool_result + ) + + await self._call_tool_cache.put( + key=cache_key, + value=cachable_tool_result, + ttl=self._call_tool_settings.get("ttl", ONE_HOUR_IN_SECONDS), + ) + + return cachable_tool_result.unwrap() + + @override + async def on_read_resource( + self, + context: MiddlewareContext[mcp.types.ReadResourceRequestParams], + call_next: CallNext[ + mcp.types.ReadResourceRequestParams, Sequence[ReadResourceContents] + ], + ) -> Sequence[ReadResourceContents]: + """Read a resource from the cache, if caching is enabled, and the result is in the cache. Otherwise, + otherwise call the next middleware and store the result in the cache if caching is enabled.""" + if self._read_resource_settings.get("enabled") is False: + return await call_next(context=context) + + cache_key: str = str(context.message.uri) + cached_value: list[CachableReadResourceContents] | None + + if cached_value := await self._read_resource_cache.get(key=cache_key): + return CachableReadResourceContents.unwrap(values=cached_value) + + value: Sequence[ReadResourceContents] = await call_next(context=context) + cached_value = CachableReadResourceContents.wrap(values=value) + + await self._read_resource_cache.put( + key=cache_key, + value=cached_value, + ttl=self._read_resource_settings.get("ttl", ONE_HOUR_IN_SECONDS), + ) + + return CachableReadResourceContents.unwrap(values=cached_value) + + @override + async def on_get_prompt( + self, + context: MiddlewareContext[mcp.types.GetPromptRequestParams], + call_next: CallNext[ + mcp.types.GetPromptRequestParams, mcp.types.GetPromptResult + ], + ) -> mcp.types.GetPromptResult: + """Get a prompt from the cache, if caching is enabled, and the result is in the cache. Otherwise, + otherwise call the next middleware and store the result in the cache if caching is enabled.""" + if self._get_prompt_settings.get("enabled") is False: + return await call_next(context=context) + + cache_key: str = f"{context.message.name}:{_get_arguments_str(arguments=context.message.arguments)}" + + if cached_value := await self._get_prompt_cache.get(key=cache_key): + return cached_value + + value: mcp.types.GetPromptResult = await call_next(context=context) + + await self._get_prompt_cache.put( + key=cache_key, + value=value, + ttl=self._get_prompt_settings.get("ttl", ONE_HOUR_IN_SECONDS), + ) + + return value + + def _matches_tool_cache_settings(self, tool_name: str) -> bool: + """Check if the tool matches the cache settings for tool calls.""" + + if included_tools := self._call_tool_settings.get("included_tools"): + if tool_name not in included_tools: + return False + + if excluded_tools := self._call_tool_settings.get("excluded_tools"): + if tool_name in excluded_tools: + return False + + return True + + def statistics(self) -> ResponseCachingStatistics: + """Get the statistics for the cache.""" + return ResponseCachingStatistics( + list_tools=self._stats.statistics.collections.get("tools/list"), + list_resources=self._stats.statistics.collections.get("resources/list"), + list_prompts=self._stats.statistics.collections.get("prompts/list"), + read_resource=self._stats.statistics.collections.get("resources/read"), + get_prompt=self._stats.statistics.collections.get("prompts/get"), + call_tool=self._stats.statistics.collections.get("tools/call"), + ) + + +def _get_arguments_str(arguments: dict[str, Any] | None) -> str: + """Get a string representation of the arguments.""" + + if arguments is None: + return "null" + + try: + return pydantic_core.to_json(value=arguments, fallback=str).decode() + + except TypeError: + return repr(arguments) diff --git a/tests/server/middleware/test_caching.py b/tests/server/middleware/test_caching.py new file mode 100644 index 000000000..696933530 --- /dev/null +++ b/tests/server/middleware/test_caching.py @@ -0,0 +1,507 @@ +"""Tests for response caching middleware.""" + +import tempfile +from unittest.mock import AsyncMock, MagicMock + +import mcp.types +import pytest +from inline_snapshot import snapshot +from key_value.aio.stores.disk import DiskStore +from key_value.aio.stores.memory import MemoryStore +from key_value.aio.wrappers.statistics.wrapper import ( + GetStatistics, + KVStoreCollectionStatistics, + PutStatistics, +) +from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.types import PromptMessage, TextContent, TextResourceContents +from pydantic import AnyUrl, BaseModel + +from fastmcp import Context, FastMCP +from fastmcp.client.client import CallToolResult, Client +from fastmcp.client.transports import FastMCPTransport +from fastmcp.prompts.prompt import FunctionPrompt, Prompt +from fastmcp.resources.resource import Resource +from fastmcp.server.middleware.caching import ( + CallToolSettings, + ResponseCachingMiddleware, + ResponseCachingStatistics, +) +from fastmcp.server.middleware.middleware import CallNext, MiddlewareContext +from fastmcp.tools.tool import Tool, ToolResult + +TEST_URI = AnyUrl("https://test_uri") + +SAMPLE_READ_RESOURCE_CONTENTS = ReadResourceContents( + content="test_text", + mime_type="text/plain", +) + + +def sample_resource_fn() -> list[ReadResourceContents]: + return [SAMPLE_READ_RESOURCE_CONTENTS] + + +SAMPLE_PROMPT_CONTENTS = TextContent(type="text", text="test_text") + + +def sample_prompt_fn() -> PromptMessage: + return PromptMessage(role="user", content=SAMPLE_PROMPT_CONTENTS) + + +SAMPLE_RESOURCE = Resource.from_function( + fn=sample_resource_fn, uri=TEST_URI, name="test_resource" +) + +SAMPLE_PROMPT = Prompt.from_function(fn=sample_prompt_fn, name="test_prompt") +SAMPLE_GET_PROMPT_RESULT = mcp.types.GetPromptResult( + messages=[ + mcp.types.PromptMessage( + role="user", content=mcp.types.TextContent(type="text", text="test_text") + ) + ] +) +SAMPLE_TOOL = Tool(name="test_tool", parameters={"param1": "value1", "param2": 42}) +SAMPLE_TOOL_RESULT = ToolResult( + content=[TextContent(type="text", text="test_text")], + structured_content={"result": "test_result"}, +) +SAMPLE_TOOL_RESULT_LARGE = ToolResult( + content=[TextContent(type="text", text="test_text" * 100)], + structured_content={"result": "test_result"}, +) + + +class CrazyModel(BaseModel): + a: int + b: int + c: str + d: float + e: bool + f: list[int] + g: dict[str, int] + h: list[dict[str, int]] + i: dict[str, list[int]] + + +@pytest.fixture +def crazy_model() -> CrazyModel: + return CrazyModel( + a=5, + b=10, + c="test", + d=1.0, + e=True, + f=[1, 2, 3], + g={"a": 1, "b": 2}, + h=[{"a": 1, "b": 2}], + i={"a": [1, 2]}, + ) + + +class TrackingCalculator: + add_calls: int + multiply_calls: int + crazy_calls: int + very_large_response_calls: int + + def __init__(self): + self.add_calls = 0 + self.multiply_calls = 0 + self.crazy_calls = 0 + self.very_large_response_calls = 0 + + def add(self, a: int, b: int) -> int: + self.add_calls += 1 + return a + b + + def multiply(self, a: int, b: int) -> int: + self.multiply_calls += 1 + return a * b + + def very_large_response(self) -> str: + self.very_large_response_calls += 1 + return "istenchars" * 100000 # 1,000,000 characters, 1mb + + def crazy(self, a: CrazyModel) -> CrazyModel: + self.crazy_calls += 1 + return a + + def how_to_calculate(self, a: int, b: int) -> str: + return f"To calculate {a} + {b}, you need to add {a} and {b} together." + + def get_add_calls(self) -> int: + return self.add_calls + + def get_multiply_calls(self) -> int: + return self.multiply_calls + + def get_crazy_calls(self) -> int: + return self.crazy_calls + + async def update_tool_list(self, context: Context): + await context.send_tool_list_changed() + + def add_tools(self, fastmcp: FastMCP, prefix: str = ""): + _ = fastmcp.add_tool(tool=Tool.from_function(fn=self.add, name=f"{prefix}add")) + _ = fastmcp.add_tool( + tool=Tool.from_function(fn=self.multiply, name=f"{prefix}multiply") + ) + _ = fastmcp.add_tool( + tool=Tool.from_function(fn=self.crazy, name=f"{prefix}crazy") + ) + _ = fastmcp.add_tool( + tool=Tool.from_function( + fn=self.very_large_response, name=f"{prefix}very_large_response" + ) + ) + _ = fastmcp.add_tool( + tool=Tool.from_function( + fn=self.update_tool_list, name=f"{prefix}update_tool_list" + ) + ) + + def add_prompts(self, fastmcp: FastMCP, prefix: str = ""): + _ = fastmcp.add_prompt( + prompt=FunctionPrompt.from_function( + fn=self.how_to_calculate, name=f"{prefix}how_to_calculate" + ) + ) + + def add_resources(self, fastmcp: FastMCP, prefix: str = ""): + _ = fastmcp.add_resource( + resource=Resource.from_function( + fn=self.get_add_calls, + uri="resource://add_calls", + name=f"{prefix}add_calls", + ) + ) + _ = fastmcp.add_resource( + resource=Resource.from_function( + fn=self.get_multiply_calls, + uri="resource://multiply_calls", + name=f"{prefix}multiply_calls", + ) + ) + _ = fastmcp.add_resource( + resource=Resource.from_function( + fn=self.get_crazy_calls, + uri="resource://crazy_calls", + name=f"{prefix}crazy_calls", + ) + ) + + +@pytest.fixture +def tracking_calculator() -> TrackingCalculator: + return TrackingCalculator() + + +@pytest.fixture +def mock_context() -> MiddlewareContext[mcp.types.CallToolRequestParams]: + """Create a mock middleware context for tool calls.""" + context = MagicMock(spec=MiddlewareContext[mcp.types.CallToolRequestParams]) + context.message = mcp.types.CallToolRequestParams( + name="test_tool", arguments={"param1": "value1", "param2": 42} + ) + context.method = "tools/call" + return context + + +@pytest.fixture +def mock_call_next() -> CallNext[mcp.types.CallToolRequestParams, ToolResult]: + """Create a mock call_next function.""" + return AsyncMock( + return_value=ToolResult( + content=[TextContent(type="text", text="test result")], + structured_content={"result": "success", "value": 123}, + ) + ) + + +@pytest.fixture +def sample_tool_result() -> ToolResult: + """Create a sample tool result for testing.""" + return ToolResult( + content=[TextContent(type="text", text="cached result")], + structured_content={"cached": True, "data": "test"}, + ) + + +class TestResponseCachingMiddleware: + """Test ResponseCachingMiddleware functionality.""" + + def test_initialization(self): + """Test middleware initialization.""" + assert ResponseCachingMiddleware( + call_tool_settings=CallToolSettings( + included_tools=["tool1"], + excluded_tools=["tool2"], + ), + ) + + @pytest.mark.parametrize( + ("tool_name", "included_tools", "excluded_tools", "result"), + [ + ("tool", ["tool", "tool2"], [], True), + ("tool", ["second tool", "third tool"], [], False), + ("tool", [], ["tool"], False), + ("tool", [], ["second tool"], True), + ("tool", ["tool", "second tool"], ["tool"], False), + ("tool", ["tool", "second tool"], ["second tool"], True), + ], + ids=[ + "tool is included", + "tool is not included", + "tool is excluded", + "tool is not excluded", + "tool is included and excluded (excluded takes precedence)", + "tool is included and not excluded", + ], + ) + def test_tool_call_filtering( + self, + tool_name: str, + included_tools: list[str], + excluded_tools: list[str], + result: bool, + ): + """Test tool filtering logic.""" + + middleware1 = ResponseCachingMiddleware( + call_tool_settings=CallToolSettings( + included_tools=included_tools, excluded_tools=excluded_tools + ), + ) + assert middleware1._matches_tool_cache_settings(tool_name=tool_name) is result + + +class TestResponseCachingMiddlewareIntegration: + """Integration tests with real FastMCP server.""" + + @pytest.fixture(params=["memory", "disk"]) + async def caching_server( + self, + tracking_calculator: TrackingCalculator, + request: pytest.FixtureRequest, + ): + """Create a FastMCP server for caching tests.""" + mcp = FastMCP("CachingTestServer") + + with tempfile.TemporaryDirectory() as temp_dir: + disk_store = DiskStore(directory=temp_dir) + response_caching_middleware = ResponseCachingMiddleware( + cache_storage=disk_store if request.param == "disk" else MemoryStore(), + ) + + mcp.add_middleware(middleware=response_caching_middleware) + + tracking_calculator.add_tools(fastmcp=mcp) + tracking_calculator.add_resources(fastmcp=mcp) + tracking_calculator.add_prompts(fastmcp=mcp) + + yield mcp + + await disk_store.close() + + @pytest.fixture + def non_caching_server(self, tracking_calculator: TrackingCalculator): + """Create a FastMCP server for non-caching tests.""" + mcp = FastMCP("NonCachingTestServer") + tracking_calculator.add_tools(fastmcp=mcp) + return mcp + + async def test_list_tools( + self, caching_server: FastMCP, tracking_calculator: TrackingCalculator + ): + """Test that tool list caching works with a real FastMCP server.""" + + async with Client(caching_server) as client: + pre_tool_list: list[mcp.types.Tool] = await client.list_tools() + assert len(pre_tool_list) == 5 + + # Add a tool and make sure it's missing from the list tool response + _ = caching_server.add_tool( + tool=Tool.from_function(fn=tracking_calculator.add, name="add_2") + ) + + post_tool_list: list[mcp.types.Tool] = await client.list_tools() + assert len(post_tool_list) == 5 + + assert pre_tool_list == post_tool_list + + async def test_call_tool( + self, + caching_server: FastMCP, + tracking_calculator: TrackingCalculator, + ): + """Test that caching works with a real FastMCP server.""" + tracking_calculator.add_tools(fastmcp=caching_server) + + async with Client[FastMCPTransport](transport=caching_server) as client: + call_tool_result_one: CallToolResult = await client.call_tool( + "add", {"a": 5, "b": 3} + ) + + assert tracking_calculator.add_calls == 1 + call_tool_result_two: CallToolResult = await client.call_tool( + "add", {"a": 5, "b": 3} + ) + assert call_tool_result_one == call_tool_result_two + + async def test_call_tool_very_large_value( + self, + caching_server: FastMCP, + tracking_calculator: TrackingCalculator, + ): + """Test that caching works with a real FastMCP server.""" + tracking_calculator.add_tools(fastmcp=caching_server) + + async with Client[FastMCPTransport](transport=caching_server) as client: + call_tool_result_one: CallToolResult = await client.call_tool( + "very_large_response", {} + ) + + assert tracking_calculator.very_large_response_calls == 1 + call_tool_result_two: CallToolResult = await client.call_tool( + "very_large_response", {} + ) + assert call_tool_result_one == call_tool_result_two + assert tracking_calculator.very_large_response_calls == 2 + + async def test_call_tool_crazy_value( + self, + caching_server: FastMCP, + tracking_calculator: TrackingCalculator, + crazy_model: CrazyModel, + ): + """Test that caching works with a real FastMCP server.""" + tracking_calculator.add_tools(fastmcp=caching_server) + + async with Client[FastMCPTransport](transport=caching_server) as client: + call_tool_result_one: CallToolResult = await client.call_tool( + "crazy", {"a": crazy_model} + ) + + assert tracking_calculator.crazy_calls == 1 + call_tool_result_two: CallToolResult = await client.call_tool( + "crazy", {"a": crazy_model} + ) + assert call_tool_result_one == call_tool_result_two + assert tracking_calculator.crazy_calls == 1 + + async def test_list_resources( + self, caching_server: FastMCP, tracking_calculator: TrackingCalculator + ): + """Test that list resources caching works with a real FastMCP server.""" + async with Client[FastMCPTransport](transport=caching_server) as client: + pre_resource_list: list[mcp.types.Resource] = await client.list_resources() + + assert len(pre_resource_list) == 3 + + tracking_calculator.add_resources(fastmcp=caching_server) + + post_resource_list: list[mcp.types.Resource] = await client.list_resources() + assert len(post_resource_list) == 3 + + assert pre_resource_list == post_resource_list + + async def test_read_resource( + self, caching_server: FastMCP, tracking_calculator: TrackingCalculator + ): + """Test that get resources caching works with a real FastMCP server.""" + async with Client[FastMCPTransport](transport=caching_server) as client: + pre_resource = await client.read_resource(uri="resource://add_calls") + assert isinstance(pre_resource[0], TextResourceContents) + assert pre_resource[0].text == "0" + + tracking_calculator.add_calls = 1 + + post_resource = await client.read_resource(uri="resource://add_calls") + assert isinstance(post_resource[0], TextResourceContents) + assert post_resource[0].text == "0" + assert pre_resource == post_resource + + async def test_list_prompts( + self, caching_server: FastMCP, tracking_calculator: TrackingCalculator + ): + """Test that list prompts caching works with a real FastMCP server.""" + async with Client[FastMCPTransport](transport=caching_server) as client: + pre_prompt_list: list[mcp.types.Prompt] = await client.list_prompts() + + assert len(pre_prompt_list) == 1 + + tracking_calculator.add_prompts(fastmcp=caching_server) + + post_prompt_list: list[mcp.types.Prompt] = await client.list_prompts() + + assert len(post_prompt_list) == 1 + + assert pre_prompt_list == post_prompt_list + + async def test_get_prompts( + self, caching_server: FastMCP, tracking_calculator: TrackingCalculator + ): + """Test that get prompts caching works with a real FastMCP server.""" + async with Client[FastMCPTransport](transport=caching_server) as client: + pre_prompt = await client.get_prompt( + name="how_to_calculate", arguments={"a": 5, "b": 3} + ) + + pre_prompt_content = pre_prompt.messages[0].content + assert isinstance(pre_prompt_content, TextContent) + assert ( + pre_prompt_content.text + == "To calculate 5 + 3, you need to add 5 and 3 together." + ) + + tracking_calculator.add_prompts(fastmcp=caching_server) + + post_prompt = await client.get_prompt( + name="how_to_calculate", arguments={"a": 5, "b": 3} + ) + + assert pre_prompt == post_prompt + + async def test_statistics( + self, + caching_server: FastMCP, + ): + """Test that statistics are collected correctly.""" + caching_middleware = caching_server.middleware[0] + assert isinstance(caching_middleware, ResponseCachingMiddleware) + + async with Client[FastMCPTransport](transport=caching_server) as client: + statistics = caching_middleware.statistics() + assert statistics == snapshot(ResponseCachingStatistics()) + + _ = await client.call_tool("add", {"a": 5, "b": 3}) + + statistics = caching_middleware.statistics() + assert statistics == snapshot( + ResponseCachingStatistics( + list_tools=KVStoreCollectionStatistics( + get=GetStatistics(count=2, hit=1, miss=1), + put=PutStatistics(count=1), + ), + call_tool=KVStoreCollectionStatistics( + get=GetStatistics(count=1, miss=1), put=PutStatistics(count=1) + ), + ) + ) + + _ = await client.call_tool("add", {"a": 5, "b": 3}) + + statistics = caching_middleware.statistics() + assert statistics == snapshot( + ResponseCachingStatistics( + list_tools=KVStoreCollectionStatistics( + get=GetStatistics(count=2, hit=1, miss=1), + put=PutStatistics(count=1), + ), + call_tool=KVStoreCollectionStatistics( + get=GetStatistics(count=2, hit=1, miss=1), + put=PutStatistics(count=1), + ), + ) + ) diff --git a/uv.lock b/uv.lock index 37903520c..9d140bc21 100644 --- a/uv.lock +++ b/uv.lock @@ -603,7 +603,7 @@ requires-dist = [ { name = "openai", marker = "extra == 'openai'", specifier = ">=1.102.0" }, { name = "openapi-core", specifier = ">=0.19.5" }, { name = "openapi-pydantic", specifier = ">=0.5.1" }, - { name = "py-key-value-aio", extras = ["disk", "memory"], specifier = ">=0.2.2" }, + { name = "py-key-value-aio", extras = ["disk", "memory"], specifier = ">=0.2.2,<0.3.0" }, { name = "pydantic", extras = ["email"], specifier = ">=2.11.7" }, { name = "pyperclip", specifier = ">=1.9.0" }, { name = "python-dotenv", specifier = ">=1.1.0" }, @@ -1319,15 +1319,15 @@ wheels = [ [[package]] name = "py-key-value-aio" -version = "0.2.2" +version = "0.2.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "beartype" }, { name = "py-key-value-shared" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ef/d0/931ea2ca54eba5b1cf53e6fa29e371a58e53ce327cb84ae0317d1269400e/py_key_value_aio-0.2.2.tar.gz", hash = "sha256:e8e4ea8a9c5c5e7b1c79e019e47cd8595d0d4c2bc5be977e357de734f920c96f", size = 20877, upload-time = "2025-10-14T18:10:09.672Z" } +sdist = { url = "https://files.pythonhosted.org/packages/76/47/948cca79fdcdd6177e8852c74cfa3447bcfea1c4a133b3c532933e98eb9e/py_key_value_aio-0.2.5.tar.gz", hash = "sha256:41093d126b98e041d9b10dd38a4c28af8a9aa5ff25857d7a1018d6ee2ce4f66e", size = 29956, upload-time = "2025-10-16T16:56:29.154Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/22/eb/bb0b1cb92defee373635fc723af11e093c54b5ed614d825735c03decfc47/py_key_value_aio-0.2.2-py3-none-any.whl", hash = "sha256:59a2858807adc3bfdf24ac6e65c091ef914a871ea89f1293ccd550d48020d1a7", size = 44077, upload-time = "2025-10-14T18:10:08.874Z" }, + { url = "https://files.pythonhosted.org/packages/76/a1/d74a611a4f8b6db30e7eab6c5d5da4241b3f9bcedbb969570b508e0660bb/py_key_value_aio-0.2.5-py3-none-any.whl", hash = "sha256:ae7a3f85a5955ccdfa73fce967b7afe81a0e87ff9692dee93e5999fdbe5d0a11", size = 63327, upload-time = "2025-10-16T16:56:26.717Z" }, ] [package.optional-dependencies] @@ -1341,15 +1341,15 @@ memory = [ [[package]] name = "py-key-value-shared" -version = "0.2.2" +version = "0.2.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "beartype" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/13/35/c837273b0404ea285da8a881e4dbd47d096b866bcfacaf234ca4bd529c4c/py_key_value_shared-0.2.2.tar.gz", hash = "sha256:7e922efb721d6ba0ef23101a1d96a2a30fa2b55c2dade090f26f32f0edb09ff6", size = 7209, upload-time = "2025-10-14T18:10:10.598Z" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/a2/f8b1f65afd48b8774453187576a0a3ba776555d7eb8f2d7d251d82130635/py_key_value_shared-0.2.5.tar.gz", hash = "sha256:1484e6cb3a2aefa396d72e938b5acf6609d10a594f589b71bb42a7da0cbf9ebb", size = 8044, upload-time = "2025-10-16T16:56:29.941Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ba/9b/c56cc06403305c3cd8c6deb0eae91f59bbe442f0e1b195ebdb822718ea80/py_key_value_shared-0.2.2-py3-none-any.whl", hash = "sha256:5073cce73450471990e3fa01d2e2c158588a47e6324feaf067a29f0b189a7194", size = 12035, upload-time = "2025-10-14T18:10:09.247Z" }, + { url = "https://files.pythonhosted.org/packages/96/eb/5a9caf4204953520206b6e91ac380c0cc524614534c47944fecae6a2faf2/py_key_value_shared-0.2.5-py3-none-any.whl", hash = "sha256:1e439328cb6ce697660100cf0f395e92dd4f43d3405c6eddae5986de78402045", size = 14139, upload-time = "2025-10-16T16:56:27.325Z" }, ] [[package]]