diff --git a/frontend/src/components/app-config/ai-config.tsx b/frontend/src/components/app-config/ai-config.tsx index e9bc68f3c0b..65079a56c87 100644 --- a/frontend/src/components/app-config/ai-config.tsx +++ b/frontend/src/components/app-config/ai-config.tsx @@ -77,6 +77,7 @@ import { SettingSubtitle } from "./common"; import { AWS_REGIONS } from "./constants"; import { IncorrectModelId } from "./incorrect-model-id"; import { IsOverridden } from "./is-overridden"; +import { MCPConfig } from "./mcp-config"; const formItemClasses = "flex flex-row items-center space-x-1 space-y-0"; @@ -1364,12 +1365,15 @@ export const AiConfig: React.FC = ({ config, onSubmit, }) => { + // MCP is not supported in WASM + const wasm = isWasm(); return ( AI Features AI Providers AI Models + {!wasm && MCP} @@ -1386,6 +1390,11 @@ export const AiConfig: React.FC = ({ + {!wasm && ( + + + + )} ); }; diff --git a/frontend/src/components/app-config/mcp-config.tsx b/frontend/src/components/app-config/mcp-config.tsx new file mode 100644 index 00000000000..253d3bdd595 --- /dev/null +++ b/frontend/src/components/app-config/mcp-config.tsx @@ -0,0 +1,128 @@ +/* Copyright 2024 Marimo. All rights reserved. */ + +import { CheckSquareIcon } from "lucide-react"; +import React from "react"; +import type { UseFormReturn } from "react-hook-form"; +import { + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { FormField, FormItem } from "@/components/ui/form"; +import type { UserConfig } from "@/core/config/config-schema"; +import { Button } from "../ui/button"; +import { Kbd } from "../ui/kbd"; +import { SettingSubtitle } from "./common"; +import { useOpenSettingsToTab } from "./state"; + +interface MCPConfigProps { + form: UseFormReturn; + onSubmit: (values: UserConfig) => void; +} + +type MCPPreset = "marimo" | "context7"; + +interface PresetConfig { + id: MCPPreset; + title: string; + description: string; +} + +const PRESET_CONFIGS: PresetConfig[] = [ + { + id: "marimo", + title: "marimo (docs)", + description: "Access marimo documentation", + }, + { + id: "context7", + title: "Context7", + description: "Connect to Context7 MCP server", + }, +]; + +export const MCPConfig: React.FC = ({ form, onSubmit }) => { + const { handleClick } = useOpenSettingsToTab(); + + return ( +
+ MCP Servers +

+ Enable Model Context Protocol (MCP) servers to provide additional + capabilities and data sources for AI features. +

+

+ This feature requires the marimo[mcp]{" "} + package. See{" "} + {" "} + for more details. +

+ + { + const presets = field.value || []; + + const togglePreset = (preset: MCPPreset) => { + const newPresets = presets.includes(preset) + ? presets.filter((p: string) => p !== preset) + : [...presets, preset]; + field.onChange(newPresets); + onSubmit(form.getValues()); + }; + + return ( + +
+ {PRESET_CONFIGS.map((config) => { + const isChecked = presets.includes(config.id); + + return ( + togglePreset(config.id)} + > + +
+ + {config.title} + + + {isChecked ? : null} + +
+
+ + {config.description} + +
+ ); + })} +
+
+ ); + }} + /> +
+ ); +}; diff --git a/frontend/src/components/app-config/optional-features.tsx b/frontend/src/components/app-config/optional-features.tsx index 2ccdc3f4c6f..9fd716b2942 100644 --- a/frontend/src/components/app-config/optional-features.tsx +++ b/frontend/src/components/app-config/optional-features.tsx @@ -75,6 +75,12 @@ const OPTIONAL_DEPENDENCIES: OptionalFeature[] = [ additionalPackageInstalls: [], description: "AI features", }, + { + id: "mcp", + packagesRequired: [{ name: "mcp", minVersion: "1" }], + additionalPackageInstalls: [{ name: "pydantic", minVersion: "2" }], + description: "Connect to MCP servers", + }, { id: "ipy-export", packagesRequired: [{ name: "nbformat" }], diff --git a/frontend/src/components/chat/tool-call-accordion.tsx b/frontend/src/components/chat/tool-call-accordion.tsx index e09623d7a85..dfb53a8f046 100644 --- a/frontend/src/components/chat/tool-call-accordion.tsx +++ b/frontend/src/components/chat/tool-call-accordion.tsx @@ -106,7 +106,7 @@ const ResultRenderer: React.FC<{ result: unknown }> = ({ result }) => { // Otherwise, fall back to the current JSON viewer return ( -
+
{typeof result === "string" ? result : JSON.stringify(result, null, 2)}
); diff --git a/frontend/src/core/config/__tests__/config-schema.test.ts b/frontend/src/core/config/__tests__/config-schema.test.ts index 7cfbfcbf27e..fd3e803989f 100644 --- a/frontend/src/core/config/__tests__/config-schema.test.ts +++ b/frontend/src/core/config/__tests__/config-schema.test.ts @@ -75,6 +75,7 @@ test("default UserConfig - empty", () => { "overrides": {}, "preset": "default", }, + "mcp": {}, "package_management": { "manager": "pip", }, @@ -139,6 +140,7 @@ test("default UserConfig - one level", () => { "overrides": {}, "preset": "default", }, + "mcp": {}, "package_management": { "manager": "pip", }, diff --git a/frontend/src/core/config/config-schema.ts b/frontend/src/core/config/config-schema.ts index b630a4a6e6e..0e3122fef47 100644 --- a/frontend/src/core/config/config-schema.ts +++ b/frontend/src/core/config/config-schema.ts @@ -188,6 +188,12 @@ export const UserConfigSchema = z wasm: z.boolean().optional(), }) .optional(), + mcp: z + .looseObject({ + presets: z.array(z.enum(["marimo", "context7"])).optional(), + }) + .optional() + .prefault({}), }) .partial() .prefault(() => ({ @@ -201,6 +207,7 @@ export const UserConfigSchema = z server: {}, ai: {}, package_management: {}, + mcp: {}, })); export type UserConfig = MarimoConfig; export type SaveConfig = UserConfig["save"]; @@ -302,6 +309,7 @@ export function defaultUserConfig(): UserConfig { server: {}, ai: {}, package_management: {}, + mcp: {}, }; return UserConfigSchema.parse(defaultConfig) as UserConfig; } diff --git a/frontend/src/core/config/feature-flag.tsx b/frontend/src/core/config/feature-flag.tsx index adb669be43f..11c6b2131c6 100644 --- a/frontend/src/core/config/feature-flag.tsx +++ b/frontend/src/core/config/feature-flag.tsx @@ -11,7 +11,6 @@ export interface ExperimentalFeatures { wasm_layouts: boolean; // Used in playground (community cloud) rtc_v2: boolean; performant_table_charts: boolean; - mcp_docs: boolean; chat_modes: boolean; sql_linter: boolean; external_agents: boolean; @@ -25,7 +24,6 @@ const defaultValues: ExperimentalFeatures = { wasm_layouts: false, rtc_v2: false, performant_table_charts: false, - mcp_docs: false, chat_modes: false, sql_linter: true, external_agents: import.meta.env.DEV, diff --git a/marimo/_config/config.py b/marimo/_config/config.py index a6996688261..3587f10b985 100644 --- a/marimo/_config/config.py +++ b/marimo/_config/config.py @@ -14,6 +14,7 @@ from typing import NotRequired from typing import ( + TYPE_CHECKING, Any, Literal, Optional, @@ -508,7 +509,6 @@ class ExperimentalConfig(TypedDict, total=False): wasm_layouts: bool # Used in playground (community cloud) rtc_v2: bool performant_table_charts: bool - mcp_docs: bool chat_modes: bool sql_linter: bool sql_mode: bool @@ -543,8 +543,7 @@ class MarimoConfig(TypedDict): snippets: NotRequired[SnippetsConfig] datasources: NotRequired[DatasourcesConfig] sharing: NotRequired[SharingConfig] - # We don't support configuring MCP servers yet - # mcp: NotRequired[MCPConfig] + mcp: NotRequired[MCPConfig] @mddoc @@ -570,7 +569,12 @@ class MCPServerStreamableHttpConfig(TypedDict): disabled: NotRequired[Optional[bool]] -MCPServerConfig = Union[MCPServerStdioConfig, MCPServerStreamableHttpConfig] +if TYPE_CHECKING: + MCPServerConfig = Union[ + MCPServerStdioConfig, MCPServerStreamableHttpConfig + ] +else: + MCPServerConfig = dict[str, Any] @mddoc @@ -584,16 +588,7 @@ class MCPConfig(TypedDict): """ mcpServers: dict[str, MCPServerConfig] - - -DEFAULT_MCP_CONFIG: MCPConfig = MCPConfig( - mcpServers={ - "marimo": MCPServerStreamableHttpConfig( - url="https://mcp.marimo.app/mcp" - ), - # TODO(bjoaquinc): add more Marimo MCP servers here after they are implemented - } -) + presets: NotRequired[list[Literal["marimo", "context7"]]] @mddoc @@ -677,6 +672,10 @@ class PartialMarimoConfig(TypedDict, total=False): "custom_paths": [], "include_default_snippets": True, }, + "mcp": { + "mcpServers": {}, + "presets": [], + }, } diff --git a/marimo/_server/ai/mcp/__init__.py b/marimo/_server/ai/mcp/__init__.py new file mode 100644 index 00000000000..8829b786185 --- /dev/null +++ b/marimo/_server/ai/mcp/__init__.py @@ -0,0 +1,48 @@ +# Copyright 2024 Marimo. All rights reserved. +"""MCP (Model Context Protocol) client implementation for marimo.""" + +from marimo._server.ai.mcp.client import ( + MCPClient, + MCPServerConnection, + MCPServerStatus, + get_mcp_client, +) +from marimo._server.ai.mcp.config import ( + MCP_PRESETS, + MCPConfigComparator, + MCPConfigDiff, + MCPServerDefinition, + MCPServerDefinitionFactory, + append_presets, +) +from marimo._server.ai.mcp.transport import ( + MCPTransportConnector, + MCPTransportRegistry, + MCPTransportType, + StdioTransportConnector, + StreamableHTTPTransportConnector, +) +from marimo._server.ai.mcp.types import MCPToolArgs + +__all__ = [ + # Client classes + "MCPClient", + "MCPServerConnection", + "MCPServerStatus", + "get_mcp_client", + # Config classes + "MCP_PRESETS", + "MCPConfigComparator", + "MCPConfigDiff", + "MCPServerDefinition", + "MCPServerDefinitionFactory", + "append_presets", + # Transport classes + "MCPTransportConnector", + "MCPTransportRegistry", + "MCPTransportType", + "StdioTransportConnector", + "StreamableHTTPTransportConnector", + # Types + "MCPToolArgs", +] diff --git a/marimo/_server/ai/mcp.py b/marimo/_server/ai/mcp/client.py similarity index 79% rename from marimo/_server/ai/mcp.py rename to marimo/_server/ai/mcp/client.py index db12f974aaa..3961b2708ec 100644 --- a/marimo/_server/ai/mcp.py +++ b/marimo/_server/ai/mcp/client.py @@ -2,31 +2,32 @@ from __future__ import annotations import asyncio -import os import time -from abc import ABC, abstractmethod from contextlib import AsyncExitStack from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Optional, Union from marimo import _loggers -from marimo._config.config import ( - DEFAULT_MCP_CONFIG, - MCPConfig, - MCPServerConfig, - MCPServerStdioConfig, - MCPServerStreamableHttpConfig, -) +from marimo._config.config import MCPConfig from marimo._dependencies.dependencies import DependencyManager +from marimo._server.ai.mcp.config import ( + MCPConfigComparator, + MCPServerDefinition, + MCPServerDefinitionFactory, + append_presets, +) +from marimo._server.ai.mcp.transport import ( + MCPTransportRegistry, +) +from marimo._server.ai.mcp.types import MCPToolArgs if TYPE_CHECKING: - from typing import Protocol, TypedDict - from anyio.streams.memory import ( MemoryObjectReceiveStream, MemoryObjectSendStream, ) + from mcp import ClientSession # type: ignore[import-not-found] from mcp.shared.message import SessionMessage from mcp.types import ( # type: ignore[import-not-found] @@ -36,41 +37,9 @@ Tool, ) - class MCPToolMeta(TypedDict): - """Metadata that marimo adds to MCP tools.""" - - server_name: str - namespaced_name: str - - class MCPToolWithMeta(Protocol): - """MCP Tool with marimo-specific metadata.""" - - name: str - description: str | None - inputSchema: dict[str, Any] - meta: MCPToolMeta - - -# Type alias that matches the MCP SDK's CallToolRequestParams.arguments type -MCPToolArgs = Optional[dict[str, Any]] - -# Type alias for MCP transport connection streams -TransportConnectorResponse = tuple[ - "MemoryObjectReceiveStream[Union[SessionMessage, Exception]]", - "MemoryObjectSendStream[SessionMessage]", -] - LOGGER = _loggers.marimo_logger() -class MCPTransportType(str, Enum): - """Supported MCP transport types.""" - - # based on https://modelcontextprotocol.io/docs/concepts/transports - STDIO = "stdio" - STREAMABLE_HTTP = "streamable_http" - - class MCPServerStatus(Enum): """Status of an MCP server connection.""" @@ -80,160 +49,6 @@ class MCPServerStatus(Enum): ERROR = "error" -@dataclass -class MCPServerDefinition: - """Runtime server definition wrapping config with computed fields.""" - - name: str - transport: MCPTransportType - config: MCPServerConfig - timeout: float = 30.0 - - -class MCPServerDefinitionFactory: - """Factory for creating transport-specific server definitions.""" - - @classmethod - def from_config( - cls, name: str, config: MCPServerConfig - ) -> MCPServerDefinition: - """Create server definition with automatic transport detection. - - Args: - name: Server name - config: Server configuration from config file - - Returns: - Server definition with detected transport type - - Raises: - ValueError: If configuration type is not supported - """ - # Import here to avoid circular imports - - if "command" in config: - return MCPServerDefinition( - name=name, - transport=MCPTransportType.STDIO, - config=config, - timeout=30.0, # default timeout for STDIO - ) - elif "url" in config: - return MCPServerDefinition( - name=name, - transport=MCPTransportType.STREAMABLE_HTTP, - config=config, - timeout=config.get("timeout") or 30.0, - ) - else: - raise ValueError(f"Unsupported config type: {type(config)}") - - -class MCPTransportConnector(ABC): - """Abstract base class for MCP transport connectors.""" - - @abstractmethod - async def connect( - self, server_def: MCPServerDefinition, exit_stack: AsyncExitStack - ) -> TransportConnectorResponse: - """Connect to the MCP server and return read/write streams. - - Args: - server_def: Server definition with transport-specific parameters - exit_stack: Async exit stack for resource management - - Returns: - Tuple of (read_stream, write_stream) for the ClientSession - """ - pass - - -class StdioTransportConnector(MCPTransportConnector): - """STDIO transport connector for process-based MCP servers.""" - - async def connect( - self, server_def: MCPServerDefinition, exit_stack: AsyncExitStack - ) -> TransportConnectorResponse: - # Import MCP SDK components for stdio transport - from mcp import StdioServerParameters - from mcp.client.stdio import stdio_client - - # Type narrowing for mypy - assert "command" in server_def.config - config = cast(MCPServerStdioConfig, server_def.config) - - # Set up environment variables for the server process - env = os.environ.copy() - env.update(config.get("env") or {}) - - # Configure server parameters - server_params = StdioServerParameters( - command=config["command"], - args=config.get("args") or [], - env=env, - ) - - # Establish connection with proper resource management - read, write, *_ = await exit_stack.enter_async_context( - stdio_client(server_params) - ) - - return read, write - - -class StreamableHTTPTransportConnector(MCPTransportConnector): - """Streamable HTTP transport connector for modern HTTP-based MCP servers.""" - - async def connect( - self, server_def: MCPServerDefinition, exit_stack: AsyncExitStack - ) -> TransportConnectorResponse: - # Import MCP SDK components for streamable HTTP transport - from mcp.client.streamable_http import streamablehttp_client - - # Type narrowing for mypy - assert "url" in server_def.config - config = cast(MCPServerStreamableHttpConfig, server_def.config) - - # Establish streamable HTTP connection - read, write, *_ = await exit_stack.enter_async_context( - streamablehttp_client( - config["url"], - headers=config.get("headers") or {}, - timeout=server_def.timeout, - ) - ) - - return read, write - - -class MCPTransportRegistry: - """Registry for MCP transport connectors.""" - - def __init__(self) -> None: - self._connectors: dict[MCPTransportType, MCPTransportConnector] = { - MCPTransportType.STDIO: StdioTransportConnector(), - MCPTransportType.STREAMABLE_HTTP: StreamableHTTPTransportConnector(), - } - - def get_connector( - self, transport_type: MCPTransportType - ) -> MCPTransportConnector: - """Get the appropriate transport connector for the given transport type. - - Args: - transport_type: The type of transport to connect with - - Returns: - Transport connector instance - - Raises: - ValueError: If transport type is not supported - """ - if transport_type not in self._connectors: - raise ValueError(f"Unsupported transport type: {transport_type}") - return self._connectors[transport_type] - - @dataclass class MCPServerConnection: """Represents a connection to an MCP server.""" @@ -259,9 +74,12 @@ class MCPServerConnection: class MCPClient: """Client for managing connections to multiple MCP servers.""" - def __init__(self, config: Optional[MCPConfig] = None): - """Initialize MCP client with server configuration.""" - self.config: MCPConfig = config or MCPConfig(mcpServers={}) + def __init__(self) -> None: + """Initialize MCP client. + + Note: For dynamic reconfiguration, use await client.configure(new_config) + which will handle adding/removing/updating connections automatically. + """ self.servers: dict[str, MCPServerDefinition] = {} self.connections: dict[str, MCPServerConnection] = {} self.tool_registry: dict[str, Tool] = {} @@ -274,15 +92,108 @@ def __init__(self, config: Optional[MCPConfig] = None): self.health_check_timeout: float = ( 5.0 # seconds - shorter timeout for health checks ) - self._parse_config() - def _parse_config(self) -> None: + async def configure(self, config: MCPConfig) -> None: + """Configure the MCP client with the given configuration. + + This method: + 1. Parses the new configuration + 2. Compares it with current configuration + 3. Disconnects from removed servers + 4. Disconnects and reconnects to updated servers + 5. Connects to new servers + + Args: + config: MCP configuration to apply + """ + # Parse new configuration + new_servers = self._parse_config(config) + + # Compute differences + diff = MCPConfigComparator.compute_diff(self.servers, new_servers) + + # Early return if no changes + if not diff.has_changes(): + LOGGER.debug( + "MCP configuration unchanged, skipping reconfiguration" + ) + return + + LOGGER.info( + f"MCP configuration changed: " + f"{len(diff.servers_to_add)} to add, " + f"{len(diff.servers_to_remove)} to remove, " + f"{len(diff.servers_to_update)} to update, " + f"{len(diff.servers_unchanged)} unchanged" + ) + + # Disconnect from removed servers + for server_name in diff.servers_to_remove: + LOGGER.info(f"Removing server: {server_name}") + await self.disconnect_from_server(server_name) + # Clean up from servers and connections registries + if server_name in self.servers: + del self.servers[server_name] + if server_name in self.connections: + del self.connections[server_name] + + # Disconnect from servers that need to be updated (will reconnect below) + for server_name in diff.servers_to_update.keys(): + LOGGER.info(f"Updating server: {server_name}") + await self.disconnect_from_server(server_name) + # Clean up old connection, will be recreated below + if server_name in self.connections: + del self.connections[server_name] + + # Update servers registry with new configuration + # Add new servers + self.servers.update(diff.servers_to_add) + # Update modified servers + self.servers.update(diff.servers_to_update) + + # Connect to new and updated servers + servers_to_connect = {**diff.servers_to_add, **diff.servers_to_update} + + if servers_to_connect: + # Connect to servers concurrently + tasks = [ + self.connect_to_server(server_name) + for server_name in servers_to_connect.keys() + ] + connection_results = await asyncio.gather( + *tasks, return_exceptions=True + ) + + for server_name, result in zip( + servers_to_connect.keys(), connection_results + ): + if isinstance(result, Exception): + LOGGER.error( + f"Failed to connect to {server_name}: {result}" + ) + elif not result: + LOGGER.warning( + f"Connection to {server_name} did not succeed" + ) + + def _parse_config( + self, config: MCPConfig + ) -> dict[str, MCPServerDefinition]: """Parse MCP server configuration. - Note: Servers with invalid configurations are logged but excluded from self.servers, + Note: Servers with invalid configurations are logged but excluded from returned dict, making them unavailable for connection attempts. + + Args: + config: MCP configuration to parse + + Returns: + Dictionary of server name to server definition for valid servers """ - mcp_servers = self.config.get("mcpServers", {}) + # Apply presets before parsing + config = append_presets(config) + mcp_servers = config.get("mcpServers", {}) + parsed_servers: dict[str, MCPServerDefinition] = {} for server_name, server_config in mcp_servers.items(): try: @@ -290,20 +201,22 @@ def _parse_config(self) -> None: server_name, server_config ) - self.servers[server_name] = server_def + parsed_servers[server_name] = server_def LOGGER.debug( - f"Registered MCP server: {server_name} (transport: {server_def.transport})" + f"Parsed MCP server: {server_name} (transport: {server_def.transport})" ) except KeyError as e: LOGGER.error( f"Invalid configuration for server {server_name}: missing {e}" ) - # Note: Server with invalid configuration is not added to self.servers + # Note: Server with invalid configuration is not added to parsed_servers except ValueError as e: LOGGER.error( f"Invalid configuration for server {server_name}: {e}" ) - # Note: Server with invalid configuration is not added to self.servers + # Note: Server with invalid configuration is not added to parsed_servers + + return parsed_servers async def _connection_lifecycle(self, server_name: str) -> None: """Minimal wrapper to run existing connection and disconnection logic in task-owned AsyncExitStack.""" @@ -498,8 +411,12 @@ def _create_namespaced_tool_name( return f"mcp_{server_name}{counter}_{tool_name}" async def connect_to_all_servers(self) -> dict[str, bool]: - """Connect to all configured MCP servers.""" - results = {} + """Connect to all configured MCP servers. + + Returns: + Dictionary mapping server names to connection success status + """ + results: dict[str, bool] = {} # Connect to servers concurrently tasks = [ @@ -545,7 +462,9 @@ def is_error_result(self, result: CallToolResult) -> bool: return hasattr(result, "isError") and result.isError is True async def invoke_tool( - self, namespaced_tool_name: str, params: CallToolRequestParams + self, + namespaced_tool_name: str, + params: CallToolRequestParams, ) -> CallToolResult: """Invoke an MCP tool using properly typed CallToolRequestParams.""" tool = self.tool_registry.get(namespaced_tool_name) @@ -619,7 +538,9 @@ async def invoke_tool( ) def create_tool_params( - self, namespaced_tool_name: str, arguments: MCPToolArgs = None + self, + namespaced_tool_name: str, + arguments: MCPToolArgs = None, ) -> CallToolRequestParams: """Create properly typed CallToolRequestParams for a tool.""" from mcp.types import CallToolRequestParams @@ -890,13 +811,21 @@ async def _cancel_health_monitoring( # Wait for all tasks to complete await asyncio.gather( - *self.health_check_tasks.values(), return_exceptions=True + *self.health_check_tasks.values(), + return_exceptions=True, ) self.health_check_tasks.clear() LOGGER.debug("Cancelled all health monitoring tasks") async def disconnect_from_server(self, server_name: str) -> bool: - """Disconnect from a specific MCP server.""" + """Disconnect from a specific MCP server. + + Args: + server_name: Name of the server to disconnect from + + Returns: + True if disconnection was successful or server wasn't connected, False otherwise + """ connection = self.connections.get(server_name) if not connection: return True @@ -942,19 +871,22 @@ async def disconnect_from_all_servers(self) -> None: _MCP_CLIENT: Optional[MCPClient] = None -def get_mcp_client(config: Optional[MCPConfig] = None) -> MCPClient: - """Get the global MCP client instance, initializing it if needed.""" +def get_mcp_client() -> MCPClient: + """Get the global MCP client instance, initializing it if needed. + + Note: The client must be configured using await client.configure(config) + before connecting to servers. + """ global _MCP_CLIENT if _MCP_CLIENT is None: if not DependencyManager.mcp.has(): - LOGGER.info("MCP SDK not available") - raise ModuleNotFoundError("MCP SDK not available") + msg = "MCP dependencies not available. Install with `pip install marimo[mcp]` or `uv add marimo[mcp]`" + LOGGER.info(msg) + raise ModuleNotFoundError( + msg, + name="mcp", + ) - _MCP_CLIENT = MCPClient(config or _get_default_config()) + _MCP_CLIENT = MCPClient() LOGGER.info("MCP client initialized") return _MCP_CLIENT - - -def _get_default_config() -> MCPConfig: - """Get default MCP configuration.""" - return DEFAULT_MCP_CONFIG diff --git a/marimo/_server/ai/mcp/config.py b/marimo/_server/ai/mcp/config.py new file mode 100644 index 00000000000..0073adea7ef --- /dev/null +++ b/marimo/_server/ai/mcp/config.py @@ -0,0 +1,197 @@ +# Copyright 2024 Marimo. All rights reserved. +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from marimo import _loggers +from marimo._server.ai.mcp.transport import MCPTransportType + +if TYPE_CHECKING: + from marimo._config.config import ( + MCPConfig, + MCPServerConfig, + ) + +LOGGER = _loggers.marimo_logger() + + +# MCP Server Presets +MCP_PRESETS: dict[str, MCPServerConfig] = { + "marimo": { + "url": "https://mcp.marimo.app/mcp", + }, + "context7": { + "url": "https://mcp.context7.com/mcp", + }, +} + + +def append_presets(config: MCPConfig) -> MCPConfig: + """Append preset MCP servers to the configuration. + + Presets are added to the mcpServers dict if they are specified in the + 'presets' list and not already present in mcpServers. + + Args: + config: MCP configuration potentially containing a presets list + + Returns: + Updated configuration with presets appended to mcpServers + """ + from marimo._config.config import MCPConfig + + presets_to_add = config.get("presets", []) + if not presets_to_add: + return config + + # Create a copy to avoid mutating the original + updated_config = MCPConfig(mcpServers=config["mcpServers"].copy()) + + # Add preset servers if not already present + for preset_name in presets_to_add: + if preset_name not in updated_config["mcpServers"]: + if preset_name in MCP_PRESETS: + updated_config["mcpServers"][preset_name] = MCP_PRESETS[ + preset_name + ] + + return updated_config + + +def is_mcp_config_empty(config: MCPConfig | None) -> bool: + """Check if the MCP configuration is empty.""" + if config is None: + return True + return not config.get("mcpServers") and not config.get("presets") + + +@dataclass +class MCPServerDefinition: + """Runtime server definition wrapping config with computed fields.""" + + name: str + transport: MCPTransportType + config: MCPServerConfig + timeout: float = 30.0 + + def __eq__(self, other: object) -> bool: + """Check if two server definitions are equal based on config content.""" + if not isinstance(other, MCPServerDefinition): + return NotImplemented + return ( + self.name == other.name + and self.transport == other.transport + and self.config == other.config + and self.timeout == other.timeout + ) + + def __hash__(self) -> int: + """Hash based on name for use in sets/dicts.""" + return hash(self.name) + + +class MCPServerDefinitionFactory: + """Factory for creating transport-specific server definitions.""" + + @classmethod + def from_config( + cls, name: str, config: MCPServerConfig + ) -> MCPServerDefinition: + """Create server definition with automatic transport detection. + + Args: + name: Server name + config: Server configuration from config file + + Returns: + Server definition with detected transport type + + Raises: + ValueError: If configuration type is not supported + """ + if "command" in config: + return MCPServerDefinition( + name=name, + transport=MCPTransportType.STDIO, + config=config, + timeout=30.0, # default timeout for STDIO + ) + elif "url" in config: + return MCPServerDefinition( + name=name, + transport=MCPTransportType.STREAMABLE_HTTP, + config=config, + timeout=config.get("timeout") or 30.0, + ) + else: + raise ValueError(f"Unsupported config type: {type(config)}") + + +@dataclass +class MCPConfigDiff: + """Represents changes between two MCP configurations.""" + + servers_to_add: dict[str, MCPServerDefinition] + servers_to_remove: set[str] + servers_to_update: dict[str, MCPServerDefinition] + servers_unchanged: set[str] + + def has_changes(self) -> bool: + """Check if there are any changes in the configuration.""" + return bool( + self.servers_to_add + or self.servers_to_remove + or self.servers_to_update + ) + + +class MCPConfigComparator: + """Utility for comparing MCP configurations and computing differences.""" + + @staticmethod + def compute_diff( + current_servers: dict[str, MCPServerDefinition], + new_servers: dict[str, MCPServerDefinition], + ) -> MCPConfigDiff: + """Compare current and new server configurations. + + Args: + current_servers: Currently configured servers + new_servers: New server configuration + + Returns: + MCPConfigDiff describing the changes needed + """ + current_names = set(current_servers.keys()) + new_names = set(new_servers.keys()) + + # Servers that need to be removed + servers_to_remove = current_names - new_names + + # Servers that need to be added + servers_to_add = { + name: new_servers[name] for name in (new_names - current_names) + } + + # Servers that might need to be updated + common_servers = current_names & new_names + servers_to_update: dict[str, MCPServerDefinition] = {} + servers_unchanged: set[str] = set() + + for name in common_servers: + current_def = current_servers[name] + new_def = new_servers[name] + + # Check if server definition has changed + if current_def != new_def: + servers_to_update[name] = new_def + else: + servers_unchanged.add(name) + + return MCPConfigDiff( + servers_to_add=servers_to_add, + servers_to_remove=servers_to_remove, + servers_to_update=servers_to_update, + servers_unchanged=servers_unchanged, + ) diff --git a/marimo/_server/ai/mcp/transport.py b/marimo/_server/ai/mcp/transport.py new file mode 100644 index 00000000000..724247ffd5d --- /dev/null +++ b/marimo/_server/ai/mcp/transport.py @@ -0,0 +1,130 @@ +# Copyright 2024 Marimo. All rights reserved. +from __future__ import annotations + +import os +from abc import ABC, abstractmethod +from enum import Enum +from typing import TYPE_CHECKING, cast + +if TYPE_CHECKING: + from contextlib import AsyncExitStack + + from marimo._config.config import ( + MCPServerStdioConfig, + MCPServerStreamableHttpConfig, + ) + from marimo._server.ai.mcp.config import MCPServerDefinition + from marimo._server.ai.mcp.types import TransportConnectorResponse + + +class MCPTransportType(str, Enum): + """Supported MCP transport types.""" + + # based on https://modelcontextprotocol.io/docs/concepts/transports + STDIO = "stdio" + STREAMABLE_HTTP = "streamable_http" + + +class MCPTransportConnector(ABC): + """Abstract base class for MCP transport connectors.""" + + @abstractmethod + async def connect( + self, server_def: MCPServerDefinition, exit_stack: AsyncExitStack + ) -> TransportConnectorResponse: + """Connect to the MCP server and return read/write streams. + + Args: + server_def: Server definition with transport-specific parameters + exit_stack: Async exit stack for resource management + + Returns: + Tuple of (read_stream, write_stream) for the ClientSession + """ + pass + + +class StdioTransportConnector(MCPTransportConnector): + """STDIO transport connector for process-based MCP servers.""" + + async def connect( + self, server_def: MCPServerDefinition, exit_stack: AsyncExitStack + ) -> TransportConnectorResponse: + # Import MCP SDK components for stdio transport + from mcp import StdioServerParameters + from mcp.client.stdio import stdio_client + + # Type narrowing for mypy + assert "command" in server_def.config + config = cast("MCPServerStdioConfig", server_def.config) + + # Set up environment variables for the server process + env = os.environ.copy() + env.update(config.get("env") or {}) + + # Configure server parameters + server_params = StdioServerParameters( + command=config["command"], + args=config.get("args") or [], + env=env, + ) + + # Establish connection with proper resource management + read, write, *_ = await exit_stack.enter_async_context( + stdio_client(server_params) + ) + + return read, write + + +class StreamableHTTPTransportConnector(MCPTransportConnector): + """Streamable HTTP transport connector for modern HTTP-based MCP servers.""" + + async def connect( + self, server_def: MCPServerDefinition, exit_stack: AsyncExitStack + ) -> TransportConnectorResponse: + # Import MCP SDK components for streamable HTTP transport + from mcp.client.streamable_http import streamablehttp_client + + # Type narrowing for mypy + assert "url" in server_def.config + config = cast("MCPServerStreamableHttpConfig", server_def.config) + + # Establish streamable HTTP connection + read, write, *_ = await exit_stack.enter_async_context( + streamablehttp_client( + config["url"], + headers=config.get("headers", {}), + timeout=server_def.timeout, + ) + ) + + return read, write + + +class MCPTransportRegistry: + """Registry for MCP transport connectors.""" + + def __init__(self) -> None: + self._connectors: dict[MCPTransportType, MCPTransportConnector] = { + MCPTransportType.STDIO: StdioTransportConnector(), + MCPTransportType.STREAMABLE_HTTP: StreamableHTTPTransportConnector(), + } + + def get_connector( + self, transport_type: MCPTransportType + ) -> MCPTransportConnector: + """Get the appropriate transport connector for the given transport type. + + Args: + transport_type: The type of transport to connect with + + Returns: + Transport connector instance + + Raises: + ValueError: If transport type is not supported + """ + if transport_type not in self._connectors: + raise ValueError(f"Unsupported transport type: {transport_type}") + return self._connectors[transport_type] diff --git a/marimo/_server/ai/mcp/types.py b/marimo/_server/ai/mcp/types.py new file mode 100644 index 00000000000..502b7310d14 --- /dev/null +++ b/marimo/_server/ai/mcp/types.py @@ -0,0 +1,24 @@ +# Copyright 2024 Marimo. All rights reserved. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from typing import Union + + from anyio.streams.memory import ( + MemoryObjectReceiveStream, + MemoryObjectSendStream, + ) + + from mcp.shared.message import SessionMessage + + +# Type alias that matches the MCP SDK's CallToolRequestParams.arguments type +MCPToolArgs = Optional[dict[str, Any]] + +# Type alias for MCP transport connection streams +TransportConnectorResponse = tuple[ + "MemoryObjectReceiveStream[Union[SessionMessage, Exception]]", + "MemoryObjectSendStream[SessionMessage]", +] diff --git a/marimo/_server/ai/tools/tool_manager.py b/marimo/_server/ai/tools/tool_manager.py index 2843d08d799..cc616057e96 100644 --- a/marimo/_server/ai/tools/tool_manager.py +++ b/marimo/_server/ai/tools/tool_manager.py @@ -1,7 +1,6 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations -from functools import cached_property from typing import TYPE_CHECKING, Any, Callable, Optional from starlette.applications import ( @@ -12,7 +11,7 @@ from marimo._ai._tools.base import ToolBase, ToolContext from marimo._ai._tools.tools_registry import SUPPORTED_BACKEND_AND_MCP_TOOLS from marimo._config.config import CopilotMode -from marimo._server.ai.mcp import get_mcp_client +from marimo._server.ai.mcp.client import get_mcp_client from marimo._server.ai.tools.types import ( FunctionArgs, ToolCallResult, @@ -20,7 +19,6 @@ ToolSource, ValidationFunction, ) -from marimo._server.api.deps import AppState from marimo._utils.once import once if TYPE_CHECKING: @@ -42,13 +40,6 @@ def __init__(self, app: Starlette) -> None: self._validation_functions: dict[str, ValidationFunction] = {} self.app: Starlette = app - @cached_property - def _enable_mcp_tools(self) -> bool: - # This may be stale but it is ok, since we want to enable MCP on startup - app_state = AppState.from_app(self.app) - config = app_state.config_manager.get_config() - return bool(config.get("experimental", {}).get("mcp_docs", False)) - @once def _init_backend_tools(self) -> None: """Initialize backend tools. We lazily register tools here instead of in the constructor for performance""" @@ -161,9 +152,6 @@ def _validate_backend_tool_arguments( def _list_mcp_tools(self) -> list[ToolDefinition]: """Get all MCP tools from the MCP client.""" - if not self._enable_mcp_tools: - return [] - try: mcp_client = get_mcp_client() mcp_tools = mcp_client.get_all_tools() diff --git a/marimo/_server/api/endpoints/config.py b/marimo/_server/api/endpoints/config.py index 4ef88ef8b7f..ffdc1c89229 100644 --- a/marimo/_server/api/endpoints/config.py +++ b/marimo/_server/api/endpoints/config.py @@ -1,7 +1,7 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, cast from starlette.authentication import requires from starlette.background import BackgroundTask @@ -9,8 +9,12 @@ from marimo import _loggers from marimo._config.config import PartialMarimoConfig +from marimo._dependencies.dependencies import DependencyManager from marimo._messaging.msgspec_encoder import asdict +from marimo._messaging.ops import MissingPackageAlert +from marimo._runtime.packages.utils import is_python_isolated from marimo._runtime.requests import SetUserConfigRequest +from marimo._server.ai.mcp.config import is_mcp_config_empty from marimo._server.api.deps import AppState from marimo._server.api.utils import parse_request from marimo._server.models.models import ( @@ -18,6 +22,7 @@ SuccessResponse, ) from marimo._server.router import APIRouter +from marimo._server.sessions import send_message_to_consumer from marimo._types.ids import ConsumerId if TYPE_CHECKING: @@ -50,6 +55,8 @@ async def save_user_config( $ref: "#/components/schemas/SuccessResponse" """ # noqa: E501 app_state = AppState(request) + session_id = app_state.get_current_session_id() + session = app_state.get_current_session() # Allow unknown keys to handle backward/forward compatibility body = await parse_request( request, cls=SaveUserConfigurationRequest, allow_unknown_keys=True @@ -60,13 +67,45 @@ async def save_user_config( cast(PartialMarimoConfig, body.config) ) - background_task: Optional[BackgroundTask] = None - # Update the server's view of the config - if config["completion"]["copilot"]: - LOGGER.debug("Starting copilot server") - background_task = BackgroundTask( - app_state.session_manager.start_lsp_server - ) + async def handle_background_tasks() -> None: + # Update the server's view of the config + if config["completion"]["copilot"]: + LOGGER.debug("Starting copilot server") + await app_state.session_manager.start_lsp_server() + + # Reconfigure MCP servers if config changed + mcp_config = config.get("mcp") + + # Handle missing MCP dependencies + if ( + not is_mcp_config_empty(mcp_config) + and not DependencyManager.mcp.has() + ): + # If we're in an edit session, send an package installation request + if session_id is not None and session is not None: + send_message_to_consumer( + session=session, + operation=MissingPackageAlert( + packages=["mcp"], + isolated=is_python_isolated(), + ), + consumer_id=ConsumerId(session_id), + ) + + try: + from marimo._server.ai.mcp import get_mcp_client + + if mcp_config: + LOGGER.debug("Reconfiguring MCP servers with updated config") + mcp_client = get_mcp_client() + await mcp_client.configure(mcp_config) + LOGGER.info( + f"MCP servers reconfigured: {list(mcp_client.servers.keys())}" + ) + except Exception as e: + LOGGER.warning(f"Failed to reconfigure MCP servers: {e}") + + background_task = BackgroundTask(handle_background_tasks) # Update the kernel's view of the config # Session could be None if the user is on the home page diff --git a/marimo/_server/api/lifespans.py b/marimo/_server/api/lifespans.py index 58e1f758e1e..c8a65c93875 100644 --- a/marimo/_server/api/lifespans.py +++ b/marimo/_server/api/lifespans.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional from marimo import _loggers +from marimo._server.ai.mcp.config import is_mcp_config_empty from marimo._server.ai.tools.tool_manager import setup_tool_manager from marimo._server.api.deps import AppState, AppStateBase from marimo._server.api.interrupt import InterruptHandler @@ -16,7 +17,8 @@ from marimo._server.model import SessionMode from marimo._server.print import ( print_experimental_features, - print_mcp, + print_mcp_client, + print_mcp_server, print_shutdown, print_startup, ) @@ -87,29 +89,32 @@ async def mcp(app: Starlette) -> AsyncIterator[None]: state = AppState.from_app(app) session_mgr = state.session_manager user_config = state.config_manager.get_config() - mcp_docs_enabled = user_config.get("experimental", {}).get( - "mcp_docs", False - ) + mcp_config = user_config.get("mcp") # Only start MCP servers in Edit mode - if session_mgr.mode != SessionMode.EDIT or not mcp_docs_enabled: + if session_mgr.mode != SessionMode.EDIT: yield return - LOGGER.warning("MCP servers are experimental and may not work as expected") + # Only start MCP servers if the config is not empty + if not mcp_config or is_mcp_config_empty(mcp_config): + yield + return async def background_connect_mcp_servers() -> Optional[MCPClient]: try: from marimo._server.ai.mcp import get_mcp_client mcp_client = get_mcp_client() - await mcp_client.connect_to_all_servers() + print_mcp_client(mcp_config) + await mcp_client.configure(mcp_config) + LOGGER.info( f"MCP servers connected: {list(mcp_client.servers.keys())}" ) return mcp_client except Exception as e: - LOGGER.warning(f"Failed to connect MCP servers in background: {e}") + LOGGER.warning(f"Failed to connect MCP servers: {e}") return None task = asyncio.create_task(background_connect_mcp_servers()) @@ -173,7 +178,7 @@ async def logging(app: Starlette) -> AsyncIterator[None]: server_token = None if skew_protection_enabled: server_token = str(state.session_manager.skew_protection_token) - print_mcp(mcp_url, server_token) + print_mcp_server(mcp_url, server_token) yield diff --git a/marimo/_server/print.py b/marimo/_server/print.py index 944a304506b..c5b818b8ec3 100644 --- a/marimo/_server/print.py +++ b/marimo/_server/print.py @@ -6,7 +6,7 @@ from typing import Optional from marimo._cli.print import bold, green, muted -from marimo._config.config import MarimoConfig +from marimo._config.config import MarimoConfig, MCPConfig from marimo._server.utils import print_, print_tabbed UTF8_SUPPORTED = False @@ -146,6 +146,7 @@ def print_experimental_features(config: MarimoConfig) -> None: "reactive_tests", "toplevel_defs", "setup_cell", + "mcp_docs", } keys = keys - finished_experiments @@ -157,7 +158,7 @@ def print_experimental_features(config: MarimoConfig) -> None: ) -def print_mcp(mcp_url: str, server_token: str | None) -> None: +def print_mcp_server(mcp_url: str, server_token: str | None) -> None: """Print MCP server configuration when MCP is enabled.""" print_() print_tabbed( @@ -171,3 +172,15 @@ def print_mcp(mcp_url: str, server_token: str | None) -> None: f"{_utf8('➜')} {green('Add Header')}: Marimo-Server-Token: {muted(server_token)}" ) print_() + + +def print_mcp_client(config: MCPConfig) -> None: + keys = set(config.get("mcpServers", {}).keys()) | set( + config.get("presets", []) + ) + if len(keys) == 0: + return + + print_tabbed( + f"{_utf8('🌐')} {green('MCP servers', bold=True)}: {', '.join(keys)}" + ) diff --git a/packages/openapi/api.yaml b/packages/openapi/api.yaml index f09c60427e6..f8fd8558d30 100644 --- a/packages/openapi/api.yaml +++ b/packages/openapi/api.yaml @@ -2037,6 +2037,29 @@ components: - keys title: ListSecretKeysResponse type: object + MCPConfig: + description: 'Configuration for MCP servers + + + Note: the field name `mcpServers` is camelCased to match MCP server + + config conventions used by popular AI applications (e.g. Cursor, Claude Desktop, + etc.)' + properties: + mcpServers: + additionalProperties: + type: object + type: object + presets: + items: + enum: + - context7 + - marimo + type: array + required: + - mcpServers + title: MCPConfig + type: object MarimoAncestorPreventedError: properties: blamed_cell: @@ -2093,6 +2116,8 @@ components: $ref: '#/components/schemas/KeymapConfig' language_servers: $ref: '#/components/schemas/LanguageServersConfig' + mcp: + $ref: '#/components/schemas/MCPConfig' package_management: $ref: '#/components/schemas/PackageManagementConfig' runtime: @@ -3586,7 +3611,7 @@ components: type: object info: title: marimo API - version: 0.16.2 + version: 0.16.3 openapi: 3.1.0 paths: /@file/{filename_and_length}: diff --git a/packages/openapi/src/api.ts b/packages/openapi/src/api.ts index 40bb0f68886..3ee67d9b42c 100644 --- a/packages/openapi/src/api.ts +++ b/packages/openapi/src/api.ts @@ -3737,6 +3737,19 @@ export interface components { ListSecretKeysResponse: { keys: components["schemas"]["SecretKeysWithProvider"][]; }; + /** + * MCPConfig + * @description Configuration for MCP servers + * + * Note: the field name `mcpServers` is camelCased to match MCP server + * config conventions used by popular AI applications (e.g. Cursor, Claude Desktop, etc.) + */ + MCPConfig: { + mcpServers: { + [key: string]: Record; + }; + presets?: ("context7" | "marimo")[]; + }; /** MarimoAncestorPreventedError */ MarimoAncestorPreventedError: { blamed_cell: string | null; @@ -3766,6 +3779,7 @@ export interface components { formatting: components["schemas"]["FormattingConfig"]; keymap: components["schemas"]["KeymapConfig"]; language_servers?: components["schemas"]["LanguageServersConfig"]; + mcp?: components["schemas"]["MCPConfig"]; package_management: components["schemas"]["PackageManagementConfig"]; runtime: components["schemas"]["RuntimeConfig"]; save: components["schemas"]["SaveConfig"]; diff --git a/tests/_server/ai/test_mcp.py b/tests/_server/ai/test_mcp.py index 70f68f7987e..6b48ef5be06 100644 --- a/tests/_server/ai/test_mcp.py +++ b/tests/_server/ai/test_mcp.py @@ -11,7 +11,9 @@ ) from marimo._dependencies.dependencies import DependencyManager from marimo._server.ai.mcp import ( + MCP_PRESETS, MCPClient, + MCPConfigComparator, MCPServerConnection, MCPServerDefinition, MCPServerDefinitionFactory, @@ -20,6 +22,7 @@ MCPTransportType, StdioTransportConnector, StreamableHTTPTransportConnector, + append_presets, get_mcp_client, ) @@ -197,6 +200,275 @@ def test_from_config_transport_detection( assert server_def.timeout == config_kwargs["timeout"] +class TestMCPConfigComparator: + """Test cases for MCPConfigComparator utility class.""" + + def test_compute_diff_no_changes(self): + """Test that compute_diff detects no changes when configs are identical.""" + server1 = MCPServerDefinition( + name="server1", + transport=MCPTransportType.STDIO, + config=MCPServerStdioConfig(command="test", args=[], env={}), + timeout=30.0, + ) + + current = {"server1": server1} + new = {"server1": server1} + + diff = MCPConfigComparator.compute_diff(current, new) + + assert not diff.has_changes() + assert len(diff.servers_to_add) == 0 + assert len(diff.servers_to_remove) == 0 + assert len(diff.servers_to_update) == 0 + assert "server1" in diff.servers_unchanged + + def test_compute_diff_add_servers(self): + """Test that compute_diff detects new servers.""" + server1 = MCPServerDefinition( + name="server1", + transport=MCPTransportType.STDIO, + config=MCPServerStdioConfig(command="test1", args=[], env={}), + timeout=30.0, + ) + server2 = MCPServerDefinition( + name="server2", + transport=MCPTransportType.STDIO, + config=MCPServerStdioConfig(command="test2", args=[], env={}), + timeout=30.0, + ) + + current = {"server1": server1} + new = {"server1": server1, "server2": server2} + + diff = MCPConfigComparator.compute_diff(current, new) + + assert diff.has_changes() + assert "server2" in diff.servers_to_add + assert len(diff.servers_to_remove) == 0 + assert len(diff.servers_to_update) == 0 + assert "server1" in diff.servers_unchanged + + def test_compute_diff_remove_servers(self): + """Test that compute_diff detects removed servers.""" + server1 = MCPServerDefinition( + name="server1", + transport=MCPTransportType.STDIO, + config=MCPServerStdioConfig(command="test1", args=[], env={}), + timeout=30.0, + ) + server2 = MCPServerDefinition( + name="server2", + transport=MCPTransportType.STDIO, + config=MCPServerStdioConfig(command="test2", args=[], env={}), + timeout=30.0, + ) + + current = {"server1": server1, "server2": server2} + new = {"server1": server1} + + diff = MCPConfigComparator.compute_diff(current, new) + + assert diff.has_changes() + assert "server2" in diff.servers_to_remove + assert len(diff.servers_to_add) == 0 + assert len(diff.servers_to_update) == 0 + assert "server1" in diff.servers_unchanged + + def test_compute_diff_update_servers(self): + """Test that compute_diff detects modified servers.""" + server1_old = MCPServerDefinition( + name="server1", + transport=MCPTransportType.STDIO, + config=MCPServerStdioConfig( + command="test", args=["--old"], env={} + ), + timeout=30.0, + ) + server1_new = MCPServerDefinition( + name="server1", + transport=MCPTransportType.STDIO, + config=MCPServerStdioConfig( + command="test", args=["--new"], env={} + ), + timeout=30.0, + ) + + current = {"server1": server1_old} + new = {"server1": server1_new} + + diff = MCPConfigComparator.compute_diff(current, new) + + assert diff.has_changes() + assert "server1" in diff.servers_to_update + assert len(diff.servers_to_add) == 0 + assert len(diff.servers_to_remove) == 0 + assert len(diff.servers_unchanged) == 0 + + def test_compute_diff_mixed_changes(self): + """Test compute_diff with multiple types of changes.""" + server1 = MCPServerDefinition( + name="unchanged", + transport=MCPTransportType.STDIO, + config=MCPServerStdioConfig(command="test1", args=[], env={}), + timeout=30.0, + ) + server2_old = MCPServerDefinition( + name="updated", + transport=MCPTransportType.STDIO, + config=MCPServerStdioConfig( + command="test2", args=["--old"], env={} + ), + timeout=30.0, + ) + server2_new = MCPServerDefinition( + name="updated", + transport=MCPTransportType.STDIO, + config=MCPServerStdioConfig( + command="test2", args=["--new"], env={} + ), + timeout=30.0, + ) + server3 = MCPServerDefinition( + name="removed", + transport=MCPTransportType.STDIO, + config=MCPServerStdioConfig(command="test3", args=[], env={}), + timeout=30.0, + ) + server4 = MCPServerDefinition( + name="added", + transport=MCPTransportType.STDIO, + config=MCPServerStdioConfig(command="test4", args=[], env={}), + timeout=30.0, + ) + + current = { + "unchanged": server1, + "updated": server2_old, + "removed": server3, + } + new = {"unchanged": server1, "updated": server2_new, "added": server4} + + diff = MCPConfigComparator.compute_diff(current, new) + + assert diff.has_changes() + assert "unchanged" in diff.servers_unchanged + assert "updated" in diff.servers_to_update + assert "removed" in diff.servers_to_remove + assert "added" in diff.servers_to_add + + +class TestMCPPresets: + """Test cases for MCP preset configuration system.""" + + def test_preset_definitions_exist(self): + """Test that expected presets are defined.""" + assert "marimo" in MCP_PRESETS + assert "context7" in MCP_PRESETS + + # Verify preset structure + assert "url" in MCP_PRESETS["marimo"] + assert "url" in MCP_PRESETS["context7"] + + def test_append_presets_no_presets_list(self): + """Test append_presets with config that has no presets list.""" + config = MCPConfig( + mcpServers={ + "custom": MCPServerStdioConfig(command="test", args=[]) + } + ) + + result = append_presets(config) + + # Should return config unchanged + assert "custom" in result["mcpServers"] + assert len(result["mcpServers"]) == 1 + + def test_append_presets_empty_presets_list(self): + """Test append_presets with empty presets list.""" + config = MCPConfig(mcpServers={}, presets=[]) + + result = append_presets(config) + + assert len(result["mcpServers"]) == 0 + + def test_append_presets_adds_marimo_preset(self): + """Test that marimo preset is added when specified.""" + config = MCPConfig(mcpServers={}, presets=["marimo"]) + + result = append_presets(config) + + assert "marimo" in result["mcpServers"] + assert ( + result["mcpServers"]["marimo"]["url"] + == MCP_PRESETS["marimo"]["url"] + ) + + def test_append_presets_adds_context7_preset(self): + """Test that context7 preset is added when specified.""" + config = MCPConfig(mcpServers={}, presets=["context7"]) + + result = append_presets(config) + + assert "context7" in result["mcpServers"] + assert ( + result["mcpServers"]["context7"]["url"] + == MCP_PRESETS["context7"]["url"] + ) + + def test_append_presets_adds_multiple_presets(self): + """Test that multiple presets can be added.""" + config = MCPConfig(mcpServers={}, presets=["marimo", "context7"]) + + result = append_presets(config) + + assert "marimo" in result["mcpServers"] + assert "context7" in result["mcpServers"] + assert len(result["mcpServers"]) == 2 + + def test_append_presets_preserves_existing_servers(self): + """Test that existing servers are preserved when adding presets.""" + config = MCPConfig( + mcpServers={ + "custom": MCPServerStdioConfig(command="test", args=[]) + }, + presets=["marimo"], + ) + + result = append_presets(config) + + assert "custom" in result["mcpServers"] + assert "marimo" in result["mcpServers"] + assert len(result["mcpServers"]) == 2 + + def test_append_presets_does_not_override_existing(self): + """Test that presets don't override existing servers with same name.""" + custom_url = "https://custom.marimo.app/mcp" + config = MCPConfig( + mcpServers={ + "marimo": MCPServerStreamableHttpConfig(url=custom_url) + }, + presets=["marimo"], + ) + + result = append_presets(config) + + # Original server should be preserved + assert result["mcpServers"]["marimo"]["url"] == custom_url + assert len(result["mcpServers"]) == 1 + + def test_append_presets_does_not_mutate_original(self): + """Test that append_presets doesn't mutate the original config.""" + config = MCPConfig(mcpServers={}, presets=["marimo"]) + + result = append_presets(config) + + # Original config should be unchanged + assert "marimo" not in config["mcpServers"] + # Result should have the preset + assert "marimo" in result["mcpServers"] + + class TestMCPTransportConnectors: """Test cases for transport connector classes.""" @@ -288,8 +560,7 @@ class TestMCPClientConfiguration: def test_init_with_empty_config(self): """Test MCPClient initialization with empty config.""" - config = MCPConfig(mcpServers={}) - client = MCPClient(config) + client = MCPClient() assert client.servers == {} assert client.connections == {} assert client.tool_registry == {} @@ -337,7 +608,11 @@ def test_parse_config_valid_servers( ): """Test parsing valid server configurations.""" config = MCPConfig(mcpServers=server_configs) - client = MCPClient(config) + client = MCPClient() + + # Parse the config to populate servers + parsed_servers = client._parse_config(config) + client.servers = parsed_servers assert len(client.servers) == len(expected_servers) for server_name in expected_servers: @@ -346,12 +621,275 @@ def test_parse_config_valid_servers( assert server_def.name == server_name +@pytest.mark.skipif( + not DependencyManager.mcp.has(), reason="MCP SDK not available" +) +class TestMCPClientReconfiguration: + """Test cases for MCPClient dynamic reconfiguration functionality.""" + + async def test_configure_noop_when_no_changes(self, mock_session_setup): + """Test that configure() does nothing when config hasn't changed.""" + del mock_session_setup + config = MCPConfig( + mcpServers={ + "server1": MCPServerStdioConfig( + command="test", args=[], env={} + ) + } + ) + client = MCPClient() + + # Initial configure + await client.configure(config) + + # Track calls to connect_to_server + original_connect = client.connect_to_server + connect_calls = [] + + async def track_connect(server_name: str): + connect_calls.append(server_name) + return await original_connect(server_name) + + client.connect_to_server = track_connect + + # Configure with same config + await client.configure(config) + + # Should not have called connect_to_server + assert len(connect_calls) == 0 + + async def test_configure_adds_new_servers(self, mock_session_setup): + """Test that configure() adds new servers.""" + del mock_session_setup + initial_config = MCPConfig( + mcpServers={ + "server1": MCPServerStdioConfig( + command="test1", args=[], env={} + ) + } + ) + client = MCPClient() + await client.configure(initial_config) + + # New config with additional server + new_config = MCPConfig( + mcpServers={ + "server1": MCPServerStdioConfig( + command="test1", args=[], env={} + ), + "server2": MCPServerStdioConfig( + command="test2", args=[], env={} + ), + } + ) + + # Mock the connection methods + mock_connect = AsyncMock(return_value=True) + with patch.object(client, "connect_to_server", mock_connect): + await client.configure(new_config) + + # Verify server2 was added + assert "server1" in client.servers + assert "server2" in client.servers + assert mock_connect.called + # Should only connect to server2 (the new one) + assert mock_connect.call_count == 1 + mock_connect.assert_called_with("server2") + + async def test_configure_removes_old_servers(self, mock_session_setup): + """Test that configure() removes servers not in new config.""" + del mock_session_setup + initial_config = MCPConfig( + mcpServers={ + "server1": MCPServerStdioConfig( + command="test1", args=[], env={} + ), + "server2": MCPServerStdioConfig( + command="test2", args=[], env={} + ), + } + ) + client = MCPClient() + await client.configure(initial_config) + + # Create mock connections + client.connections["server1"] = create_test_server_connection( + "server1", MCPServerStatus.CONNECTED + ) + client.connections["server2"] = create_test_server_connection( + "server2", MCPServerStatus.CONNECTED + ) + + # New config with only server1 + new_config = MCPConfig( + mcpServers={ + "server1": MCPServerStdioConfig( + command="test1", args=[], env={} + ) + } + ) + + # Mock disconnect_from_server + mock_disconnect = AsyncMock(return_value=True) + with patch.object(client, "disconnect_from_server", mock_disconnect): + await client.configure(new_config) + + # Verify server2 was removed + assert "server1" in client.servers + assert "server2" not in client.servers + assert "server2" not in client.connections + + # Should have called disconnect for server2 + mock_disconnect.assert_called_once_with("server2") + + async def test_configure_updates_modified_servers( + self, mock_session_setup + ): + """Test that configure() reconnects to servers with changed config.""" + del mock_session_setup + initial_config = MCPConfig( + mcpServers={ + "server1": MCPServerStdioConfig( + command="test1", args=["--old"], env={} + ) + } + ) + client = MCPClient() + await client.configure(initial_config) + + # Create mock connection + client.connections["server1"] = create_test_server_connection( + "server1", MCPServerStatus.CONNECTED + ) + + # New config with modified server1 + new_config = MCPConfig( + mcpServers={ + "server1": MCPServerStdioConfig( + command="test1", args=["--new"], env={} + ) + } + ) + + # Mock methods + mock_disconnect = AsyncMock(return_value=True) + mock_connect = AsyncMock(return_value=True) + with ( + patch.object(client, "disconnect_from_server", mock_disconnect), + patch.object(client, "connect_to_server", mock_connect), + ): + await client.configure(new_config) + + # Should have disconnected and reconnected to server1 + mock_disconnect.assert_called_once_with("server1") + mock_connect.assert_called_once_with("server1") + + # Verify config was updated + assert client.servers["server1"].config["args"] == ["--new"] + + async def test_configure_mixed_changes(self, mock_session_setup): + """Test configure() with add, remove, and update operations.""" + del mock_session_setup + initial_config = MCPConfig( + mcpServers={ + "keep_unchanged": MCPServerStdioConfig( + command="test1", args=[], env={} + ), + "to_update": MCPServerStdioConfig( + command="test2", args=["--old"], env={} + ), + "to_remove": MCPServerStdioConfig( + command="test3", args=[], env={} + ), + } + ) + client = MCPClient() + await client.configure(initial_config) + + # Create mock connections + for name in ["keep_unchanged", "to_update", "to_remove"]: + client.connections[name] = create_test_server_connection( + name, MCPServerStatus.CONNECTED + ) + + # New config + new_config = MCPConfig( + mcpServers={ + "keep_unchanged": MCPServerStdioConfig( + command="test1", args=[], env={} + ), + "to_update": MCPServerStdioConfig( + command="test2", args=["--new"], env={} + ), + "to_add": MCPServerStdioConfig( + command="test4", args=[], env={} + ), + } + ) + + # Mock methods + mock_disconnect = AsyncMock(return_value=True) + mock_connect = AsyncMock(return_value=True) + with ( + patch.object(client, "disconnect_from_server", mock_disconnect), + patch.object(client, "connect_to_server", mock_connect), + ): + await client.configure(new_config) + + # Verify results + assert "keep_unchanged" in client.servers + assert "to_update" in client.servers + assert "to_add" in client.servers + assert "to_remove" not in client.servers + assert "to_remove" not in client.connections + + # Verify disconnect was called for removed and updated + assert mock_disconnect.call_count == 2 + disconnect_calls = [ + call[0][0] for call in mock_disconnect.call_args_list + ] + assert "to_remove" in disconnect_calls + assert "to_update" in disconnect_calls + + # Verify connect was called for added and updated + assert mock_connect.call_count == 2 + connect_calls = [call[0][0] for call in mock_connect.call_args_list] + assert "to_add" in connect_calls + assert "to_update" in connect_calls + + async def test_configure_connection_failures_logged( + self, mock_session_setup + ): + """Test that configure() handles connection failures gracefully.""" + del mock_session_setup + initial_config = MCPConfig(mcpServers={}) + client = MCPClient() + await client.configure(initial_config) + + new_config = MCPConfig( + mcpServers={ + "server1": MCPServerStdioConfig( + command="test1", args=[], env={} + ) + } + ) + + # Mock connect_to_server to fail + mock_connect = AsyncMock(side_effect=Exception("Connection failed")) + with patch.object(client, "connect_to_server", mock_connect): + # Should not raise, just log + await client.configure(new_config) + + # Server should still be in registry even if connection failed + assert "server1" in client.servers + + class TestMCPClientToolManagement: """Test cases for MCPClient tool management functionality.""" def test_create_namespaced_tool_name_no_conflict(self): """Test creating namespaced tool name without conflicts.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() name = client._create_namespaced_tool_name("github", "create_issue") assert name == "mcp_github_create_issue" @@ -360,7 +898,7 @@ def test_create_namespaced_tool_name_no_conflict(self): ) def test_create_namespaced_tool_name_with_conflicts(self): """Test creating namespaced tool name with conflicts and counter resolution.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() from mcp.types import Tool @@ -393,7 +931,7 @@ def test_create_namespaced_tool_name_with_conflicts(self): ) def test_add_server_tools(self): """Test adding tools from a server to registry and connection.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() from mcp.types import Tool # Create server connection @@ -433,7 +971,7 @@ def test_add_server_tools(self): ) def test_remove_server_tools(self): """Test removing tools from a server.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() from mcp.types import Tool # Create tools from different servers @@ -519,7 +1057,7 @@ def test_remove_server_tools(self): ) def test_get_tools_by_server(self, server_name, expected_tool_count): """Test getting tools by server name.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() from mcp.types import Tool # Add tools from different servers @@ -554,7 +1092,7 @@ class TestMCPClientToolExecution: def test_create_tool_params(self): """Test creating properly typed CallToolRequestParams.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() # Add a mock tool to the registry mock_tool = create_test_tool() @@ -603,7 +1141,7 @@ async def test_invoke_tool_error_cases( self, tool_setup, connection_setup, expected_error_pattern ): """Test invoke_tool error handling scenarios.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() from mcp.types import Tool # Setup tool if provided @@ -655,7 +1193,7 @@ async def test_invoke_tool_error_cases( async def test_invoke_tool_success(self): """Test successful tool invocation.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() from mcp.types import CallToolResult, TextContent # Setup tool @@ -694,7 +1232,7 @@ async def test_invoke_tool_success(self): async def test_invoke_tool_timeout(self): """Test tool invocation timeout handling.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() # Setup tool mock_tool = create_test_tool() @@ -757,7 +1295,7 @@ def test_result_handling_helpers( """Test CallToolResult helper methods.""" from mcp.types import CallToolResult, TextContent - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() # Create result content = [TextContent(**item) for item in result_content] @@ -784,7 +1322,7 @@ class TestMCPClientConnectionManagement: async def test_discover_tools_success(self): """Test successful tool discovery from an MCP server.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() from mcp.types import ListToolsResult, Tool # Create mock connection with session @@ -820,7 +1358,7 @@ async def test_discover_tools_success(self): async def test_discover_tools_no_session(self): """Test tool discovery with no active session.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() # Create connection without session connection = create_test_server_connection(session=None) @@ -864,7 +1402,8 @@ async def test_connect_to_server_success( ) } ) - client = MCPClient(config) + client = MCPClient() + await client.configure(config) # Test connection result = await client.connect_to_server("test_server") @@ -893,7 +1432,8 @@ async def test_connect_to_server_edge_cases( command="python", args=["test.py"] ) - client = MCPClient(config) + client = MCPClient() + await client.configure(config) if already_connected: # Setup existing connection @@ -942,7 +1482,8 @@ async def test_connect_to_all_servers_mixed_results( ), } ) - client = MCPClient(config) + client = MCPClient() + await client.configure(config) results = await client.connect_to_all_servers() @@ -960,7 +1501,7 @@ class TestMCPClientDisconnectionManagement: async def test_disconnect_from_server_success(self): """Test successful disconnection from a connected server.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() # Setup a connected server using existing patterns connection = create_test_server_connection( @@ -988,7 +1529,7 @@ async def test_disconnect_from_server_success(self): async def test_disconnect_from_server_already_disconnected(self): """Test disconnection from server that's already disconnected.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() # Call disconnect on non-existent server result = await client.disconnect_from_server("nonexistent_server") @@ -998,7 +1539,7 @@ async def test_disconnect_from_server_already_disconnected(self): async def test_disconnect_from_server_with_exception(self): """Test disconnection failure handling (validates our new comment).""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() # Setup connection with task that will raise exception when awaited connection = create_test_server_connection( @@ -1031,7 +1572,7 @@ async def blocking_failing_task(): async def test_disconnect_from_server_cleanup_verification(self): """Test that disconnection properly cleans up server state.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() # Setup connected server with tools and monitoring connection = create_test_server_connection( @@ -1098,7 +1639,7 @@ async def test_disconnect_from_server_cleanup_verification(self): ) async def test_disconnect_from_all_servers_scenarios(self, server_setups): """Test disconnect_from_all_servers with various success/failure combinations.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() # Setup connections based on test parameters for setup in server_setups: @@ -1130,7 +1671,7 @@ async def test_disconnect_from_all_servers_scenarios(self, server_setups): async def test_disconnect_from_all_servers_with_health_monitoring(self): """Test that disconnect_from_all_servers cancels health monitoring first.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() # Setup connections with health monitoring tasks server_names = ["server1", "server2"] @@ -1160,7 +1701,7 @@ async def test_disconnect_from_all_servers_with_health_monitoring(self): async def test_disconnect_cross_task_scenario(self): """Test disconnection in cross-task scenarios (like server shutdown).""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() # Setup connection that simulates cross-task issues connection = create_test_server_connection( @@ -1203,7 +1744,7 @@ class TestMCPClientHealthMonitoring: ) async def test_perform_health_check_success(self): """Test successful health check.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() # Create connection with mock session server_def = MCPServerDefinitionFactory.from_config( @@ -1242,7 +1783,7 @@ async def test_perform_health_check_failure_cases( self, session_setup, ping_behavior, expected_result ): """Test health check failure scenarios.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() # Create connection server_def = MCPServerDefinitionFactory.from_config( @@ -1264,7 +1805,7 @@ async def test_perform_health_check_failure_cases( async def test_perform_health_check_timeout(self): """Test health check timeout handling.""" - client = MCPClient(MCPConfig(mcpServers={})) + client = MCPClient() client.health_check_timeout = 0.1 # Very short timeout # Create connection with session that hangs @@ -1329,12 +1870,12 @@ def test_get_mcp_client_singleton(self): @pytest.mark.skipif( not DependencyManager.mcp.has(), reason="MCP SDK not available" ) - def test_get_mcp_client_with_custom_config(self): + async def test_get_mcp_client_with_custom_config(self): """Test get_mcp_client with custom configuration.""" # Reset global client for this test - import marimo._server.ai.mcp as mcp_module + import marimo._server.ai.mcp.client as client_module - mcp_module._MCP_CLIENT = None + client_module._MCP_CLIENT = None custom_config = MCPConfig( mcpServers={ @@ -1344,5 +1885,6 @@ def test_get_mcp_client_with_custom_config(self): } ) - client = get_mcp_client(custom_config) + client = get_mcp_client() + await client.configure(custom_config) assert "custom_server" in client.servers diff --git a/tests/_server/templates/snapshots/export1.txt b/tests/_server/templates/snapshots/export1.txt index 07847959377..961492ac909 100644 --- a/tests/_server/templates/snapshots/export1.txt +++ b/tests/_server/templates/snapshots/export1.txt @@ -59,7 +59,7 @@ "mode": "read", "version": "0.0.0", "serverToken": "token", - "config": {"ai": {"models": {"custom_models": [], "displayed_models": []}}, "completion": {"activate_on_typing": true, "copilot": false}, "display": {"cell_output": "below", "code_editor_font_size": 14, "dataframes": "rich", "default_table_max_columns": 50, "default_table_page_size": 10, "default_width": "medium", "reference_highlighting": false, "theme": "light"}, "formatting": {"line_length": 79}, "keymap": {"overrides": {}, "preset": "default"}, "language_servers": {"pylsp": {"enable_flake8": false, "enable_mypy": true, "enable_pydocstyle": false, "enable_pyflakes": false, "enable_pylint": false, "enable_ruff": true, "enabled": true}}, "package_management": {"manager": "pixi"}, "runtime": {"auto_instantiate": true, "auto_reload": "off", "default_sql_output": "auto", "on_cell_change": "autorun", "output_max_bytes": 8000000, "reactive_tests": true, "std_stream_max_bytes": 1000000, "watcher_on_save": "lazy"}, "save": {"autosave": "after_delay", "autosave_delay": 1000, "format_on_save": false}, "server": {"browser": "default", "follow_symlink": false}, "snippets": {"custom_paths": [], "include_default_snippets": true}}, + "config": {"ai": {"models": {"custom_models": [], "displayed_models": []}}, "completion": {"activate_on_typing": true, "copilot": false}, "display": {"cell_output": "below", "code_editor_font_size": 14, "dataframes": "rich", "default_table_max_columns": 50, "default_table_page_size": 10, "default_width": "medium", "reference_highlighting": false, "theme": "light"}, "formatting": {"line_length": 79}, "keymap": {"overrides": {}, "preset": "default"}, "language_servers": {"pylsp": {"enable_flake8": false, "enable_mypy": true, "enable_pydocstyle": false, "enable_pyflakes": false, "enable_pylint": false, "enable_ruff": true, "enabled": true}}, "mcp": {"mcpServers": {}, "presets": []}, "package_management": {"manager": "pixi"}, "runtime": {"auto_instantiate": true, "auto_reload": "off", "default_sql_output": "auto", "on_cell_change": "autorun", "output_max_bytes": 8000000, "reactive_tests": true, "std_stream_max_bytes": 1000000, "watcher_on_save": "lazy"}, "save": {"autosave": "after_delay", "autosave_delay": 1000, "format_on_save": false}, "server": {"browser": "default", "follow_symlink": false}, "snippets": {"custom_paths": [], "include_default_snippets": true}}, "configOverrides": {"formatting": {"line_length": 100}}, "appConfig": {"sql_output": "auto", "width": "compact"}, "view": {"showAppCode": true}, diff --git a/tests/_server/templates/snapshots/export2.txt b/tests/_server/templates/snapshots/export2.txt index 535eb3d9b97..abb8a2ac667 100644 --- a/tests/_server/templates/snapshots/export2.txt +++ b/tests/_server/templates/snapshots/export2.txt @@ -59,7 +59,7 @@ "mode": "read", "version": "0.0.0", "serverToken": "token", - "config": {"ai": {"models": {"custom_models": [], "displayed_models": []}}, "completion": {"activate_on_typing": true, "copilot": false}, "display": {"cell_output": "below", "code_editor_font_size": 14, "dataframes": "rich", "default_table_max_columns": 50, "default_table_page_size": 10, "default_width": "medium", "reference_highlighting": false, "theme": "light"}, "formatting": {"line_length": 79}, "keymap": {"overrides": {}, "preset": "default"}, "language_servers": {"pylsp": {"enable_flake8": false, "enable_mypy": true, "enable_pydocstyle": false, "enable_pyflakes": false, "enable_pylint": false, "enable_ruff": true, "enabled": true}}, "package_management": {"manager": "pixi"}, "runtime": {"auto_instantiate": true, "auto_reload": "off", "default_sql_output": "auto", "on_cell_change": "autorun", "output_max_bytes": 8000000, "reactive_tests": true, "std_stream_max_bytes": 1000000, "watcher_on_save": "lazy"}, "save": {"autosave": "after_delay", "autosave_delay": 1000, "format_on_save": false}, "server": {"browser": "default", "follow_symlink": false}, "snippets": {"custom_paths": [], "include_default_snippets": true}}, + "config": {"ai": {"models": {"custom_models": [], "displayed_models": []}}, "completion": {"activate_on_typing": true, "copilot": false}, "display": {"cell_output": "below", "code_editor_font_size": 14, "dataframes": "rich", "default_table_max_columns": 50, "default_table_page_size": 10, "default_width": "medium", "reference_highlighting": false, "theme": "light"}, "formatting": {"line_length": 79}, "keymap": {"overrides": {}, "preset": "default"}, "language_servers": {"pylsp": {"enable_flake8": false, "enable_mypy": true, "enable_pydocstyle": false, "enable_pyflakes": false, "enable_pylint": false, "enable_ruff": true, "enabled": true}}, "mcp": {"mcpServers": {}, "presets": []}, "package_management": {"manager": "pixi"}, "runtime": {"auto_instantiate": true, "auto_reload": "off", "default_sql_output": "auto", "on_cell_change": "autorun", "output_max_bytes": 8000000, "reactive_tests": true, "std_stream_max_bytes": 1000000, "watcher_on_save": "lazy"}, "save": {"autosave": "after_delay", "autosave_delay": 1000, "format_on_save": false}, "server": {"browser": "default", "follow_symlink": false}, "snippets": {"custom_paths": [], "include_default_snippets": true}}, "configOverrides": {"formatting": {"line_length": 100}}, "appConfig": {"sql_output": "auto", "width": "compact"}, "view": {"showAppCode": true}, diff --git a/tests/_server/templates/snapshots/export3.txt b/tests/_server/templates/snapshots/export3.txt index 1ff6b8a8f5f..658b38ef273 100644 --- a/tests/_server/templates/snapshots/export3.txt +++ b/tests/_server/templates/snapshots/export3.txt @@ -59,7 +59,7 @@ "mode": "read", "version": "0.0.0", "serverToken": "token", - "config": {"ai": {"models": {"custom_models": [], "displayed_models": []}}, "completion": {"activate_on_typing": true, "copilot": false}, "display": {"cell_output": "below", "code_editor_font_size": 14, "dataframes": "rich", "default_table_max_columns": 50, "default_table_page_size": 10, "default_width": "medium", "reference_highlighting": false, "theme": "light"}, "formatting": {"line_length": 79}, "keymap": {"overrides": {}, "preset": "default"}, "language_servers": {"pylsp": {"enable_flake8": false, "enable_mypy": true, "enable_pydocstyle": false, "enable_pyflakes": false, "enable_pylint": false, "enable_ruff": true, "enabled": true}}, "package_management": {"manager": "pixi"}, "runtime": {"auto_instantiate": true, "auto_reload": "off", "default_sql_output": "auto", "on_cell_change": "autorun", "output_max_bytes": 8000000, "reactive_tests": true, "std_stream_max_bytes": 1000000, "watcher_on_save": "lazy"}, "save": {"autosave": "after_delay", "autosave_delay": 1000, "format_on_save": false}, "server": {"browser": "default", "follow_symlink": false}, "snippets": {"custom_paths": [], "include_default_snippets": true}}, + "config": {"ai": {"models": {"custom_models": [], "displayed_models": []}}, "completion": {"activate_on_typing": true, "copilot": false}, "display": {"cell_output": "below", "code_editor_font_size": 14, "dataframes": "rich", "default_table_max_columns": 50, "default_table_page_size": 10, "default_width": "medium", "reference_highlighting": false, "theme": "light"}, "formatting": {"line_length": 79}, "keymap": {"overrides": {}, "preset": "default"}, "language_servers": {"pylsp": {"enable_flake8": false, "enable_mypy": true, "enable_pydocstyle": false, "enable_pyflakes": false, "enable_pylint": false, "enable_ruff": true, "enabled": true}}, "mcp": {"mcpServers": {}, "presets": []}, "package_management": {"manager": "pixi"}, "runtime": {"auto_instantiate": true, "auto_reload": "off", "default_sql_output": "auto", "on_cell_change": "autorun", "output_max_bytes": 8000000, "reactive_tests": true, "std_stream_max_bytes": 1000000, "watcher_on_save": "lazy"}, "save": {"autosave": "after_delay", "autosave_delay": 1000, "format_on_save": false}, "server": {"browser": "default", "follow_symlink": false}, "snippets": {"custom_paths": [], "include_default_snippets": true}}, "configOverrides": {"formatting": {"line_length": 100}}, "appConfig": {"sql_output": "auto", "width": "compact"}, "view": {"showAppCode": true}, diff --git a/tests/_server/templates/snapshots/export4.txt b/tests/_server/templates/snapshots/export4.txt index 84621c4bfd2..1e142889085 100644 --- a/tests/_server/templates/snapshots/export4.txt +++ b/tests/_server/templates/snapshots/export4.txt @@ -59,7 +59,7 @@ "mode": "read", "version": "0.0.0", "serverToken": "token", - "config": {"ai": {"models": {"custom_models": [], "displayed_models": []}}, "completion": {"activate_on_typing": true, "copilot": false}, "display": {"cell_output": "below", "code_editor_font_size": 14, "dataframes": "rich", "default_table_max_columns": 50, "default_table_page_size": 10, "default_width": "medium", "reference_highlighting": false, "theme": "light"}, "formatting": {"line_length": 79}, "keymap": {"overrides": {}, "preset": "default"}, "language_servers": {"pylsp": {"enable_flake8": false, "enable_mypy": true, "enable_pydocstyle": false, "enable_pyflakes": false, "enable_pylint": false, "enable_ruff": true, "enabled": true}}, "package_management": {"manager": "pixi"}, "runtime": {"auto_instantiate": true, "auto_reload": "off", "default_sql_output": "auto", "on_cell_change": "autorun", "output_max_bytes": 8000000, "reactive_tests": true, "std_stream_max_bytes": 1000000, "watcher_on_save": "lazy"}, "save": {"autosave": "after_delay", "autosave_delay": 1000, "format_on_save": false}, "server": {"browser": "default", "follow_symlink": false}, "snippets": {"custom_paths": [], "include_default_snippets": true}}, + "config": {"ai": {"models": {"custom_models": [], "displayed_models": []}}, "completion": {"activate_on_typing": true, "copilot": false}, "display": {"cell_output": "below", "code_editor_font_size": 14, "dataframes": "rich", "default_table_max_columns": 50, "default_table_page_size": 10, "default_width": "medium", "reference_highlighting": false, "theme": "light"}, "formatting": {"line_length": 79}, "keymap": {"overrides": {}, "preset": "default"}, "language_servers": {"pylsp": {"enable_flake8": false, "enable_mypy": true, "enable_pydocstyle": false, "enable_pyflakes": false, "enable_pylint": false, "enable_ruff": true, "enabled": true}}, "mcp": {"mcpServers": {}, "presets": []}, "package_management": {"manager": "pixi"}, "runtime": {"auto_instantiate": true, "auto_reload": "off", "default_sql_output": "auto", "on_cell_change": "autorun", "output_max_bytes": 8000000, "reactive_tests": true, "std_stream_max_bytes": 1000000, "watcher_on_save": "lazy"}, "save": {"autosave": "after_delay", "autosave_delay": 1000, "format_on_save": false}, "server": {"browser": "default", "follow_symlink": false}, "snippets": {"custom_paths": [], "include_default_snippets": true}}, "configOverrides": {"formatting": {"line_length": 100}}, "appConfig": {"css_file": "custom.css", "sql_output": "auto", "width": "compact"}, "view": {"showAppCode": true}, diff --git a/tests/_server/templates/snapshots/export5.txt b/tests/_server/templates/snapshots/export5.txt index 9f61f9a0ae5..80ba99eae4b 100644 --- a/tests/_server/templates/snapshots/export5.txt +++ b/tests/_server/templates/snapshots/export5.txt @@ -59,7 +59,7 @@ "mode": "read", "version": "0.0.0", "serverToken": "token", - "config": {"ai": {"models": {"custom_models": [], "displayed_models": []}}, "completion": {"activate_on_typing": true, "copilot": false}, "display": {"cell_output": "below", "code_editor_font_size": 14, "dataframes": "rich", "default_table_max_columns": 50, "default_table_page_size": 10, "default_width": "medium", "reference_highlighting": false, "theme": "light"}, "formatting": {"line_length": 79}, "keymap": {"overrides": {}, "preset": "default"}, "language_servers": {"pylsp": {"enable_flake8": false, "enable_mypy": true, "enable_pydocstyle": false, "enable_pyflakes": false, "enable_pylint": false, "enable_ruff": true, "enabled": true}}, "package_management": {"manager": "pixi"}, "runtime": {"auto_instantiate": true, "auto_reload": "off", "default_sql_output": "auto", "on_cell_change": "autorun", "output_max_bytes": 8000000, "reactive_tests": true, "std_stream_max_bytes": 1000000, "watcher_on_save": "lazy"}, "save": {"autosave": "after_delay", "autosave_delay": 1000, "format_on_save": false}, "server": {"browser": "default", "follow_symlink": false}, "snippets": {"custom_paths": [], "include_default_snippets": true}}, + "config": {"ai": {"models": {"custom_models": [], "displayed_models": []}}, "completion": {"activate_on_typing": true, "copilot": false}, "display": {"cell_output": "below", "code_editor_font_size": 14, "dataframes": "rich", "default_table_max_columns": 50, "default_table_page_size": 10, "default_width": "medium", "reference_highlighting": false, "theme": "light"}, "formatting": {"line_length": 79}, "keymap": {"overrides": {}, "preset": "default"}, "language_servers": {"pylsp": {"enable_flake8": false, "enable_mypy": true, "enable_pydocstyle": false, "enable_pyflakes": false, "enable_pylint": false, "enable_ruff": true, "enabled": true}}, "mcp": {"mcpServers": {}, "presets": []}, "package_management": {"manager": "pixi"}, "runtime": {"auto_instantiate": true, "auto_reload": "off", "default_sql_output": "auto", "on_cell_change": "autorun", "output_max_bytes": 8000000, "reactive_tests": true, "std_stream_max_bytes": 1000000, "watcher_on_save": "lazy"}, "save": {"autosave": "after_delay", "autosave_delay": 1000, "format_on_save": false}, "server": {"browser": "default", "follow_symlink": false}, "snippets": {"custom_paths": [], "include_default_snippets": true}}, "configOverrides": {"formatting": {"line_length": 100}}, "appConfig": {"app_title": "My App", "html_head_file": "head.html", "sql_output": "auto", "width": "compact"}, "view": {"showAppCode": true}, diff --git a/tests/_server/templates/snapshots/export6.txt b/tests/_server/templates/snapshots/export6.txt index 6b1e2e64c63..88891f7c391 100644 --- a/tests/_server/templates/snapshots/export6.txt +++ b/tests/_server/templates/snapshots/export6.txt @@ -59,7 +59,7 @@ "mode": "read", "version": "0.0.0", "serverToken": "token", - "config": {"ai": {"models": {"custom_models": [], "displayed_models": []}}, "completion": {"activate_on_typing": true, "copilot": false}, "display": {"cell_output": "below", "code_editor_font_size": 14, "custom_css": ["custom1.css", "custom2.css"], "dataframes": "rich", "default_table_max_columns": 50, "default_table_page_size": 10, "default_width": "medium", "reference_highlighting": false, "theme": "light"}, "formatting": {"line_length": 79}, "keymap": {"overrides": {}, "preset": "default"}, "language_servers": {"pylsp": {"enable_flake8": false, "enable_mypy": true, "enable_pydocstyle": false, "enable_pyflakes": false, "enable_pylint": false, "enable_ruff": true, "enabled": true}}, "package_management": {"manager": "pixi"}, "runtime": {"auto_instantiate": true, "auto_reload": "off", "default_sql_output": "auto", "on_cell_change": "autorun", "output_max_bytes": 8000000, "reactive_tests": true, "std_stream_max_bytes": 1000000, "watcher_on_save": "lazy"}, "save": {"autosave": "after_delay", "autosave_delay": 1000, "format_on_save": false}, "server": {"browser": "default", "follow_symlink": false}, "snippets": {"custom_paths": [], "include_default_snippets": true}}, + "config": {"ai": {"models": {"custom_models": [], "displayed_models": []}}, "completion": {"activate_on_typing": true, "copilot": false}, "display": {"cell_output": "below", "code_editor_font_size": 14, "custom_css": ["custom1.css", "custom2.css"], "dataframes": "rich", "default_table_max_columns": 50, "default_table_page_size": 10, "default_width": "medium", "reference_highlighting": false, "theme": "light"}, "formatting": {"line_length": 79}, "keymap": {"overrides": {}, "preset": "default"}, "language_servers": {"pylsp": {"enable_flake8": false, "enable_mypy": true, "enable_pydocstyle": false, "enable_pyflakes": false, "enable_pylint": false, "enable_ruff": true, "enabled": true}}, "mcp": {"mcpServers": {}, "presets": []}, "package_management": {"manager": "pixi"}, "runtime": {"auto_instantiate": true, "auto_reload": "off", "default_sql_output": "auto", "on_cell_change": "autorun", "output_max_bytes": 8000000, "reactive_tests": true, "std_stream_max_bytes": 1000000, "watcher_on_save": "lazy"}, "save": {"autosave": "after_delay", "autosave_delay": 1000, "format_on_save": false}, "server": {"browser": "default", "follow_symlink": false}, "snippets": {"custom_paths": [], "include_default_snippets": true}}, "configOverrides": {"formatting": {"line_length": 100}}, "appConfig": {"sql_output": "auto", "width": "compact"}, "view": {"showAppCode": true},