Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion marimo/_ai/_tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)

from marimo import _loggers
from marimo._ai._tools.types import ToolGuidelines
from marimo._ai._tools.types import MarimoNotebookInfo, ToolGuidelines
from marimo._ai._tools.utils.exceptions import ToolExecutionError
from marimo._config.config import CopilotMode
from marimo._server.ai.tools.types import (
Expand All @@ -26,6 +26,7 @@
ValidationFunction,
)
from marimo._server.api.deps import AppStateBase
from marimo._server.model import ConnectionState
from marimo._server.sessions import Session, SessionManager
from marimo._types.ids import SessionId
from marimo._utils.case import to_snake_case
Expand Down Expand Up @@ -80,6 +81,38 @@ def get_session(self, session_id: SessionId) -> Session:
)
return session_manager.sessions[session_id]

def get_active_sessions_internal(self) -> list[MarimoNotebookInfo]:
"""
Get active sessions from the app state.

This follows the logic from marimo/_server/api/endpoints/home.py
"""
import os

UNSAVED_NOTEBOOK_MESSAGE = (
"(unsaved notebook - save to disk to get file path)"
)
files: list[MarimoNotebookInfo] = []
for session_id, session in self.session_manager.sessions.items():
state = session.connection_state()
if (
state == ConnectionState.OPEN
or state == ConnectionState.ORPHANED
):
full_file_path = session.app_file_manager.path
filename = session.app_file_manager.filename
basename = os.path.basename(filename) if filename else None
files.append(
MarimoNotebookInfo(
name=(basename or "new notebook"),
# file path should be absolute path for agent-based edit tools
path=(full_file_path or UNSAVED_NOTEBOOK_MESSAGE),
session_id=session_id,
)
)
# Return most recent notebooks first (reverse chronological order)
return files[::-1]


class ToolBase(Generic[ArgsT, OutT], ABC):
"""
Expand Down
80 changes: 7 additions & 73 deletions marimo/_ai/_tools/tools/notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,31 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Optional

from marimo._ai._tools.base import ToolBase
from marimo._ai._tools.types import (
EmptyArgs,
MarimoNotebookInfo,
SuccessResult,
ToolGuidelines,
)
from marimo._server.model import ConnectionState
from marimo._server.models.home import MarimoFile
from marimo._server.sessions import SessionManager
from marimo._types.ids import SessionId
from marimo._utils.paths import pretty_path


@dataclass
class NotebookInfo:
name: str
path: str
session_id: Optional[SessionId] = None
initialization_id: Optional[str] = None


@dataclass
class SummaryInfo:
total_notebooks: int
total_sessions: int
active_connections: int


@dataclass
class GetActiveNotebooksData:
summary: SummaryInfo
notebooks: list[NotebookInfo]
notebooks: list[MarimoNotebookInfo]


def _default_active_notebooks_data() -> GetActiveNotebooksData:
return GetActiveNotebooksData(
summary=SummaryInfo(
total_notebooks=0, total_sessions=0, active_connections=0
),
summary=SummaryInfo(total_notebooks=0, active_connections=0),
notebooks=[],
)

Expand Down Expand Up @@ -73,69 +57,19 @@ def handle(self, args: EmptyArgs) -> GetActiveNotebooksOutput:
del args
context = self.context
session_manager = context.session_manager
active_files = self._get_active_sessions_internal(session_manager)

# Build notebooks list
notebooks: list[NotebookInfo] = []
for file_info in active_files:
notebooks.append(
NotebookInfo(
name=file_info.name,
path=file_info.path,
session_id=file_info.session_id,
initialization_id=file_info.initialization_id,
)
)

# Build summary statistics
notebooks = context.get_active_sessions_internal()

summary: SummaryInfo = SummaryInfo(
total_notebooks=len(active_files),
total_sessions=len(session_manager.sessions),
total_notebooks=len(notebooks),
active_connections=session_manager.get_active_connection_count(),
)

# Build data object
data = GetActiveNotebooksData(summary=summary, notebooks=notebooks)

# Return a success result with summary statistics and notebook details
return GetActiveNotebooksOutput(
data=data,
next_steps=[
"Use the `get_lightweight_cell_map` tool to get the content of a notebook",
"Use the `get_cell_runtime_data` tool to get the code, errors, and variables of a cell if you already have the cell id",
"Use the `get_notebook_errors` tool to help debug errors in the notebook",
],
)

# helper methods

def _get_active_sessions_internal(
self, session_manager: SessionManager
) -> list[MarimoFile]:
"""
Get active sessions from the app state.

This replicates the logic from marimo/_server/api/endpoints/home.py
"""
import os

files: list[MarimoFile] = []
for session_id, session in session_manager.sessions.items():
state = session.connection_state()
if (
state == ConnectionState.OPEN
or state == ConnectionState.ORPHANED
):
filename = session.app_file_manager.filename
basename = os.path.basename(filename) if filename else None
files.append(
MarimoFile(
name=(basename or "new notebook"),
path=(
pretty_path(filename) if filename else session_id
),
session_id=session_id,
initialization_id=session.initialization_id,
)
)
# Return most recent notebooks first (reverse chronological order)
return files[::-1]
9 changes: 9 additions & 0 deletions marimo/_ai/_tools/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from dataclasses import dataclass
from typing import Any, Literal, Optional

from marimo._types.ids import SessionId

# helper classes
StatusValue = Literal["success", "error", "warning"]

Expand Down Expand Up @@ -32,3 +34,10 @@ class ToolGuidelines:
prerequisites: Optional[list[str]] = None
side_effects: Optional[list[str]] = None
additional_info: Optional[str] = None


@dataclass
class MarimoNotebookInfo:
name: str
path: str
session_id: SessionId
54 changes: 27 additions & 27 deletions marimo/_mcp/server/_prompts/prompts/notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ def handle(self) -> list[PromptMessage]:
"""
from mcp.types import PromptMessage, TextContent

session_manager = self.context.session_manager
context = self.context
notebooks = context.get_active_sessions_internal()

# Get all active sessions
sessions = session_manager.sessions

if not sessions:
if len(notebooks) == 0:
return [
PromptMessage(
role="user",
Expand All @@ -38,34 +36,36 @@ def handle(self) -> list[PromptMessage]:
)
]

# Create a message for each session
messages: list[PromptMessage] = []
for session_id, session in sessions.items():
# Get file path if available
maybe_file_path = session.app_file_manager.filename

# Create actionable message for this session
if maybe_file_path:
message = (
f"Notebook session ID: {session_id}\n"
f"Notebook file path: {maybe_file_path}\n\n"
f"Use this session_id when calling MCP tools that require it. "
f"You can also edit the notebook directly by modifying the file at the path above."
)
else:
message = (
f"Notebook session ID: {session_id}\n\n"
f"Use this session_id when calling MCP tools that require it."
)
session_messages: list[PromptMessage] = []
for active_file in notebooks:
session_message = (
f"Notebook session ID: {active_file.session_id}\n"
f"Notebook file path: {active_file.path}\n\n"
)

messages.append(
session_messages.append(
PromptMessage(
role="user",
content=TextContent(
type="text",
text=message,
text=session_message,
),
)
)

return messages
action_message = (
f"Use {'this session_id' if len(notebooks) == 1 else 'these session_ids'} when calling marimo MCP tools that require it."
f"You can also edit {'this notebook' if len(notebooks) == 1 else 'these notebooks'} directly by modifying the files at the paths above."
)

session_messages.append(
PromptMessage(
role="user",
content=TextContent(
type="text",
text=action_message,
),
)
)

return session_messages
Loading
Loading