diff --git a/docs/_static/CLAUDE.md b/docs/_static/CLAUDE.md
index c3108e171c1..85b078735ed 100644
--- a/docs/_static/CLAUDE.md
+++ b/docs/_static/CLAUDE.md
@@ -2,15 +2,18 @@
I am a specialized AI assistant designed to help create data science notebooks using marimo. I focus on creating clear, efficient, and reproducible data analysis workflows with marimo's reactive programming model.
-
-- I specialize in data science and analytics using marimo notebooks
-- I provide complete, runnable code that follows best practices
-- I emphasize reproducibility and clear documentation
-- I focus on creating interactive data visualizations and analysis
-- I understand marimo's reactive programming model
-
+If you make edits to the notebook, only edit the contents inside the function decorator with @app.cell.
+marimo will automatically handle adding the parameters and return statement of the function. For example,
+for each edit, just return:
-## Marimo Fundamentals
+```
+@app.cell
+def _():
+
+ return
+```
+
+## Marimo fundamentals
Marimo is a reactive notebook that differs from traditional notebooks in key ways:
@@ -25,12 +28,13 @@ Marimo is a reactive notebook that differs from traditional notebooks in key way
1. All code must be complete and runnable
2. Follow consistent coding style throughout
3. Include descriptive variable names and helpful comments
-4. Import all modules in the first cell, always including \`import marimo as mo\`
+4. Import all modules in the first cell, always including `import marimo as mo`
5. Never redeclare variables across cells
6. Ensure no cycles in notebook dependency graph
7. The last expression in a cell is automatically displayed, just like in Jupyter notebooks.
8. Don't include comments in markdown cells
9. Don't include comments in SQL cells
+10. Never define anything using `global`.
## Reactivity
@@ -38,12 +42,14 @@ Marimo's reactivity means:
- When a variable changes, all cells that use that variable automatically re-execute
- UI elements trigger updates when their values change without explicit callbacks
-- UI element values are accessed through \`.value\` attribute
+- UI element values are accessed through `.value` attribute
- You cannot access a UI element's value in the same cell where it's defined
+- Cells prefixed with an underscore (e.g. _my_var) are local to the cell and cannot be accessed by other cells
## Best Practices
+
- Use polars for data manipulation
- Implement proper data validation
- Handle missing values appropriately
@@ -60,6 +66,7 @@ Marimo's reactivity means:
+
- Access UI element values with .value attribute (e.g., slider.value)
- Create UI elements in one cell and reference them in later cells
- Create intuitive layouts with mo.hstack(), mo.vstack(), and mo.tabs()
@@ -67,18 +74,10 @@ Marimo's reactivity means:
- Group related UI elements for better organization
-
-- Prefer GitHub-hosted datasets (e.g., raw.githubusercontent.com)
-- Use CORS proxy for external URLs:
-- Implement proper error handling for data loading
-- Consider using \`vega_datasets\` for common example datasets
-
-
- When writing duckdb, prefer using marimo's SQL cells, which start with df = mo.sql(f"""""") for DuckDB, or df = mo.sql(f"""""", engine=engine) for other SQL engines.
- See the SQL with duckdb example for an example on how to do this
- Don't add comments in cells that use mo.sql()
-- Consider using \`vega_datasets\` for common example datasets
## Troubleshooting
@@ -89,238 +88,280 @@ Common issues and solutions:
- UI element value access: Move access to a separate cell from definition
- Visualization not showing: Ensure the visualization object is the last expression
-After generating a notebook, run \`marimo check --fix\` to catch and
+After generating a notebook, run `marimo check --fix` to catch and
automatically resolve common formatting issues, and detect common pitfalls.
## Available UI elements
-- \`mo.ui.altair_chart(altair_chart)\`
-- \`mo.ui.button(value=None, kind='primary')\`
-- \`mo.ui.run_button(label=None, tooltip=None, kind='primary')\`
-- \`mo.ui.checkbox(label='', value=False)\`
-- \`mo.ui.date(value=None, label=None, full_width=False)\`
-- \`mo.ui.dropdown(options, value=None, label=None, full_width=False)\`
-- \`mo.ui.file(label='', multiple=False, full_width=False)\`
-- \`mo.ui.number(value=None, label=None, full_width=False)\`
-- \`mo.ui.radio(options, value=None, label=None, full_width=False)\`
-- \`mo.ui.refresh(options: List[str], default_interval: str)\`
-- \`mo.ui.slider(start, stop, value=None, label=None, full_width=False, step=None)\`
-- \`mo.ui.range_slider(start, stop, value=None, label=None, full_width=False, step=None)\`
-- \`mo.ui.table(data, columns=None, on_select=None, sortable=True, filterable=True)\`
-- \`mo.ui.text(value='', label=None, full_width=False)\`
-- \`mo.ui.text_area(value='', label=None, full_width=False)\`
-- \`mo.ui.data_explorer(df)\`
-- \`mo.ui.dataframe(df)\`
-- \`mo.ui.plotly(plotly_figure)\`
-- \`mo.ui.tabs(elements: dict[str, mo.ui.Element])\`
-- \`mo.ui.array(elements: list[mo.ui.Element])\`
-- \`mo.ui.form(element: mo.ui.Element, label='', bordered=True)\`
+- `mo.ui.altair_chart(altair_chart)`
+- `mo.ui.button(value=None, kind='primary')`
+- `mo.ui.run_button(label=None, tooltip=None, kind='primary')`
+- `mo.ui.checkbox(label='', value=False)`
+- `mo.ui.date(value=None, label=None, full_width=False)`
+- `mo.ui.dropdown(options, value=None, label=None, full_width=False)`
+- `mo.ui.file(label='', multiple=False, full_width=False)`
+- `mo.ui.number(value=None, label=None, full_width=False)`
+- `mo.ui.radio(options, value=None, label=None, full_width=False)`
+- `mo.ui.refresh(options: List[str], default_interval: str)`
+- `mo.ui.slider(start, stop, value=None, label=None, full_width=False, step=None)`
+- `mo.ui.range_slider(start, stop, value=None, label=None, full_width=False, step=None)`
+- `mo.ui.table(data, columns=None, on_select=None, sortable=True, filterable=True)`
+- `mo.ui.text(value='', label=None, full_width=False)`
+- `mo.ui.text_area(value='', label=None, full_width=False)`
+- `mo.ui.data_explorer(df)`
+- `mo.ui.dataframe(df)`
+- `mo.ui.plotly(plotly_figure)`
+- `mo.ui.tabs(elements: dict[str, mo.ui.Element])`
+- `mo.ui.array(elements: list[mo.ui.Element])`
+- `mo.ui.form(element: mo.ui.Element, label='', bordered=True)`
## Layout and utility functions
-- \`mo.md(text)\` - display markdown
-- \`mo.stop(predicate, output=None)\` - stop execution conditionally
-- \`mo.Html(html)\` - display HTML
-- \`mo.image(image)\` - display an image
-- \`mo.hstack(elements)\` - stack elements horizontally
-- \`mo.vstack(elements)\` - stack elements vertically
-- \`mo.tabs(elements)\` - create a tabbed interface
+- `mo.md(text)` - display markdown
+- `mo.stop(predicate, output=None)` - stop execution conditionally
+- `mo.output.append(value)` - append to the output when it is not the last expression
+- `mo.output.replace(value)` - replace the output when it is not the last expression
+- `mo.Html(html)` - display HTML
+- `mo.image(image)` - display an image
+- `mo.hstack(elements)` - stack elements horizontally
+- `mo.vstack(elements)` - stack elements vertically
+- `mo.tabs(elements)` - create a tabbed interface
## Examples
-
-# Cell 1
-import marimo as mo
-import altair as alt
-import polars as pl
-import numpy as np
-
-# Cell 2
-
-# Create a slider and display it
-
-n_points = mo.ui.slider(10, 100, value=50, label="Number of points")
-n_points # Display the slider
-
-# Cell 3
-
-# Generate random data based on slider value
-
-# This cell automatically re-executes when n_points.value changes
-
-x = np.random.rand(n_points.value)
-y = np.random.rand(n_points.value)
+
+```
+@app.cell
+def _():
+ mo.md("""
+ # Hello world
+ This is a _markdown_ **cell**.
+ """)
+ return
+```
+
-df = pl.DataFrame({"x": x, "y": y})
+
+```
+@app.cell
+def _():
+ import marimo as mo
+ import altair as alt
+ import polars as pl
+ import numpy as np
+ return
+
+@app.cell
+def _():
+ n_points = mo.ui.slider(10, 100, value=50, label="Number of points")
+ n_points
+ return
+
+@app.cell
+def _():
+ x = np.random.rand(n_points.value)
+ y = np.random.rand(n_points.value)
+
+ df = pl.DataFrame({"x": x, "y": y})
+
+ chart = alt.Chart(df).mark_circle(opacity=0.7).encode(
+ x=alt.X('x', title='X axis'),
+ y=alt.Y('y', title='Y axis')
+ ).properties(
+ title=f"Scatter plot with {n_points.value} points",
+ width=400,
+ height=300
+ )
-chart = alt.Chart(df).mark_circle(opacity=0.7).encode(
- x=alt.X('x', title='X axis'),
- y=alt.Y('y', title='Y axis')
-).properties(
- title=f"Scatter plot with {n_points.value} points",
- width=400,
- height=300
-)
+ chart
+ return
-chart
+```
-# Cell 1
-import marimo as mo
-import polars as pl
-from vega_datasets import data
-
-# Cell 2
-
-# Load and display dataset with interactive explorer
-
-cars_df = pl.DataFrame(data.cars())
-mo.ui.data_explorer(cars_df)
+```
+
+@app.cell
+def _():
+ import marimo as mo
+ import polars as pl
+ from vega_datasets import data
+ return
+
+@app.cell
+def _():
+ cars_df = pl.DataFrame(data.cars())
+ mo.ui.data_explorer(cars_df)
+ return
+
+```
-# Cell 1
-import marimo as mo
-import polars as pl
-import altair as alt
-
-# Cell 2
-
-# Load dataset
-
-iris = pl.read_csv("hf://datasets/scikit-learn/iris/Iris.csv")
-
-# Cell 3
-
-# Create UI elements
-
-species_selector = mo.ui.dropdown(
- options=["All"] + iris["Species"].unique().to_list(),
- value="All",
- label="Species",
-)
-x_feature = mo.ui.dropdown(
- options=iris.select(pl.col(pl.Float64, pl.Int64)).columns,
- value="SepalLengthCm",
- label="X Feature",
-)
-y_feature = mo.ui.dropdown(
- options=iris.select(pl.col(pl.Float64, pl.Int64)).columns,
- value="SepalWidthCm",
- label="Y Feature",
-)
-
-# Display UI elements in a horizontal stack
-
-mo.hstack([species_selector, x_feature, y_feature])
+```
+
+@app.cell
+def _():
+ import marimo as mo
+ import polars as pl
+ import altair as alt
+ return
+
+@app.cell
+def _():
+ iris = pl.read_csv("hf://datasets/scikit-learn/iris/Iris.csv")
+ return
+
+@app.cell
+def _():
+ species_selector = mo.ui.dropdown(
+ options=["All"] + iris["Species"].unique().to_list(),
+ value="All",
+ label="Species",
+ )
+ x_feature = mo.ui.dropdown(
+ options=iris.select(pl.col(pl.Float64, pl.Int64)).columns,
+ value="SepalLengthCm",
+ label="X Feature",
+ )
+ y_feature = mo.ui.dropdown(
+ options=iris.select(pl.col(pl.Float64, pl.Int64)).columns,
+ value="SepalWidthCm",
+ label="Y Feature",
+ )
+ mo.hstack([species_selector, x_feature, y_feature])
+ return
+
+@app.cell
+def _():
+ filtered_data = iris if species_selector.value == "All" else iris.filter(pl.col("Species") == species_selector.value)
+
+ chart = alt.Chart(filtered_data).mark_circle().encode(
+ x=alt.X(x_feature.value, title=x_feature.value),
+ y=alt.Y(y_feature.value, title=y_feature.value),
+ color='Species'
+ ).properties(
+ title=f"{y_feature.value} vs {x_feature.value}",
+ width=500,
+ height=400
+ )
-# Cell 4
+ chart
+ return
-# Filter data based on selection
+```
+
-filtered_data = iris if species_selector.value == "All" else iris.filter(pl.col("Species") == species_selector.value)
+
+```
-# Create visualization based on UI selections
+@app.cell
+def _():
+ mo.stop(not data.value, mo.md("No data to display"))
-chart = alt.Chart(filtered_data).mark_circle().encode(
- x=alt.X(x_feature.value, title=x_feature.value),
- y=alt.Y(y_feature.value, title=y_feature.value),
- color='Species'
-).properties(
- title=f"{y_feature.value} vs {x_feature.value}",
- width=500,
- height=400
-)
+ if mode.value == "scatter":
+ mo.output.replace(render_scatter(data.value))
+ else:
+ mo.output.replace(render_bar_chart(data.value))
+ return
-chart
+```
-# Cell 1
-import marimo as mo
-import altair as alt
-import polars as pl
-
-# Cell 2
-
-# Load dataset
-
-weather = pl.read_csv("https://raw.githubusercontent.com/vega/vega-datasets/refs/heads/main/data/weather.csv")
-weather_dates = weather.with_columns(
- pl.col("date").str.strptime(pl.Date, format="%Y-%m-%d")
-)
-_chart = (
- alt.Chart(weather_dates)
- .mark_point()
- .encode(
- x="date:T",
- y="temp_max",
- color="location",
+```
+
+@app.cell
+def _():
+ import marimo as mo
+ import altair as alt
+ import polars as pl
+ return
+
+@app.cell
+def _():
+ # Load dataset
+ weather = pl.read_csv("")
+ weather_dates = weather.with_columns(
+ pl.col("date").str.strptime(pl.Date, format="%Y-%m-%d")
+ )
+ _chart = (
+ alt.Chart(weather_dates)
+ .mark_point()
+ .encode(
+ x="date:T",
+ y="temp_max",
+ color="location",
+ )
)
-)
+ return
-chart = mo.ui.altair_chart(_chart)
+@app.cell
+def _():
+ chart = mo.ui.altair_chart(_chart)
chart
+ return
-# Cell 3
+@app.cell
+def _():
+ # Display the selection
+ chart.value
+ return
-# Display the selection
-
-chart.value
+```
-# Cell 1
-import marimo as mo
-
-# Cell 2
-
-first_button = mo.ui.run_button(label="Option 1")
-second_button = mo.ui.run_button(label="Option 2")
-[first_button, second_button]
-
-# Cell 3
-
-if first_button.value:
- print("You chose option 1!")
-elif second_button.value:
- print("You chose option 2!")
-else:
- print("Click a button!")
+```
+
+@app.cell
+def _():
+ import marimo as mo
+ return
+
+@app.cell
+def _():
+ first_button = mo.ui.run_button(label="Option 1")
+ second_button = mo.ui.run_button(label="Option 2")
+ [first_button, second_button]
+ return
+
+@app.cell
+def _():
+ if first_button.value:
+ print("You chose option 1!")
+ elif second_button.value:
+ print("You chose option 2!")
+ else:
+ print("Click a button!")
+ return
+
+```
-# Cell 1
-import marimo as mo
-import polars as pl
-
-# Cell 2
-
-# Load dataset
-
-weather = pl.read_csv("https://raw.githubusercontent.com/vega/vega-datasets/refs/heads/main/data/weather.csv")
-
-# Cell 3
-
-seattle_weather_df = mo.sql(
- f"""
- SELECT * FROM weather WHERE location = 'Seattle';
- """
-)
-
-
-
-# Cell 1
-import marimo as mo
-
-# Cell 2
-
-mo.md(
- r"""
-The quadratic function $f$ is defined as
+```
+
+@app.cell
+def _():
+ import marimo as mo
+ import polars as pl
+ return
+
+@app.cell
+def _():
+ weather = pl.read_csv('')
+ return
+
+@app.cell
+def _():
+ seattle_weather_df = mo.sql(
+ f"""
+ SELECT * FROM weather WHERE location = 'Seattle';
+ """
+ )
+ return
-$$f(x) = x^2.$$
-"""
-)
+```
diff --git a/marimo/_ai/_tools/base.py b/marimo/_ai/_tools/base.py
index d6a9ffc2f32..0fa688e0b17 100644
--- a/marimo/_ai/_tools/base.py
+++ b/marimo/_ai/_tools/base.py
@@ -1,7 +1,6 @@
# Copyright 2025 Marimo. All rights reserved.
from __future__ import annotations
-import dataclasses
import inspect
import re
from abc import ABC, abstractmethod
@@ -237,7 +236,7 @@ def as_backend_tool(
# helpers
def _coerce_args(self, args: Any) -> ArgsT: # type: ignore[override]
"""If Args is a dataclass and args is a dict, construct it; else pass through."""
- if dataclasses.is_dataclass(args):
+ if is_dataclass(args):
# Already parsed
return args # type: ignore[return-value]
return parse_raw(args, self.Args)
diff --git a/marimo/_ai/_tools/tools/rules.py b/marimo/_ai/_tools/tools/rules.py
new file mode 100644
index 00000000000..62ea2eccbba
--- /dev/null
+++ b/marimo/_ai/_tools/tools/rules.py
@@ -0,0 +1,63 @@
+# Copyright 2025 Marimo. All rights reserved.
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Optional
+
+import marimo._utils.requests as requests
+from marimo import _loggers
+from marimo._ai._tools.base import ToolBase
+from marimo._ai._tools.types import EmptyArgs, SuccessResult
+
+LOGGER = _loggers.marimo_logger()
+
+# We load the rules remotely, so we can update these without requiring a new release.
+# If requested, we can bundle this into the library instead.
+MARIMO_RULES_URL = "https://docs.marimo.io/CLAUDE.md"
+
+
+@dataclass
+class GetMarimoRulesOutput(SuccessResult):
+ rules_content: Optional[str] = None
+ source_url: str = MARIMO_RULES_URL
+
+
+class GetMarimoRules(ToolBase[EmptyArgs, GetMarimoRulesOutput]):
+ """Get the official marimo rules and guidelines for AI assistants.
+
+ Returns:
+ The content of the rules file.
+ """
+
+ def handle(self, args: EmptyArgs) -> GetMarimoRulesOutput:
+ del args
+
+ try:
+ response = requests.get(MARIMO_RULES_URL, timeout=10)
+ response.raise_for_status()
+
+ return GetMarimoRulesOutput(
+ rules_content=response.text(),
+ source_url=MARIMO_RULES_URL,
+ next_steps=[
+ "Follow the guidelines in the rules when working with marimo notebooks",
+ ],
+ )
+
+ except Exception as e:
+ LOGGER.warning(
+ "Failed to fetch marimo rules from %s: %s",
+ MARIMO_RULES_URL,
+ str(e),
+ )
+
+ return GetMarimoRulesOutput(
+ status="error",
+ message=f"Failed to fetch marimo rules: {str(e)}",
+ source_url=MARIMO_RULES_URL,
+ next_steps=[
+ "Check internet connectivity",
+ "Verify the rules URL is accessible",
+ "Try again later if the service is temporarily unavailable",
+ ],
+ )
diff --git a/marimo/_ai/_tools/tools_registry.py b/marimo/_ai/_tools/tools_registry.py
index 8d6729a05d2..3a6000c39af 100644
--- a/marimo/_ai/_tools/tools_registry.py
+++ b/marimo/_ai/_tools/tools_registry.py
@@ -9,9 +9,11 @@
from marimo._ai._tools.tools.datasource import GetDatabaseTables
from marimo._ai._tools.tools.errors import GetNotebookErrors
from marimo._ai._tools.tools.notebooks import GetActiveNotebooks
+from marimo._ai._tools.tools.rules import GetMarimoRules
from marimo._ai._tools.tools.tables_and_variables import GetTablesAndVariables
SUPPORTED_BACKEND_AND_MCP_TOOLS: list[type[ToolBase[Any, Any]]] = [
+ GetMarimoRules,
GetActiveNotebooks,
GetCellRuntimeData,
GetLightweightCellMap,
diff --git a/marimo/_mcp/server/lifespan.py b/marimo/_mcp/server/lifespan.py
index d81e36f7c05..804ff35fabb 100644
--- a/marimo/_mcp/server/lifespan.py
+++ b/marimo/_mcp/server/lifespan.py
@@ -4,7 +4,6 @@
from typing import TYPE_CHECKING
from marimo._loggers import marimo_logger
-from marimo._mcp.server.main import setup_mcp_server
LOGGER = marimo_logger()
@@ -17,11 +16,15 @@ async def mcp_server_lifespan(app: "Starlette") -> AsyncIterator[None]:
"""Lifespan for MCP server functionality (exposing marimo as MCP server)."""
try:
- session_manager = setup_mcp_server(app)
+ mcp_app = app.state.mcp
+ if mcp_app is None:
+ LOGGER.warning("MCP server not found in app state")
+ yield
+ return
- async with session_manager.run():
+ # Session manager owns request lifecycle during app run
+ async with mcp_app.session_manager.run():
LOGGER.info("MCP server session manager started")
- # Session manager owns request lifecycle during app run
yield
except ImportError as e:
diff --git a/marimo/_mcp/server/main.py b/marimo/_mcp/server/main.py
index d27166f9422..0538ec2293a 100644
--- a/marimo/_mcp/server/main.py
+++ b/marimo/_mcp/server/main.py
@@ -8,8 +8,6 @@
from typing import TYPE_CHECKING
-from mcp.server.fastmcp import FastMCP
-
from marimo._ai._tools.base import ToolContext
from marimo._ai._tools.tools_registry import SUPPORTED_BACKEND_AND_MCP_TOOLS
from marimo._loggers import marimo_logger
@@ -18,11 +16,10 @@
if TYPE_CHECKING:
- from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.applications import Starlette
-def setup_mcp_server(app: "Starlette") -> "StreamableHTTPSessionManager":
+def setup_mcp_server(app: "Starlette") -> None:
"""Create and configure MCP server for marimo integration.
Args:
@@ -33,17 +30,20 @@ def setup_mcp_server(app: "Starlette") -> "StreamableHTTPSessionManager":
Returns:
StreamableHTTPSessionManager: MCP session manager
"""
+ from mcp.server.fastmcp import FastMCP
+ from starlette.middleware.base import BaseHTTPMiddleware
+ from starlette.responses import JSONResponse
from starlette.routing import Mount
+ from starlette.types import Receive, Scope, Send
mcp = FastMCP(
"marimo-mcp-server",
stateless_http=True,
log_level="WARNING",
+ # Change base path from /mcp to /server
+ streamable_http_path="/server",
)
- # Change base path from /mcp to /server
- mcp.settings.streamable_http_path = "/server"
-
# Register all tools
context = ToolContext(app=app)
for tool in SUPPORTED_BACKEND_AND_MCP_TOOLS:
@@ -53,7 +53,23 @@ def setup_mcp_server(app: "Starlette") -> "StreamableHTTPSessionManager":
# Initialize streamable HTTP app
mcp_app = mcp.streamable_http_app()
+ # Middleware to require edit scope
+ class RequiresEditMiddleware(BaseHTTPMiddleware):
+ async def __call__(
+ self, scope: Scope, receive: Receive, send: Send
+ ) -> None:
+ auth = scope.get("auth")
+ if auth is None or "edit" not in auth.scopes:
+ response = JSONResponse(
+ {"detail": "Forbidden"},
+ status_code=403,
+ )
+ return await response(scope, receive, send)
+
+ return await self.app(scope, receive, send)
+
+ mcp_app.add_middleware(RequiresEditMiddleware)
+
# Add to the top of the routes to avoid conflicts with other routes
app.routes.insert(0, Mount("/mcp", mcp_app))
-
- return mcp.session_manager
+ app.state.mcp = mcp
diff --git a/marimo/_server/print.py b/marimo/_server/print.py
index 5c29264741f..6cce28cb398 100644
--- a/marimo/_server/print.py
+++ b/marimo/_server/print.py
@@ -163,14 +163,18 @@ def print_mcp_server(mcp_url: str, server_token: str | None) -> None:
"""Print MCP server configuration when MCP is enabled."""
print_()
print_tabbed(
- f"{_utf8('🔗')} {green('Experimental MCP Server Configuration', bold=True)}"
+ f"{_utf8('🔗')} {green('Experimental MCP server configuration', bold=True)}"
)
print_tabbed(
- f"{_utf8('➜')} {green('MCP Server URL')}: {_colorized_url(mcp_url)}"
+ f"{_utf8('➜')} {green('MCP server URL')}: {_colorized_url(mcp_url)}"
+ )
+ # Add to Claude code
+ print_tabbed(
+ f"{_utf8('➜')} {green('Add to Claude Code')}: claude mcp add --transport http marimo {mcp_url}"
)
if server_token is not None:
print_tabbed(
- f"{_utf8('➜')} {green('Add Header')}: Marimo-Server-Token: {muted(server_token)}"
+ f"{_utf8('➜')} {green('Add header')}: Marimo-Server-Token: {muted(server_token)}"
)
print_()
diff --git a/marimo/_server/start.py b/marimo/_server/start.py
index 6bd169b3611..8631ff4e249 100644
--- a/marimo/_server/start.py
+++ b/marimo/_server/start.py
@@ -13,6 +13,7 @@
from marimo._cli.print import echo
from marimo._config.manager import get_default_config_manager
from marimo._config.settings import GLOBAL_SETTINGS
+from marimo._mcp.server.main import setup_mcp_server
from marimo._messaging.ops import StartupLogs
from marimo._runtime.requests import SerializedCLIArgs
from marimo._server.file_router import AppFileRouter
@@ -174,6 +175,16 @@ def start(
Start the server.
"""
+ # Defaults when mcp is enabled
+ if mcp:
+ # Turn on watch mode
+ watch = True
+ # Turn off skew protection for MCP server
+ # since it is more convenient to connect to.
+ # Skew protection is not a security thing, but rather
+ # prevents connecting to old servers.
+ skew_protection = False
+
# Find a free port if none is specified
# if the user specifies a port, we don't try to find a free one
port = port or find_free_port(DEFAULT_PORT, addr=host)
@@ -240,7 +251,9 @@ def start(
*LIFESPAN_REGISTRY.get_all(),
]
- if mcp and mode == SessionMode.EDIT:
+ mcp_enabled = mcp and mode == SessionMode.EDIT
+
+ if mcp_enabled:
from marimo._mcp.server.lifespan import mcp_server_lifespan
lifespans_list.append(mcp_server_lifespan)
@@ -260,6 +273,9 @@ def start(
timeout=timeout,
)
+ if mcp_enabled:
+ setup_mcp_server(app)
+
app.state.port = external_port
app.state.host = external_host
diff --git a/marimo/_utils/file_watcher.py b/marimo/_utils/file_watcher.py
index 3d0f809058a..7c6a98f07df 100644
--- a/marimo/_utils/file_watcher.py
+++ b/marimo/_utils/file_watcher.py
@@ -24,7 +24,7 @@ def create(path: Path, callback: Callback) -> FileWatcher:
LOGGER.debug("Using watchdog file watcher")
return _create_watchdog(path, callback, asyncio.get_event_loop())
else:
- LOGGER.warning(
+ LOGGER.info(
"watchdog is not installed, using polling file watcher"
)
return PollingFileWatcher(path, callback, asyncio.get_event_loop())
diff --git a/tests/_ai/tools/tools/test_rules.py b/tests/_ai/tools/tools/test_rules.py
new file mode 100644
index 00000000000..d966987b29a
--- /dev/null
+++ b/tests/_ai/tools/tools/test_rules.py
@@ -0,0 +1,79 @@
+# Copyright 2025 Marimo. All rights reserved.
+
+from __future__ import annotations
+
+from unittest.mock import Mock, patch
+
+import pytest
+
+from marimo._ai._tools.base import ToolContext
+from marimo._ai._tools.tools.rules import GetMarimoRules
+from marimo._ai._tools.types import EmptyArgs
+
+
+@pytest.fixture
+def tool() -> GetMarimoRules:
+ """Create a GetMarimoRules tool instance."""
+ return GetMarimoRules(ToolContext())
+
+
+def test_get_rules_success(tool: GetMarimoRules) -> None:
+ """Test successfully fetching marimo rules."""
+ mock_response = Mock()
+ mock_response.text.return_value = "# Marimo Rules\n\nTest content"
+ mock_response.raise_for_status = Mock()
+
+ with patch("marimo._utils.requests.get", return_value=mock_response):
+ result = tool.handle(EmptyArgs())
+
+ assert result.status == "success"
+ assert result.rules_content == "# Marimo Rules\n\nTest content"
+ assert result.source_url == "https://docs.marimo.io/CLAUDE.md"
+ assert len(result.next_steps) == 1
+ assert "Follow the guidelines" in result.next_steps[0]
+ mock_response.raise_for_status.assert_called_once()
+
+
+def test_get_rules_http_error(tool: GetMarimoRules) -> None:
+ """Test handling HTTP errors when fetching rules."""
+ mock_response = Mock()
+ mock_response.raise_for_status.side_effect = Exception("404 Not Found")
+
+ with patch("marimo._utils.requests.get", return_value=mock_response):
+ result = tool.handle(EmptyArgs())
+
+ assert result.status == "error"
+ assert result.rules_content is None
+ assert "Failed to fetch marimo rules" in result.message
+ assert "404 Not Found" in result.message
+ assert result.source_url == "https://docs.marimo.io/CLAUDE.md"
+ assert len(result.next_steps) == 3
+ assert "Check internet connectivity" in result.next_steps[0]
+
+
+def test_get_rules_network_error(tool: GetMarimoRules) -> None:
+ """Test handling network errors when fetching rules."""
+ with patch(
+ "marimo._utils.requests.get",
+ side_effect=Exception("Connection refused"),
+ ):
+ result = tool.handle(EmptyArgs())
+
+ assert result.status == "error"
+ assert result.rules_content is None
+ assert "Failed to fetch marimo rules" in result.message
+ assert "Connection refused" in result.message
+ assert len(result.next_steps) == 3
+
+
+def test_get_rules_timeout(tool: GetMarimoRules) -> None:
+ """Test handling timeout when fetching rules."""
+ with patch(
+ "marimo._utils.requests.get",
+ side_effect=Exception("Request timeout"),
+ ):
+ result = tool.handle(EmptyArgs())
+
+ assert result.status == "error"
+ assert result.rules_content is None
+ assert "Request timeout" in result.message
diff --git a/tests/_mcp/server/test_main.py b/tests/_mcp/server/test_main.py
deleted file mode 100644
index 31a95c9791a..00000000000
--- a/tests/_mcp/server/test_main.py
+++ /dev/null
@@ -1,11 +0,0 @@
-# Copyright 2024 Marimo. All rights reserved.
-import pytest
-
-pytest.importorskip("mcp", reason="MCP requires Python 3.10+")
-
-# TODO: Currently researching best practices for how to test MCP Servers in memory
-# Need to investigate:
-# - How to create an in-memory MCP server instance for testing
-# - Best practices for mocking MCP client-server communication
-# - Testing MCP protocol compliance and tool execution
-# - Integration testing patterns for MCP servers
diff --git a/tests/_mcp/server/test_mcp_server.py b/tests/_mcp/server/test_mcp_server.py
new file mode 100644
index 00000000000..67b58fd9d52
--- /dev/null
+++ b/tests/_mcp/server/test_mcp_server.py
@@ -0,0 +1,101 @@
+# Copyright 2024 Marimo. All rights reserved.
+import pytest
+
+from marimo._mcp.server.lifespan import mcp_server_lifespan
+
+pytest.importorskip("mcp", reason="MCP requires Python 3.10+")
+
+from starlette.applications import Starlette
+from starlette.authentication import AuthCredentials, SimpleUser
+from starlette.middleware import Middleware
+from starlette.middleware.authentication import AuthenticationMiddleware
+from starlette.requests import HTTPConnection
+from starlette.testclient import TestClient
+
+from marimo._mcp.server.main import setup_mcp_server
+from marimo._server.api.middleware import AuthBackend
+from tests._server.mocks import get_mock_session_manager
+
+
+def create_test_app() -> Starlette:
+ """Create a test Starlette app with MCP server."""
+ app = Starlette(
+ middleware=[
+ Middleware(
+ AuthenticationMiddleware,
+ backend=AuthBackend(should_authenticate=False),
+ ),
+ ],
+ )
+ app.state.session_manager = get_mock_session_manager()
+ setup_mcp_server(app)
+ return app
+
+
+def test_mcp_server_starts_up():
+ """Test that MCP server can be set up and routes are registered."""
+ app = create_test_app()
+ client = TestClient(app)
+
+ # Verify the MCP server is mounted
+ assert hasattr(app.state, "mcp")
+
+ # Verify /mcp route exists
+ assert any("/mcp" in str(route.path) for route in app.routes)
+
+
+async def test_mcp_server_requires_edit_scope():
+ """Test that MCP server validates 'edit' scope is present."""
+ app = create_test_app()
+
+ # Mock a request without edit scope
+ class MockAuthBackend:
+ async def authenticate(self, conn: HTTPConnection):
+ del conn
+ # Return user without edit scope
+ return AuthCredentials(scopes=["read"]), SimpleUser("test_user")
+
+ # Create app with authentication that doesn't include edit scope
+ app_no_edit = Starlette(
+ middleware=[
+ Middleware(
+ AuthenticationMiddleware,
+ backend=MockAuthBackend(),
+ ),
+ ],
+ )
+ app_no_edit.state.session_manager = get_mock_session_manager()
+ setup_mcp_server(app_no_edit)
+
+ client = TestClient(app_no_edit, raise_server_exceptions=False)
+
+ # Try to access MCP endpoint without edit scope
+ response = client.get("/mcp/server")
+ assert response.status_code == 403
+
+ # Mock a request with edit scope
+ class MockAuthBackendWithEdit:
+ async def authenticate(self, conn: HTTPConnection):
+ del conn
+ # Return user with edit scope
+ return AuthCredentials(scopes=["edit"]), SimpleUser("test_user")
+
+ # Create app with edit scope
+ app_with_edit = Starlette(
+ middleware=[
+ Middleware(
+ AuthenticationMiddleware,
+ backend=MockAuthBackendWithEdit(),
+ ),
+ ],
+ )
+
+ setup_mcp_server(app_with_edit)
+ async with mcp_server_lifespan(app_with_edit):
+ app_with_edit.state.session_manager = get_mock_session_manager()
+
+ client_with_edit = TestClient(app_with_edit)
+
+ # Access should not be forbidden (may get other status codes based on MCP protocol)
+ response = client_with_edit.get("/mcp/server")
+ assert response.status_code != 403