Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
465 changes: 253 additions & 212 deletions docs/_static/CLAUDE.md

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions marimo/_ai/_tools/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2025 Marimo. All rights reserved.
from __future__ import annotations

import dataclasses
import inspect
import re
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -237,7 +236,7 @@ def as_backend_tool(
# helpers
def _coerce_args(self, args: Any) -> ArgsT: # type: ignore[override]
"""If Args is a dataclass and args is a dict, construct it; else pass through."""
if dataclasses.is_dataclass(args):
if is_dataclass(args):
# Already parsed
return args # type: ignore[return-value]
return parse_raw(args, self.Args)
Expand Down
63 changes: 63 additions & 0 deletions marimo/_ai/_tools/tools/rules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2025 Marimo. All rights reserved.
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import marimo._utils.requests as requests
from marimo import _loggers
from marimo._ai._tools.base import ToolBase
from marimo._ai._tools.types import EmptyArgs, SuccessResult

LOGGER = _loggers.marimo_logger()

# We load the rules remotely, so we can update these without requiring a new release.
# If requested, we can bundle this into the library instead.
MARIMO_RULES_URL = "https://docs.marimo.io/CLAUDE.md"


@dataclass
class GetMarimoRulesOutput(SuccessResult):
rules_content: Optional[str] = None
source_url: str = MARIMO_RULES_URL


class GetMarimoRules(ToolBase[EmptyArgs, GetMarimoRulesOutput]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really good! Maybe we should add this to the next_steps= suggestion for some of the other tools?
My suggestions would be:

  • GetNotebookErrors() if there are any notebook specific or marimo specific errors
  • GetLightweightCellMap() and GetActiveNotebooks() because these are entry points.
  • GetDatabaseTables() because the rules contain some SQL-specific guidelines

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can do this in a followup. it hard to tell why an AI would ask for rules so im not sure if the next steps would veer it off track

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok sounds good. I'll keep this incase we decide to update it later. We can see first if the Chat LLMs either don't pull rules initially or pull it early on in the conversation then forget it.

"""Get the official marimo rules and guidelines for AI assistants.
Returns:
The content of the rules file.
"""

def handle(self, args: EmptyArgs) -> GetMarimoRulesOutput:
del args

try:
response = requests.get(MARIMO_RULES_URL, timeout=10)
response.raise_for_status()

return GetMarimoRulesOutput(
rules_content=response.text(),
source_url=MARIMO_RULES_URL,
next_steps=[
"Follow the guidelines in the rules when working with marimo notebooks",
],
)

except Exception as e:
LOGGER.warning(
"Failed to fetch marimo rules from %s: %s",
MARIMO_RULES_URL,
str(e),
)

return GetMarimoRulesOutput(
status="error",
message=f"Failed to fetch marimo rules: {str(e)}",
source_url=MARIMO_RULES_URL,
next_steps=[
"Check internet connectivity",
"Verify the rules URL is accessible",
"Try again later if the service is temporarily unavailable",
],
)
2 changes: 2 additions & 0 deletions marimo/_ai/_tools/tools_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from marimo._ai._tools.tools.datasource import GetDatabaseTables
from marimo._ai._tools.tools.errors import GetNotebookErrors
from marimo._ai._tools.tools.notebooks import GetActiveNotebooks
from marimo._ai._tools.tools.rules import GetMarimoRules
from marimo._ai._tools.tools.tables_and_variables import GetTablesAndVariables

SUPPORTED_BACKEND_AND_MCP_TOOLS: list[type[ToolBase[Any, Any]]] = [
GetMarimoRules,
GetActiveNotebooks,
GetCellRuntimeData,
GetLightweightCellMap,
Expand Down
11 changes: 7 additions & 4 deletions marimo/_mcp/server/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import TYPE_CHECKING

from marimo._loggers import marimo_logger
from marimo._mcp.server.main import setup_mcp_server

LOGGER = marimo_logger()

Expand All @@ -17,11 +16,15 @@ async def mcp_server_lifespan(app: "Starlette") -> AsyncIterator[None]:
"""Lifespan for MCP server functionality (exposing marimo as MCP server)."""

try:
session_manager = setup_mcp_server(app)
mcp_app = app.state.mcp
if mcp_app is None:
LOGGER.warning("MCP server not found in app state")
yield
return

async with session_manager.run():
# Session manager owns request lifecycle during app run
async with mcp_app.session_manager.run():
LOGGER.info("MCP server session manager started")
# Session manager owns request lifecycle during app run
yield

except ImportError as e:
Expand Down
34 changes: 25 additions & 9 deletions marimo/_mcp/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

from typing import TYPE_CHECKING

from mcp.server.fastmcp import FastMCP

from marimo._ai._tools.base import ToolContext
from marimo._ai._tools.tools_registry import SUPPORTED_BACKEND_AND_MCP_TOOLS
from marimo._loggers import marimo_logger
Expand All @@ -18,11 +16,10 @@


if TYPE_CHECKING:
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.applications import Starlette


def setup_mcp_server(app: "Starlette") -> "StreamableHTTPSessionManager":
def setup_mcp_server(app: "Starlette") -> None:
"""Create and configure MCP server for marimo integration.

Args:
Expand All @@ -33,17 +30,20 @@ def setup_mcp_server(app: "Starlette") -> "StreamableHTTPSessionManager":
Returns:
StreamableHTTPSessionManager: MCP session manager
"""
from mcp.server.fastmcp import FastMCP
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.routing import Mount
from starlette.types import Receive, Scope, Send

mcp = FastMCP(
"marimo-mcp-server",
stateless_http=True,
log_level="WARNING",
# Change base path from /mcp to /server
streamable_http_path="/server",
)

# Change base path from /mcp to /server
mcp.settings.streamable_http_path = "/server"

# Register all tools
context = ToolContext(app=app)
for tool in SUPPORTED_BACKEND_AND_MCP_TOOLS:
Expand All @@ -53,7 +53,23 @@ def setup_mcp_server(app: "Starlette") -> "StreamableHTTPSessionManager":
# Initialize streamable HTTP app
mcp_app = mcp.streamable_http_app()

# Middleware to require edit scope
class RequiresEditMiddleware(BaseHTTPMiddleware):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RequiresEditMiddleware() doesn't seem to be mcp specific at all. Maybe it would be better to move this logic to marimo/_server/api/middleware.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is actually just for MCP. we already support auth on other endpoints with @requires

async def __call__(
self, scope: Scope, receive: Receive, send: Send
) -> None:
auth = scope.get("auth")
if auth is None or "edit" not in auth.scopes:
response = JSONResponse(
{"detail": "Forbidden"},
status_code=403,
)
return await response(scope, receive, send)

return await self.app(scope, receive, send)

mcp_app.add_middleware(RequiresEditMiddleware)

# Add to the top of the routes to avoid conflicts with other routes
app.routes.insert(0, Mount("/mcp", mcp_app))

return mcp.session_manager
app.state.mcp = mcp
10 changes: 7 additions & 3 deletions marimo/_server/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,18 @@ def print_mcp_server(mcp_url: str, server_token: str | None) -> None:
"""Print MCP server configuration when MCP is enabled."""
print_()
print_tabbed(
f"{_utf8('🔗')} {green('Experimental MCP Server Configuration', bold=True)}"
f"{_utf8('🔗')} {green('Experimental MCP server configuration', bold=True)}"
)
print_tabbed(
f"{_utf8('➜')} {green('MCP Server URL')}: {_colorized_url(mcp_url)}"
f"{_utf8('➜')} {green('MCP server URL')}: {_colorized_url(mcp_url)}"
)
# Add to Claude code
print_tabbed(
f"{_utf8('➜')} {green('Add to Claude Code')}: claude mcp add --transport http marimo {mcp_url}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also include 'add to Cursor'? Since the most likely use cases for this are Claude Code and Cursor. It can be something like:

print_tabbed(
    f"{_utf8('➜')} {green('Add to Cursor')}: Add the following to ~/.cursor/mcp.json:\n"
    f"""  {{
    "mcpServers": {{
      "marimo": {{
        "url": "{mcp_url}"
      }}
    }}
  }}"""
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that seems a bit verbose for the CLI. we could look into making the CLI more interface in the future (e.g. press i for more info)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that makes sense

)
if server_token is not None:
print_tabbed(
f"{_utf8('➜')} {green('Add Header')}: Marimo-Server-Token: {muted(server_token)}"
f"{_utf8('➜')} {green('Add header')}: Marimo-Server-Token: {muted(server_token)}"
)
print_()

Expand Down
18 changes: 17 additions & 1 deletion marimo/_server/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from marimo._cli.print import echo
from marimo._config.manager import get_default_config_manager
from marimo._config.settings import GLOBAL_SETTINGS
from marimo._mcp.server.main import setup_mcp_server
from marimo._messaging.ops import StartupLogs
from marimo._runtime.requests import SerializedCLIArgs
from marimo._server.file_router import AppFileRouter
Expand Down Expand Up @@ -174,6 +175,16 @@ def start(
Start the server.
"""

# Defaults when mcp is enabled
if mcp:
# Turn on watch mode
watch = True
# Turn off skew protection for MCP server
# since it is more convenient to connect to.
# Skew protection is not a security thing, but rather
# prevents connecting to old servers.
skew_protection = False

# Find a free port if none is specified
# if the user specifies a port, we don't try to find a free one
port = port or find_free_port(DEFAULT_PORT, addr=host)
Expand Down Expand Up @@ -240,7 +251,9 @@ def start(
*LIFESPAN_REGISTRY.get_all(),
]

if mcp and mode == SessionMode.EDIT:
mcp_enabled = mcp and mode == SessionMode.EDIT

if mcp_enabled:
from marimo._mcp.server.lifespan import mcp_server_lifespan

lifespans_list.append(mcp_server_lifespan)
Expand All @@ -260,6 +273,9 @@ def start(
timeout=timeout,
)

if mcp_enabled:
setup_mcp_server(app)

app.state.port = external_port
app.state.host = external_host

Expand Down
2 changes: 1 addition & 1 deletion marimo/_utils/file_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def create(path: Path, callback: Callback) -> FileWatcher:
LOGGER.debug("Using watchdog file watcher")
return _create_watchdog(path, callback, asyncio.get_event_loop())
else:
LOGGER.warning(
LOGGER.info(
"watchdog is not installed, using polling file watcher"
)
return PollingFileWatcher(path, callback, asyncio.get_event_loop())
Expand Down
79 changes: 79 additions & 0 deletions tests/_ai/tools/tools/test_rules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2025 Marimo. All rights reserved.

from __future__ import annotations

from unittest.mock import Mock, patch

import pytest

from marimo._ai._tools.base import ToolContext
from marimo._ai._tools.tools.rules import GetMarimoRules
from marimo._ai._tools.types import EmptyArgs


@pytest.fixture
def tool() -> GetMarimoRules:
"""Create a GetMarimoRules tool instance."""
return GetMarimoRules(ToolContext())


def test_get_rules_success(tool: GetMarimoRules) -> None:
"""Test successfully fetching marimo rules."""
mock_response = Mock()
mock_response.text.return_value = "# Marimo Rules\n\nTest content"
mock_response.raise_for_status = Mock()

with patch("marimo._utils.requests.get", return_value=mock_response):
result = tool.handle(EmptyArgs())

assert result.status == "success"
assert result.rules_content == "# Marimo Rules\n\nTest content"
assert result.source_url == "https://docs.marimo.io/CLAUDE.md"
assert len(result.next_steps) == 1
assert "Follow the guidelines" in result.next_steps[0]
mock_response.raise_for_status.assert_called_once()


def test_get_rules_http_error(tool: GetMarimoRules) -> None:
"""Test handling HTTP errors when fetching rules."""
mock_response = Mock()
mock_response.raise_for_status.side_effect = Exception("404 Not Found")

with patch("marimo._utils.requests.get", return_value=mock_response):
result = tool.handle(EmptyArgs())

assert result.status == "error"
assert result.rules_content is None
assert "Failed to fetch marimo rules" in result.message
assert "404 Not Found" in result.message
assert result.source_url == "https://docs.marimo.io/CLAUDE.md"
assert len(result.next_steps) == 3
assert "Check internet connectivity" in result.next_steps[0]


def test_get_rules_network_error(tool: GetMarimoRules) -> None:
"""Test handling network errors when fetching rules."""
with patch(
"marimo._utils.requests.get",
side_effect=Exception("Connection refused"),
):
result = tool.handle(EmptyArgs())

assert result.status == "error"
assert result.rules_content is None
assert "Failed to fetch marimo rules" in result.message
assert "Connection refused" in result.message
assert len(result.next_steps) == 3


def test_get_rules_timeout(tool: GetMarimoRules) -> None:
"""Test handling timeout when fetching rules."""
with patch(
"marimo._utils.requests.get",
side_effect=Exception("Request timeout"),
):
result = tool.handle(EmptyArgs())

assert result.status == "error"
assert result.rules_content is None
assert "Request timeout" in result.message
11 changes: 0 additions & 11 deletions tests/_mcp/server/test_main.py

This file was deleted.

Loading
Loading