Skip to content

Commit 325eb40

Browse files
authored
improvement: DRY GetActiveNotebooks tool and ActiveNotebooks prompt and improve outputs (#6976)
## 📝 Summary <!-- Provide a concise summary of what this pull request is addressing. If this PR fixes any issues, list them here by number (e.g., Fixes #123). --> This moves getting all active notebooks to ToolContext so it can be accessed by both GetActiveNotebooks tool and ActiveNotebooks prompt. ## 🔍 Description of Changes <!-- Detail the specific changes made in this pull request. Explain the problem addressed and how it was resolved. If applicable, provide before and after comparisons, screenshots, or any relevant details to help reviewers understand the changes easily. --> It also does the following: - Add better handling of no filepath (for new notebook that hasn't been saved yet) - Remove unused or redundant notebook info for both tool and prompt - Create a base MarimoNotebookInfo dataclass where sessionId is not Optional - Change next-steps in GetActiveNotebooks from GetCellRuntimeData to GetNotebookErrors since it makes more sense - Change from using filename (or relative filepath) to always using absolute filepath for agent edit notebook tools - Improve GetActiveNotebooks tests ## 📋 Checklist - [x] I have read the [contributor guidelines](https://github.com/marimo-team/marimo/blob/main/CONTRIBUTING.md). - [ ] For large changes, or changes that affect the public API: this change was discussed or approved through an issue, on [Discord](https://marimo.io/discord?ref=pr), or the community [discussions](https://github.com/marimo-team/marimo/discussions) (Please provide a link if applicable). - [x] I have added tests for the changes made. - [x] I have run the code and verified that it works as expected.
1 parent e6433bc commit 325eb40

File tree

5 files changed

+211
-172
lines changed

5 files changed

+211
-172
lines changed

marimo/_ai/_tools/base.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818

1919
from marimo import _loggers
20-
from marimo._ai._tools.types import ToolGuidelines
20+
from marimo._ai._tools.types import MarimoNotebookInfo, ToolGuidelines
2121
from marimo._ai._tools.utils.exceptions import ToolExecutionError
2222
from marimo._config.config import CopilotMode
2323
from marimo._server.ai.tools.types import (
@@ -26,6 +26,7 @@
2626
ValidationFunction,
2727
)
2828
from marimo._server.api.deps import AppStateBase
29+
from marimo._server.model import ConnectionState
2930
from marimo._server.sessions import Session, SessionManager
3031
from marimo._types.ids import SessionId
3132
from marimo._utils.case import to_snake_case
@@ -80,6 +81,38 @@ def get_session(self, session_id: SessionId) -> Session:
8081
)
8182
return session_manager.sessions[session_id]
8283

84+
def get_active_sessions_internal(self) -> list[MarimoNotebookInfo]:
85+
"""
86+
Get active sessions from the app state.
87+
88+
This follows the logic from marimo/_server/api/endpoints/home.py
89+
"""
90+
import os
91+
92+
UNSAVED_NOTEBOOK_MESSAGE = (
93+
"(unsaved notebook - save to disk to get file path)"
94+
)
95+
files: list[MarimoNotebookInfo] = []
96+
for session_id, session in self.session_manager.sessions.items():
97+
state = session.connection_state()
98+
if (
99+
state == ConnectionState.OPEN
100+
or state == ConnectionState.ORPHANED
101+
):
102+
full_file_path = session.app_file_manager.path
103+
filename = session.app_file_manager.filename
104+
basename = os.path.basename(filename) if filename else None
105+
files.append(
106+
MarimoNotebookInfo(
107+
name=(basename or "new notebook"),
108+
# file path should be absolute path for agent-based edit tools
109+
path=(full_file_path or UNSAVED_NOTEBOOK_MESSAGE),
110+
session_id=session_id,
111+
)
112+
)
113+
# Return most recent notebooks first (reverse chronological order)
114+
return files[::-1]
115+
83116

