diff --git a/marimo/_mcp/server/lifespan.py b/marimo/_mcp/server/lifespan.py index bd98f66f9ca..d81e36f7c05 100644 --- a/marimo/_mcp/server/lifespan.py +++ b/marimo/_mcp/server/lifespan.py @@ -1,17 +1,19 @@ # Copyright 2025 Marimo. All rights reserved. import contextlib from collections.abc import AsyncIterator - -from starlette.applications import Starlette +from typing import TYPE_CHECKING from marimo._loggers import marimo_logger from marimo._mcp.server.main import setup_mcp_server LOGGER = marimo_logger() +if TYPE_CHECKING: + from starlette.applications import Starlette + @contextlib.asynccontextmanager -async def mcp_server_lifespan(app: Starlette) -> AsyncIterator[None]: +async def mcp_server_lifespan(app: "Starlette") -> AsyncIterator[None]: """Lifespan for MCP server functionality (exposing marimo as MCP server).""" try: diff --git a/marimo/_mcp/server/main.py b/marimo/_mcp/server/main.py index 4ff3f552732..d27166f9422 100644 --- a/marimo/_mcp/server/main.py +++ b/marimo/_mcp/server/main.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING from mcp.server.fastmcp import FastMCP -from starlette.routing import Mount from marimo._ai._tools.base import ToolContext from marimo._ai._tools.tools_registry import SUPPORTED_BACKEND_AND_MCP_TOOLS @@ -34,6 +33,7 @@ def setup_mcp_server(app: "Starlette") -> "StreamableHTTPSessionManager": Returns: StreamableHTTPSessionManager: MCP session manager """ + from starlette.routing import Mount mcp = FastMCP( "marimo-mcp-server", diff --git a/marimo/_plugins/stateless/mpl/_mpl.py b/marimo/_plugins/stateless/mpl/_mpl.py index 46ca1f8820c..3772c2bc8b2 100644 --- a/marimo/_plugins/stateless/mpl/_mpl.py +++ b/marimo/_plugins/stateless/mpl/_mpl.py @@ -17,8 +17,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union -from starlette.responses import HTMLResponse, Response - from marimo import _loggers from marimo._output.builder import h from marimo._output.formatting import as_html @@ -42,6 +40,7 @@ from matplotlib.figure import Figure, SubFigure from starlette.applications import Starlette from starlette.requests import Request + from starlette.responses import HTMLResponse, Response from starlette.websockets import WebSocket @@ -222,10 +221,12 @@ def _template(fig_id: str, port: int) -> str: } +# Toplevel for reuse in endpoints. async def mpl_js(request: Request) -> Response: from matplotlib.backends.backend_webagg_core import ( FigureManagerWebAgg, ) + from starlette.responses import Response del request return Response( @@ -235,6 +236,8 @@ async def mpl_js(request: Request) -> Response: async def mpl_custom_css(request: Request) -> Response: + from starlette.responses import Response + del request return Response( content=css_content, @@ -242,12 +245,14 @@ async def mpl_custom_css(request: Request) -> Response: ) +# Over all application for handling figures on a per kernel basis def create_application() -> Starlette: import matplotlib as mpl from matplotlib.backends.backend_webagg_core import ( FigureManagerWebAgg, ) from starlette.applications import Starlette + from starlette.responses import HTMLResponse, Response from starlette.routing import Mount, Route, WebSocketRoute from starlette.staticfiles import StaticFiles from starlette.websockets import ( diff --git a/marimo/_server/api/endpoints/mpl.py b/marimo/_server/api/endpoints/mpl.py index 5d8ad7c5957..bf3ba6fddde 100644 --- a/marimo/_server/api/endpoints/mpl.py +++ b/marimo/_server/api/endpoints/mpl.py @@ -7,15 +7,12 @@ from typing import TYPE_CHECKING, Any, Callable, Optional import websockets -from starlette.responses import Response - -# import StaticFiles from starlette -from starlette.staticfiles import StaticFiles if TYPE_CHECKING: from collections.abc import Awaitable from starlette.requests import Request + from starlette.responses import Response from starlette.websockets import WebSocket @@ -39,6 +36,7 @@ def mpl_fallback_handler( Args: path_prefix: Prefix to add to path when calling _mpl_handler (default "") """ + from starlette.responses import Response def decorator( func: Callable[[Request], Awaitable[Response]], @@ -80,6 +78,7 @@ async def mpl_static(request: Request) -> Response: from matplotlib.backends.backend_webagg_core import ( FigureManagerWebAgg, ) + from starlette.staticfiles import StaticFiles static_app = StaticFiles( directory=FigureManagerWebAgg.get_static_file_path() # type: ignore[no-untyped-call] @@ -93,6 +92,7 @@ async def mpl_images(request: Request) -> Response: """Fallback for image files from matplotlib.""" path = request.path_params["path"] import matplotlib as mpl + from starlette.staticfiles import StaticFiles static_app = StaticFiles(directory=Path(mpl.get_data_path(), "images")) return await static_app.get_response(path, request.scope) @@ -130,6 +130,8 @@ async def _mpl_handler( Returns: Response from the matplotlib server or error response """ + from starlette.responses import Response + # Proxy to matplotlib server # Determine the target port port = figure_endpoints.get(figurenum, None)