diff --git a/frontend/src/components/app-config/mcp-config.tsx b/frontend/src/components/app-config/mcp-config.tsx index 253d3bdd595..c8c4a03eec5 100644 --- a/frontend/src/components/app-config/mcp-config.tsx +++ b/frontend/src/components/app-config/mcp-config.tsx @@ -1,6 +1,6 @@ /* Copyright 2024 Marimo. All rights reserved. */ -import { CheckSquareIcon } from "lucide-react"; +import { CheckSquareIcon, Loader2, RefreshCwIcon } from "lucide-react"; import React from "react"; import type { UseFormReturn } from "react-hook-form"; import { @@ -12,6 +12,8 @@ import { } from "@/components/ui/card"; import { FormField, FormItem } from "@/components/ui/form"; import type { UserConfig } from "@/core/config/config-schema"; +import { useMCPRefresh, useMCPStatus } from "../mcp/hooks"; +import { McpStatusText } from "../mcp/mcp-status-indicator"; import { Button } from "../ui/button"; import { Kbd } from "../ui/kbd"; import { SettingSubtitle } from "./common"; @@ -45,10 +47,48 @@ const PRESET_CONFIGS: PresetConfig[] = [ export const MCPConfig: React.FC = ({ form, onSubmit }) => { const { handleClick } = useOpenSettingsToTab(); + const { data: status, refetch, isFetching } = useMCPStatus(); + const { refresh, isRefreshing } = useMCPRefresh(); + + const handleRefresh = async () => { + await refresh(); + refetch(); + }; return (
- MCP Servers +
+ MCP Servers +
+ {status && } + +
+
+ {status?.error && ( +
+ {status.error} +
+ )} + {status?.servers && ( +
+ {Object.entries(status.servers).map(([server, status]) => ( +
+ {server}: +
+ ))} +
+ )}

Enable Model Context Protocol (MCP) servers to provide additional capabilities and data sources for AI features. diff --git a/frontend/src/components/chat/chat-panel.tsx b/frontend/src/components/chat/chat-panel.tsx index ecd1f4bf97b..dce1e4263b8 100644 --- a/frontend/src/components/chat/chat-panel.tsx +++ b/frontend/src/components/chat/chat-panel.tsx @@ -63,6 +63,7 @@ import { } from "../editor/ai/completion-utils"; import { PanelEmptyState } from "../editor/chrome/panels/empty-state"; import { CopyClipboardIcon } from "../icons/copy-icon"; +import { MCPStatusIndicator } from "../mcp/mcp-status-indicator"; import { Input } from "../ui/input"; import { Tooltip, TooltipProvider } from "../ui/tooltip"; import { toast } from "../ui/use-toast"; @@ -120,6 +121,7 @@ const ChatHeader: React.FC = ({

+ + + + +
+
+

MCP Server Status

+ +
+ {status && ( +
+ {hasServers && ( +
+ Overall: + +
+ )} + {status.error && ( +
+ {status.error} +
+ )} + {hasServers && ( +
+
+ Servers: +
+ {Object.entries(servers).map(([name, serverStatus]) => ( +
+ + {name} + + +
+ ))} +
+ )} + {!hasServers && ( +
+ No MCP servers configured.
Configure under{" "} + Settings > AI > MCP +
+ )} +
+ )} +
+
+ + ); +}; + +export const McpStatusText: React.FC<{ + status: + | "ok" + | "partial" + | "error" + | "failed" + | "disconnected" + | "pending" + | "connected"; +}> = ({ status }) => { + return ( + + {status} + + ); +}; diff --git a/frontend/src/core/network/CachingRequestRegistry.ts b/frontend/src/core/network/CachingRequestRegistry.ts index 87cce30563e..549be17c44b 100644 --- a/frontend/src/core/network/CachingRequestRegistry.ts +++ b/frontend/src/core/network/CachingRequestRegistry.ts @@ -56,9 +56,9 @@ export class CachingRequestRegistry { const promise = this.delegate.request(req); this.cache.set(key, promise); - return promise.catch((err) => { + return promise.catch((error) => { this.cache.delete(key); - throw err; + throw error; }); } diff --git a/marimo/_cli/development/commands.py b/marimo/_cli/development/commands.py index a3a71781a94..4f4231c96b2 100644 --- a/marimo/_cli/development/commands.py +++ b/marimo/_cli/development/commands.py @@ -209,6 +209,8 @@ def _generate_server_api_schema() -> dict[str, Any]: models.UpdateComponentValuesRequest, models.InvokeAiToolRequest, models.InvokeAiToolResponse, + models.MCPStatusResponse, + models.MCPRefreshResponse, requests.CodeCompletionRequest, requests.DeleteCellRequest, requests.HTTPRequest, diff --git a/marimo/_server/api/endpoints/ai.py b/marimo/_server/api/endpoints/ai.py index 298a7a1e6e8..0155646fa1a 100644 --- a/marimo/_server/api/endpoints/ai.py +++ b/marimo/_server/api/endpoints/ai.py @@ -1,7 +1,7 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from starlette.authentication import requires from starlette.exceptions import HTTPException @@ -22,6 +22,7 @@ get_edit_model, get_max_tokens, ) +from marimo._server.ai.mcp import MCPServerStatus, get_mcp_client from marimo._server.ai.prompts import ( FIM_MIDDLE_TAG, FIM_PREFIX_TAG, @@ -47,6 +48,8 @@ from marimo._server.models.models import ( InvokeAiToolRequest, InvokeAiToolResponse, + MCPRefreshResponse, + MCPStatusResponse, ) from marimo._server.responses import StructResponse from marimo._server.router import APIRouter @@ -379,3 +382,184 @@ async def invoke_tool( error=f"Tool invocation failed: {str(e)}", ) ) + + +@router.get("/mcp/status") +@requires("edit") +async def mcp_status( + *, + request: Request, +) -> Response: + """ + responses: + 200: + description: Get MCP server status + content: + application/json: + schema: + $ref: "#/components/schemas/MCPStatusResponse" + """ + app_state = AppState(request) + app_state.require_current_session() + + try: + # Try to get MCP client + mcp_client = get_mcp_client() + + # Get all server statuses + server_statuses = mcp_client.get_all_server_statuses() + + # Map internal status enum to API status strings + status_map: dict[ + MCPServerStatus, + Literal["pending", "connected", "disconnected", "failed"], + ] = { + MCPServerStatus.CONNECTED: "connected", + MCPServerStatus.CONNECTING: "pending", + MCPServerStatus.DISCONNECTED: "disconnected", + MCPServerStatus.ERROR: "failed", + } + + servers = { + name: status_map.get(status, "failed") + for name, status in server_statuses.items() + } + + # Determine overall status + overall_status: Literal["ok", "partial", "error"] = "ok" + if not servers: + # No servers configured + overall_status = "ok" + error = None + elif all(s == "connected" for s in servers.values()): + # All servers connected + overall_status = "ok" + error = None + elif any(s == "connected" for s in servers.values()): + # Some servers connected + overall_status = "partial" + failed_servers = [ + name for name, status in servers.items() if status == "failed" + ] + error = ( + f"Some servers failed to connect: {', '.join(failed_servers)}" + ) + else: + # No servers connected or all failed + overall_status = "error" + error = "No MCP servers connected" + + return StructResponse( + MCPStatusResponse( + status=overall_status, + error=error, + servers=servers, + ) + ) + + except ModuleNotFoundError: + # MCP dependencies not installed + return StructResponse( + MCPStatusResponse( + status="error", + error="Missing dependencies. Install with: pip install marimo[mcp]", + servers={}, + ) + ) + except Exception as e: + LOGGER.error(f"Error getting MCP status: {e}") + return StructResponse( + MCPStatusResponse( + status="error", + error=str(e), + servers={}, + ) + ) + + +@router.post("/mcp/refresh") +@requires("edit") +async def mcp_refresh( + *, + request: Request, +) -> Response: + """ + responses: + 200: + description: Refresh MCP server configuration + content: + application/json: + schema: + $ref: "#/components/schemas/MCPRefreshResponse" + """ + app_state = AppState(request) + app_state.require_current_session() + + try: + # Get the MCP client + mcp_client = get_mcp_client() + + # Get current config + config = app_state.app_config_manager.get_config(hide_secrets=False) + mcp_config = config.get("mcp") + + if mcp_config is None: + return StructResponse( + MCPRefreshResponse( + success=False, + error="MCP configuration is not set", + servers={}, + ) + ) + + # Reconfigure the client with the current configuration + # This will handle disconnecting/reconnecting as needed + await mcp_client.configure(mcp_config) + + # Get updated server statuses + server_statuses = mcp_client.get_all_server_statuses() + + # Map status to success boolean + servers = { + name: status == MCPServerStatus.CONNECTED + for name, status in server_statuses.items() + } + + # Overall success if all servers are connected (or no servers) + success = len(servers) == 0 or all(servers.values()) + + error = None + if not success: + failed_servers = [ + name for name, connected in servers.items() if not connected + ] + error = ( + f"Some servers failed to connect: {', '.join(failed_servers)}" + ) + + return StructResponse( + MCPRefreshResponse( + success=success, + error=error, + servers=servers, + ) + ) + + except ModuleNotFoundError: + # MCP dependencies not installed + return StructResponse( + MCPRefreshResponse( + success=False, + error="Missing dependencies. Install with: pip install marimo[mcp]", + servers={}, + ) + ) + except Exception as e: + LOGGER.error(f"Error refreshing MCP: {e}") + return StructResponse( + MCPRefreshResponse( + success=False, + error=str(e), + servers={}, + ) + ) diff --git a/marimo/_server/api/endpoints/export.py b/marimo/_server/api/endpoints/export.py index 41b4c846cf9..106de58ca2d 100644 --- a/marimo/_server/api/endpoints/export.py +++ b/marimo/_server/api/endpoints/export.py @@ -9,6 +9,7 @@ from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse from marimo import _loggers +from marimo._dependencies.dependencies import DependencyManager from marimo._messaging.msgspec_encoder import asdict from marimo._server.api.deps import AppState from marimo._server.api.status import HTTPStatus @@ -350,6 +351,11 @@ async def auto_export_as_ipynb( return PlainTextResponse(status_code=HTTPStatus.NOT_MODIFIED) async def _background_export() -> None: + # Check has nbformat installed + if not DependencyManager.nbformat.has(): + LOGGER.error("Cannot snapshot to IPYNB: nbformat not installed") + return + # Reload the file manager to get the latest state session.app_file_manager.reload() diff --git a/marimo/_server/models/models.py b/marimo/_server/models/models.py index bd0857dc8ef..51a58a6aab0 100644 --- a/marimo/_server/models/models.py +++ b/marimo/_server/models/models.py @@ -2,7 +2,7 @@ from __future__ import annotations import os -from typing import Any, Optional +from typing import Any, Literal, Optional import msgspec @@ -161,3 +161,16 @@ class InvokeAiToolResponse(BaseResponse): tool_name: str result: Any error: Optional[str] = None + + +class MCPStatusResponse(msgspec.Struct, rename="camel"): + status: Literal["ok", "partial", "error"] + error: Optional[str] = None + servers: dict[ + str, Literal["pending", "connected", "disconnected", "failed"] + ] = {} # server_name -> status + + +class MCPRefreshResponse(BaseResponse): + error: Optional[str] = None + servers: dict[str, bool] = {} # server_name -> connected diff --git a/packages/openapi/api.yaml b/packages/openapi/api.yaml index f8fd8558d30..56f98178f38 100644 --- a/packages/openapi/api.yaml +++ b/packages/openapi/api.yaml @@ -2060,6 +2060,49 @@ components: - mcpServers title: MCPConfig type: object + MCPRefreshResponse: + properties: + error: + anyOf: + - type: string + - type: 'null' + default: null + servers: + additionalProperties: + type: boolean + default: {} + type: object + success: + type: boolean + required: + - success + title: MCPRefreshResponse + type: object + MCPStatusResponse: + properties: + error: + anyOf: + - type: string + - type: 'null' + default: null + servers: + additionalProperties: + enum: + - connected + - disconnected + - failed + - pending + default: {} + type: object + status: + enum: + - error + - ok + - partial + required: + - status + title: MCPStatusResponse + type: object MarimoAncestorPreventedError: properties: blamed_cell: @@ -3690,6 +3733,24 @@ paths: schema: $ref: '#/components/schemas/InvokeAiToolResponse' description: Tool invocation result + /api/ai/mcp/refresh: + post: + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/MCPRefreshResponse' + description: Refresh MCP server configuration + /api/ai/mcp/status: + get: + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/MCPStatusResponse' + description: Get MCP server status /api/datasources/preview_column: post: requestBody: diff --git a/packages/openapi/src/api.ts b/packages/openapi/src/api.ts index 3ee67d9b42c..2fc720898af 100644 --- a/packages/openapi/src/api.ts +++ b/packages/openapi/src/api.ts @@ -201,6 +201,76 @@ export interface paths { patch?: never; trace?: never; }; + "/api/ai/mcp/refresh": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + get?: never; + put?: never; + post: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Refresh MCP server configuration */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["MCPRefreshResponse"]; + }; + }; + }; + }; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; + "/api/ai/mcp/status": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + get: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Get MCP server status */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["MCPStatusResponse"]; + }; + }; + }; + }; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/datasources/preview_column": { parameters: { query?: never; @@ -3750,6 +3820,27 @@ export interface components { }; presets?: ("context7" | "marimo")[]; }; + /** MCPRefreshResponse */ + MCPRefreshResponse: { + /** @default null */ + error?: string | null; + /** @default {} */ + servers?: { + [key: string]: boolean; + }; + success: boolean; + }; + /** MCPStatusResponse */ + MCPStatusResponse: { + /** @default null */ + error?: string | null; + /** @default {} */ + servers?: { + [key: string]: "connected" | "disconnected" | "failed" | "pending"; + }; + /** @enum {unknown} */ + status: "error" | "ok" | "partial"; + }; /** MarimoAncestorPreventedError */ MarimoAncestorPreventedError: { blamed_cell: string | null; diff --git a/tests/_server/ai/test_mcp.py b/tests/_server/ai/test_mcp.py index 6b48ef5be06..f3279f864e4 100644 --- a/tests/_server/ai/test_mcp.py +++ b/tests/_server/ai/test_mcp.py @@ -1447,6 +1447,7 @@ async def test_connect_to_server_edge_cases( result = await client.connect_to_server("test_server") assert result == expected_result + @pytest.mark.xfail(reason="Flaky test") @patch("mcp.ClientSession") async def test_connect_to_all_servers_mixed_results( self, mock_session_class diff --git a/tests/_server/api/endpoints/test_ai.py b/tests/_server/api/endpoints/test_ai.py index 188c2d90dd7..2e6256686c9 100644 --- a/tests/_server/api/endpoints/test_ai.py +++ b/tests/_server/api/endpoints/test_ai.py @@ -1294,6 +1294,46 @@ def test_invoke_tool_without_session(client: TestClient) -> None: # Should fail without proper session assert response.status_code in [400, 401, 403], response.text + +class TestMCPEndpoints: + """Tests for MCP status and refresh endpoints.""" + + @staticmethod + @with_session(SESSION_ID) + def test_mcp_status(client: TestClient) -> None: + """Test MCP status endpoint returns error when dependencies not installed.""" + response = client.get( + "/api/ai/mcp/status", + headers=HEADERS, + ) + + assert response.status_code == 200, response.text + data = response.json() + + # Should have required fields + assert "status" in data + assert "servers" in data + # Will likely error due to missing dependencies or no config + assert data["status"] in ["ok", "partial", "error"] + + @staticmethod + @with_session(SESSION_ID) + def test_mcp_refresh(client: TestClient) -> None: + """Test MCP refresh endpoint returns error when dependencies not installed.""" + response = client.post( + "/api/ai/mcp/refresh", + headers=HEADERS, + ) + + assert response.status_code == 200, response.text + data = response.json() + + # Should have required fields + assert "success" in data + assert "servers" in data + # Will likely fail due to missing dependencies or no config + assert isinstance(data["success"], bool) + @staticmethod @with_session(SESSION_ID) @patch("marimo._server.api.endpoints.ai.get_tool_manager")