84117
class ToolBase(Generic[ArgsT, OutT], ABC):
85118
"""

marimo/_ai/_tools/tools/notebooks.py

Lines changed: 7 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,31 @@
22
from __future__ import annotations
33

44
from dataclasses import dataclass, field
5-
from typing import Optional
65

76
from marimo._ai._tools.base import ToolBase
87
from marimo._ai._tools.types import (
98
EmptyArgs,
9+
MarimoNotebookInfo,
1010
SuccessResult,
1111
ToolGuidelines,
1212
)
13-
from marimo._server.model import ConnectionState
14-
from marimo._server.models.home import MarimoFile
15-
from marimo._server.sessions import SessionManager
16-
from marimo._types.ids import SessionId
17-
from marimo._utils.paths import pretty_path
18-
19-
20-
@dataclass
21-
class NotebookInfo:
22-
name: str
23-
path: str
24-
session_id: Optional[SessionId] = None
25-
initialization_id: Optional[str] = None
2613

2714

2815
@dataclass
2916
class SummaryInfo:
3017
total_notebooks: int
31-
total_sessions: int
3218
active_connections: int
3319

3420

3521
@dataclass
3622
class GetActiveNotebooksData:
3723
summary: SummaryInfo
38-
notebooks: list[NotebookInfo]
24+
notebooks: list[MarimoNotebookInfo]
3925

4026

4127
def _default_active_notebooks_data() -> GetActiveNotebooksData:
4228
return GetActiveNotebooksData(
43-
summary=SummaryInfo(
44-
total_notebooks=0, total_sessions=0, active_connections=0
45-
),
29+
summary=SummaryInfo(total_notebooks=0, active_connections=0),
4630
notebooks=[],
4731
)
4832

@@ -73,69 +57,19 @@ def handle(self, args: EmptyArgs) -> GetActiveNotebooksOutput:
7357
del args
7458
context = self.context
7559
session_manager = context.session_manager
76-
active_files = self._get_active_sessions_internal(session_manager)
77-
78-
# Build notebooks list
79-
notebooks: list[NotebookInfo] = []
80-
for file_info in active_files:
81-
notebooks.append(
82-
NotebookInfo(
83-
name=file_info.name,
84-
path=file_info.path,
85-
session_id=file_info.session_id,
86-
initialization_id=file_info.initialization_id,
87-
)
88-
)
89-
90-
# Build summary statistics
60+
notebooks = context.get_active_sessions_internal()
61+
9162
summary: SummaryInfo = SummaryInfo(
92-
total_notebooks=len(active_files),
93-
total_sessions=len(session_manager.sessions),
63+
total_notebooks=len(notebooks),
9464
active_connections=session_manager.get_active_connection_count(),
9565
)
9666

97-
# Build data object
9867
data = GetActiveNotebooksData(summary=summary, notebooks=notebooks)
9968

100-
# Return a success result with summary statistics and notebook details
10169
return GetActiveNotebooksOutput(
10270
data=data,
10371
next_steps=[
10472
"Use the `get_lightweight_cell_map` tool to get the content of a notebook",
105-
"Use the `get_cell_runtime_data` tool to get the code, errors, and variables of a cell if you already have the cell id",
73+
"Use the `get_notebook_errors` tool to help debug errors in the notebook",
10674
],
10775
)
108-
109-
# helper methods
110-
111-
def _get_active_sessions_internal(
112-
self, session_manager: SessionManager
113-
) -> list[MarimoFile]:
114-
"""
115-
Get active sessions from the app state.
116-
117-
This replicates the logic from marimo/_server/api/endpoints/home.py
118-
"""
119-
import os
120-
121-
files: list[MarimoFile] = []
122-
for session_id, session in session_manager.sessions.items():
123-
state = session.connection_state()
124-
if (
125-
state == ConnectionState.OPEN
126-
or state == ConnectionState.ORPHANED
127-
):
128-
filename = session.app_file_manager.filename
129-
basename = os.path.basename(filename) if filename else None
130-
files.append(
131-
MarimoFile(
132-
name=(basename or "new notebook"),
133-
path=(
134-
pretty_path(filename) if filename else session_id
135-
),
136-
session_id=session_id,
137-
initialization_id=session.initialization_id,
138-
)
139-
)
140-
# Return most recent notebooks first (reverse chronological order)
141-
return files[::-1]

marimo/_ai/_tools/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from dataclasses import dataclass
55
from typing import Any, Literal, Optional
66

7+
from marimo._types.ids import SessionId
8+
79
# helper classes
810
StatusValue = Literal["success", "error", "warning"]
911

@@ -32,3 +34,10 @@ class ToolGuidelines:
3234
prerequisites: Optional[list[str]] = None
3335
side_effects: Optional[list[str]] = None
3436
additional_info: Optional[str] = None
37+
38+
39+
@dataclass
40+
class MarimoNotebookInfo:
41+
name: str
42+
path: str
43+
session_id: SessionId

marimo/_mcp/server/_prompts/prompts/notebooks.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@ def handle(self) -> list[PromptMessage]:
2222
"""
2323
from mcp.types import PromptMessage, TextContent
2424

25-
session_manager = self.context.session_manager
25+
context = self.context
26+
notebooks = context.get_active_sessions_internal()
2627

27-
# Get all active sessions
28-
sessions = session_manager.sessions
29-
30-
if not sessions:
28+
if len(notebooks) == 0:
3129
return [
3230
PromptMessage(
3331
role="user",
@@ -38,34 +36,36 @@ def handle(self) -> list[PromptMessage]:
3836
)
3937
]
4038

41-
# Create a message for each session
42-
messages: list[PromptMessage] = []
43-
for session_id, session in sessions.items():
44-
# Get file path if available
45-
maybe_file_path = session.app_file_manager.filename
46-
47-
# Create actionable message for this session
48-
if maybe_file_path:
49-
message = (
50-
f"Notebook session ID: {session_id}\n"
51-
f"Notebook file path: {maybe_file_path}\n\n"
52-
f"Use this session_id when calling MCP tools that require it. "
53-
f"You can also edit the notebook directly by modifying the file at the path above."
54-
)
55-
else:
56-
message = (
57-
f"Notebook session ID: {session_id}\n\n"
58-
f"Use this session_id when calling MCP tools that require it."
59-
)
39+
session_messages: list[PromptMessage] = []
40+
for active_file in notebooks:
41+
session_message = (
42+
f"Notebook session ID: {active_file.session_id}\n"
43+
f"Notebook file path: {active_file.path}\n\n"
44+
)
6045

61-
messages.append(
46+
session_messages.append(
6247
PromptMessage(
6348
role="user",
6449
content=TextContent(
6550
type="text",
66-
text=message,
51+
text=session_message,
6752
),
6853
)
6954
)
7055

71-
return messages
56+
action_message = (
57+
f"Use {'this session_id' if len(notebooks) == 1 else 'these session_ids'} when calling marimo MCP tools that require it."
58+
f"You can also edit {'this notebook' if len(notebooks) == 1 else 'these notebooks'} directly by modifying the files at the paths above."
59+
)
60+
61+
session_messages.append(
62+
PromptMessage(
63+
role="user",
64+
content=TextContent(
65+
type="text",
66+
text=action_message,
67+
),
68+
)
69+
)
70+
71+
return session_messages

0 commit comments

Comments
 (0)