Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
465 changes: 253 additions & 212 deletions docs/_static/CLAUDE.md

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions marimo/_ai/_tools/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2025 Marimo. All rights reserved.
from __future__ import annotations

import dataclasses
import inspect
import re
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -237,7 +236,7 @@ def as_backend_tool(
# helpers
def _coerce_args(self, args: Any) -> ArgsT: # type: ignore[override]
"""If Args is a dataclass and args is a dict, construct it; else pass through."""
if dataclasses.is_dataclass(args):
if is_dataclass(args):
# Already parsed
return args # type: ignore[return-value]
return parse_raw(args, self.Args)
Expand Down
63 changes: 63 additions & 0 deletions marimo/_ai/_tools/tools/rules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2025 Marimo. All rights reserved.
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import marimo._utils.requests as requests
from marimo import _loggers
from marimo._ai._tools.base import ToolBase
from marimo._ai._tools.types import EmptyArgs, SuccessResult

LOGGER = _loggers.marimo_logger()

# We load the rules remotely, so we can update these without requiring a new release.
# If requested, we can bundle this into the library instead.
MARIMO_RULES_URL = "https://docs.marimo.io/CLAUDE.md"


@dataclass
class GetMarimoRulesOutput(SuccessResult):
rules_content: Optional[str] = None
source_url: str = MARIMO_RULES_URL


class GetMarimoRules(ToolBase[EmptyArgs, GetMarimoRulesOutput]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really good! Maybe we should add this to the next_steps= suggestion for some of the other tools?
My suggestions would be:

  • GetNotebookErrors() if there are any notebook specific or marimo specific errors
  • GetLightweightCellMap() and GetActiveNotebooks() because these are entry points.
  • GetDatabaseTables() because the rules contain some SQL-specific guidelines

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can do this in a followup. it hard to tell why an AI would ask for rules so im not sure if the next steps would veer it off track

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok sounds good. I'll keep this incase we decide to update it later. We can see first if the Chat LLMs either don't pull rules initially or pull it early on in the conversation then forget it.

"""Get the official marimo rules and guidelines for AI assistants.
Returns:
The content of the rules file.
"""

def handle(self, args: EmptyArgs) -> GetMarimoRulesOutput:
del args

try:
response = requests.get(MARIMO_RULES_URL, timeout=10)
response.raise_for_status()

return GetMarimoRulesOutput(
rules_content=response.text(),
source_url=MARIMO_RULES_URL,
next_steps=[
"Follow the guidelines in the rules when working with marimo notebooks",
],
)

except Exception as e:
LOGGER.warning(
"Failed to fetch marimo rules from %s: %s",
MARIMO_RULES_URL,
str(e),
)

return GetMarimoRulesOutput(
status="error",
message=f"Failed to fetch marimo rules: {str(e)}",
source_url=MARIMO_RULES_URL,
next_steps=[
"Check internet connectivity",
"Verify the rules URL is accessible",
"Try again later if the service is temporarily unavailable",
],
)
2 changes: 2 additions & 0 deletions marimo/_ai/_tools/tools_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from marimo._ai._tools.tools.datasource import GetDatabaseTables
from marimo._ai._tools.tools.errors import GetNotebookErrors
from marimo._ai._tools.tools.notebooks import GetActiveNotebooks
from marimo._ai._tools.tools.rules import GetMarimoRules
from marimo._ai._tools.tools.tables_and_variables import GetTablesAndVariables

SUPPORTED_BACKEND_AND_MCP_TOOLS: list[type[ToolBase[Any, Any]]] = [
GetMarimoRules,
GetActiveNotebooks,
GetCellRuntimeData,
GetLightweightCellMap,
Expand Down
11 changes: 7 additions & 4 deletions marimo/_mcp/server/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import TYPE_CHECKING

from marimo._loggers import marimo_logger
from marimo._mcp.server.main import setup_mcp_server

LOGGER = marimo_logger()

Expand All @@ -17,11 +16,15 @@ async def mcp_server_lifespan(app: "Starlette") -> AsyncIterator[None]:
"""Lifespan for MCP server functionality (exposing marimo as MCP server)."""

try:
session_manager = setup_mcp_server(app)
mcp_app = app.state.mcp
if mcp_app is None:
LOGGER.warning("MCP server not found in app state")
yield
return

async with session_manager.run():
# Session manager owns request lifecycle during app run
async with mcp_app.session_manager.run():
LOGGER.info("MCP server session manager started")
# Session manager owns request lifecycle during app run
yield

except ImportError as e:
Expand Down
34 changes: 25 additions & 9 deletions marimo/_mcp/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

from typing import TYPE_CHECKING

from mcp.server.fastmcp import FastMCP

from marimo._ai._tools.base import ToolContext
from marimo._ai._tools.tools_registry import SUPPORTED_BACKEND_AND_MCP_TOOLS
from marimo._loggers import marimo_logger
Expand All @@ -18,11 +16,10 @@


if TYPE_CHECKING:
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.applications import Starlette


def setup_mcp_server(app: "Starlette") -> "StreamableHTTPSessionManager":
def setup_mcp_server(app: "Starlette") -> None:
"""Create and configure MCP server for marimo integration.

Args:
Expand All @@ -33,17 +30,20 @@ def setup_mcp_server(app: "Starlette") -> "StreamableHTTPSessionManager":
Returns:
StreamableHTTPSessionManager: MCP session manager
"""
from mcp.server.fastmcp import FastMCP
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.routing import Mount
from starlette.types import Receive, Scope, Send

mcp = FastMCP(
"marimo-mcp-server",
stateless_http=True,
log_level="WARNING",
# Change base path from /mcp to /server
streamable_http_path="/server",
)

# Change base path from /mcp to /server
mcp.settings.streamable_http_path = "/server"

# Register all tools
context = ToolContext(app=app)
for tool in SUPPORTED_BACKEND_AND_MCP_TOOLS:
Expand All @@ -53,7 +53,23 @@ def setup_mcp_server(app: "Starlette") -> "StreamableHTTPSessionManager":
# Initialize streamable HTTP app
mcp_app = mcp.streamable_http_app()

# Middleware to require edit scope
class RequiresEditMiddleware(BaseHTTPMiddleware):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RequiresEditMiddleware() doesn't seem to be mcp specific at all. Maybe it would be better to move this logic to marimo/_server/api/middleware.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is actually just for MCP. we already support auth on other endpoints with @requires

async def __call__(
self, scope: Scope, receive: Receive, send: Send
) -> None:
auth = scope.get("auth")
if auth is None or "edit" not in auth.scopes:
response = JSONResponse(
{"detail": "Forbidden"},
status_code=403,
)
return await response(scope, receive, send)

return await self.app(scope, receive, send)

mcp_app.add_middleware(RequiresEditMiddleware)

# Add to the top of the routes to avoid conflicts with other routes
app.routes.insert(0, Mount("/mcp", mcp_app))

return mcp.session_manager
app.state.mcp = mcp
10 changes: 7 additions & 3 deletions marimo/_server/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,18 @@ def print_mcp_server(mcp_url: str, server_token: str | None) -> None:
"""Print MCP server configuration when MCP is enabled."""
print_()
print_tabbed(
f"{_utf8('🔗')} {green('Experimental MCP Server Configuration', bold=True)}"
f"{_utf8('🔗')} {green('Experimental MCP server configuration', bold=True)}"
)
print_tabbed(
f"{_utf8('➜')} {green('MCP Server URL')}: {_colorized_url(mcp_url)}"
f"{_utf8('➜')} {green('MCP server URL')}: {_colorized_url(mcp_url)}"
)
# Add to Claude code
print_tabbed(
f"{_utf8('➜')} {green('Add to Claude Code')}: claude mcp add --transport http marimo {mcp_url}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also include 'add to Cursor'? Since the most likely use cases for this are Claude Code and Cursor. It can be something like:

print_tabbed(
    f"{_utf8('➜')} {green('Add to Cursor')}: Add the following to ~/.cursor/mcp.json:\n"
    f"""  {{
    "mcpServers": {{
      "marimo": {{
        "url": "{mcp_url}"
      }}
    }}
  }}"""
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that seems a bit verbose for the CLI. we could look into making the CLI more interface in the future (e.g. press i for more info)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that makes sense

)
if server_token is not None:
print_tabbed(
f"{_utf8('➜')} {green('Add Header')}: Marimo-Server-Token: {muted(server_token)}"
f"{_utf8('➜')} {green('Add header')}: Marimo-Server-Token: {muted(server_token)}"
)
print_()

Expand Down
18 changes: 17 additions & 1 deletion marimo/_server/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from marimo._cli.print import echo
from marimo._config.manager import get_default_config_manager
from marimo._config.settings import GLOBAL_SETTINGS
from marimo._mcp.server.main import setup_mcp_server
from marimo._messaging.ops import StartupLogs
from marimo._runtime.requests import SerializedCLIArgs
from marimo._server.file_router import AppFileRouter
Expand Down Expand Up @@ -174,6 +175,16 @@ def start(
Start the server.
"""

# Defaults when mcp is enabled
if mcp:
# Turn on watch mode
watch = True
# Turn off skew protection for MCP server
# since it is more convenient to connect to.
# Skew protection is not a security thing, but rather
# prevents connecting to old servers.
skew_protection = False

# Find a free port if none is specified
# if the user specifies a port, we don't try to find a free one
port = port or find_free_port(DEFAULT_PORT, addr=host)
Expand Down Expand Up @@ -240,7 +251,9 @@ def start(
*LIFESPAN_REGISTRY.get_all(),
]

if mcp and mode == SessionMode.EDIT:
mcp_enabled = mcp and mode == SessionMode.EDIT

if mcp_enabled:
from marimo._mcp.server.lifespan import mcp_server_lifespan

lifespans_list.append(mcp_server_lifespan)
Expand All @@ -260,6 +273,9 @@ def start(
timeout=timeout,
)

if mcp_enabled:
setup_mcp_server(app)

app.state.port = external_port
app.state.host = external_host

Expand Down
2 changes: 1 addition & 1 deletion marimo/_utils/file_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def create(path: Path, callback: Callback) -> FileWatcher:
LOGGER.debug("Using watchdog file watcher")
return _create_watchdog(path, callback, asyncio.get_event_loop())
else:
LOGGER.warning(
LOGGER.info(
"watchdog is not installed, using polling file watcher"
)
return PollingFileWatcher(path, callback, asyncio.get_event_loop())
Expand Down
11 changes: 0 additions & 11 deletions tests/_mcp/server/test_main.py

This file was deleted.

101 changes: 101 additions & 0 deletions tests/_mcp/server/test_mcp_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2024 Marimo. All rights reserved.
import pytest

from marimo._mcp.server.lifespan import mcp_server_lifespan

pytest.importorskip("mcp", reason="MCP requires Python 3.10+")

from starlette.applications import Starlette
from starlette.authentication import AuthCredentials, SimpleUser
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import HTTPConnection
from starlette.testclient import TestClient

from marimo._mcp.server.main import setup_mcp_server
from marimo._server.api.middleware import AuthBackend
from tests._server.mocks import get_mock_session_manager


def create_test_app() -> Starlette:
"""Create a test Starlette app with MCP server."""
app = Starlette(
middleware=[
Middleware(
AuthenticationMiddleware,
backend=AuthBackend(should_authenticate=False),
),
],
)
app.state.session_manager = get_mock_session_manager()
setup_mcp_server(app)
return app


def test_mcp_server_starts_up():
"""Test that MCP server can be set up and routes are registered."""
app = create_test_app()
client = TestClient(app)

# Verify the MCP server is mounted
assert hasattr(app.state, "mcp")

# Verify /mcp route exists
assert any("/mcp" in str(route.path) for route in app.routes)


async def test_mcp_server_requires_edit_scope():
"""Test that MCP server validates 'edit' scope is present."""
app = create_test_app()

# Mock a request without edit scope
class MockAuthBackend:
async def authenticate(self, conn: HTTPConnection):
del conn
# Return user without edit scope
return AuthCredentials(scopes=["read"]), SimpleUser("test_user")

# Create app with authentication that doesn't include edit scope
app_no_edit = Starlette(
middleware=[
Middleware(
AuthenticationMiddleware,
backend=MockAuthBackend(),
),
],
)
app_no_edit.state.session_manager = get_mock_session_manager()
setup_mcp_server(app_no_edit)

client = TestClient(app_no_edit, raise_server_exceptions=False)

# Try to access MCP endpoint without edit scope
response = client.get("/mcp/server")
assert response.status_code == 403

# Mock a request with edit scope
class MockAuthBackendWithEdit:
async def authenticate(self, conn: HTTPConnection):
del conn
# Return user with edit scope
return AuthCredentials(scopes=["edit"]), SimpleUser("test_user")

# Create app with edit scope
app_with_edit = Starlette(
middleware=[
Middleware(
AuthenticationMiddleware,
backend=MockAuthBackendWithEdit(),
),
],
)

setup_mcp_server(app_with_edit)
async with mcp_server_lifespan(app_with_edit):
app_with_edit.state.session_manager = get_mock_session_manager()

client_with_edit = TestClient(app_with_edit)

# Access should not be forbidden (may get other status codes based on MCP protocol)
response = client_with_edit.get("/mcp/server")
assert response.status_code != 403
Loading