diff --git a/marimo/_ai/_tools/base.py b/marimo/_ai/_tools/base.py index 42bb7c279d4..d1b2cd70f97 100644 --- a/marimo/_ai/_tools/base.py +++ b/marimo/_ai/_tools/base.py @@ -17,7 +17,7 @@ ) from marimo import _loggers -from marimo._ai._tools.types import ToolGuidelines +from marimo._ai._tools.types import MarimoNotebookInfo, ToolGuidelines from marimo._ai._tools.utils.exceptions import ToolExecutionError from marimo._config.config import CopilotMode from marimo._server.ai.tools.types import ( @@ -26,6 +26,7 @@ ValidationFunction, ) from marimo._server.api.deps import AppStateBase +from marimo._server.model import ConnectionState from marimo._server.sessions import Session, SessionManager from marimo._types.ids import SessionId from marimo._utils.case import to_snake_case @@ -80,6 +81,38 @@ def get_session(self, session_id: SessionId) -> Session: ) return session_manager.sessions[session_id] + def get_active_sessions_internal(self) -> list[MarimoNotebookInfo]: + """ + Get active sessions from the app state. + + This follows the logic from marimo/_server/api/endpoints/home.py + """ + import os + + UNSAVED_NOTEBOOK_MESSAGE = ( + "(unsaved notebook - save to disk to get file path)" + ) + files: list[MarimoNotebookInfo] = [] + for session_id, session in self.session_manager.sessions.items(): + state = session.connection_state() + if ( + state == ConnectionState.OPEN + or state == ConnectionState.ORPHANED + ): + full_file_path = session.app_file_manager.path + filename = session.app_file_manager.filename + basename = os.path.basename(filename) if filename else None + files.append( + MarimoNotebookInfo( + name=(basename or "new notebook"), + # file path should be absolute path for agent-based edit tools + path=(full_file_path or UNSAVED_NOTEBOOK_MESSAGE), + session_id=session_id, + ) + ) + # Return most recent notebooks first (reverse chronological order) + return files[::-1] + class ToolBase(Generic[ArgsT, OutT], ABC): """ diff --git a/marimo/_ai/_tools/tools/notebooks.py b/marimo/_ai/_tools/tools/notebooks.py index ca9fe41eee5..c4127e7b21c 100644 --- a/marimo/_ai/_tools/tools/notebooks.py +++ b/marimo/_ai/_tools/tools/notebooks.py @@ -2,47 +2,31 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Optional from marimo._ai._tools.base import ToolBase from marimo._ai._tools.types import ( EmptyArgs, + MarimoNotebookInfo, SuccessResult, ToolGuidelines, ) -from marimo._server.model import ConnectionState -from marimo._server.models.home import MarimoFile -from marimo._server.sessions import SessionManager -from marimo._types.ids import SessionId -from marimo._utils.paths import pretty_path - - -@dataclass -class NotebookInfo: - name: str - path: str - session_id: Optional[SessionId] = None - initialization_id: Optional[str] = None @dataclass class SummaryInfo: total_notebooks: int - total_sessions: int active_connections: int @dataclass class GetActiveNotebooksData: summary: SummaryInfo - notebooks: list[NotebookInfo] + notebooks: list[MarimoNotebookInfo] def _default_active_notebooks_data() -> GetActiveNotebooksData: return GetActiveNotebooksData( - summary=SummaryInfo( - total_notebooks=0, total_sessions=0, active_connections=0 - ), + summary=SummaryInfo(total_notebooks=0, active_connections=0), notebooks=[], ) @@ -73,69 +57,19 @@ def handle(self, args: EmptyArgs) -> GetActiveNotebooksOutput: del args context = self.context session_manager = context.session_manager - active_files = self._get_active_sessions_internal(session_manager) - - # Build notebooks list - notebooks: list[NotebookInfo] = [] - for file_info in active_files: - notebooks.append( - NotebookInfo( - name=file_info.name, - path=file_info.path, - session_id=file_info.session_id, - initialization_id=file_info.initialization_id, - ) - ) - - # Build summary statistics + notebooks = context.get_active_sessions_internal() + summary: SummaryInfo = SummaryInfo( - total_notebooks=len(active_files), - total_sessions=len(session_manager.sessions), + total_notebooks=len(notebooks), active_connections=session_manager.get_active_connection_count(), ) - # Build data object data = GetActiveNotebooksData(summary=summary, notebooks=notebooks) - # Return a success result with summary statistics and notebook details return GetActiveNotebooksOutput( data=data, next_steps=[ "Use the `get_lightweight_cell_map` tool to get the content of a notebook", - "Use the `get_cell_runtime_data` tool to get the code, errors, and variables of a cell if you already have the cell id", + "Use the `get_notebook_errors` tool to help debug errors in the notebook", ], ) - - # helper methods - - def _get_active_sessions_internal( - self, session_manager: SessionManager - ) -> list[MarimoFile]: - """ - Get active sessions from the app state. - - This replicates the logic from marimo/_server/api/endpoints/home.py - """ - import os - - files: list[MarimoFile] = [] - for session_id, session in session_manager.sessions.items(): - state = session.connection_state() - if ( - state == ConnectionState.OPEN - or state == ConnectionState.ORPHANED - ): - filename = session.app_file_manager.filename - basename = os.path.basename(filename) if filename else None - files.append( - MarimoFile( - name=(basename or "new notebook"), - path=( - pretty_path(filename) if filename else session_id - ), - session_id=session_id, - initialization_id=session.initialization_id, - ) - ) - # Return most recent notebooks first (reverse chronological order) - return files[::-1] diff --git a/marimo/_ai/_tools/types.py b/marimo/_ai/_tools/types.py index 75b15cf430f..99e5c3b1bf1 100644 --- a/marimo/_ai/_tools/types.py +++ b/marimo/_ai/_tools/types.py @@ -4,6 +4,8 @@ from dataclasses import dataclass from typing import Any, Literal, Optional +from marimo._types.ids import SessionId + # helper classes StatusValue = Literal["success", "error", "warning"] @@ -32,3 +34,10 @@ class ToolGuidelines: prerequisites: Optional[list[str]] = None side_effects: Optional[list[str]] = None additional_info: Optional[str] = None + + +@dataclass +class MarimoNotebookInfo: + name: str + path: str + session_id: SessionId diff --git a/marimo/_mcp/server/_prompts/prompts/notebooks.py b/marimo/_mcp/server/_prompts/prompts/notebooks.py index 67deb975129..c4bd7f5a55e 100644 --- a/marimo/_mcp/server/_prompts/prompts/notebooks.py +++ b/marimo/_mcp/server/_prompts/prompts/notebooks.py @@ -22,12 +22,10 @@ def handle(self) -> list[PromptMessage]: """ from mcp.types import PromptMessage, TextContent - session_manager = self.context.session_manager + context = self.context + notebooks = context.get_active_sessions_internal() - # Get all active sessions - sessions = session_manager.sessions - - if not sessions: + if len(notebooks) == 0: return [ PromptMessage( role="user", @@ -38,34 +36,36 @@ def handle(self) -> list[PromptMessage]: ) ] - # Create a message for each session - messages: list[PromptMessage] = [] - for session_id, session in sessions.items(): - # Get file path if available - maybe_file_path = session.app_file_manager.filename - - # Create actionable message for this session - if maybe_file_path: - message = ( - f"Notebook session ID: {session_id}\n" - f"Notebook file path: {maybe_file_path}\n\n" - f"Use this session_id when calling MCP tools that require it. " - f"You can also edit the notebook directly by modifying the file at the path above." - ) - else: - message = ( - f"Notebook session ID: {session_id}\n\n" - f"Use this session_id when calling MCP tools that require it." - ) + session_messages: list[PromptMessage] = [] + for active_file in notebooks: + session_message = ( + f"Notebook session ID: {active_file.session_id}\n" + f"Notebook file path: {active_file.path}\n\n" + ) - messages.append( + session_messages.append( PromptMessage( role="user", content=TextContent( type="text", - text=message, + text=session_message, ), ) ) - return messages + action_message = ( + f"Use {'this session_id' if len(notebooks) == 1 else 'these session_ids'} when calling marimo MCP tools that require it." + f"You can also edit {'this notebook' if len(notebooks) == 1 else 'these notebooks'} directly by modifying the files at the paths above." + ) + + session_messages.append( + PromptMessage( + role="user", + content=TextContent( + type="text", + text=action_message, + ), + ) + ) + + return session_messages diff --git a/tests/_ai/tools/tools/test_notebooks.py b/tests/_ai/tools/tools/test_notebooks.py index 7230f4d3088..a364fd4285b 100644 --- a/tests/_ai/tools/tools/test_notebooks.py +++ b/tests/_ai/tools/tools/test_notebooks.py @@ -1,119 +1,174 @@ from __future__ import annotations import os -from unittest.mock import Mock, patch +from dataclasses import dataclass +from unittest.mock import Mock + +import pytest from marimo._ai._tools.base import ToolContext from marimo._ai._tools.tools.notebooks import GetActiveNotebooks +from marimo._ai._tools.types import EmptyArgs, MarimoNotebookInfo from marimo._server.model import ConnectionState +from marimo._types.ids import SessionId + + +@dataclass +class MockAppFileManager: + filename: str | None + path: str | None +@dataclass class MockSession: - def __init__( - self, connection_state, filename=None, session_id="test_session" - ): - self._connection_state = connection_state - self.app_file_manager = Mock() - self.app_file_manager.filename = filename - self.initialization_id = f"init_{session_id}" - - def connection_state(self): + _connection_state: ConnectionState + app_file_manager: MockAppFileManager + + def connection_state(self) -> ConnectionState: return self._connection_state +@dataclass class MockSessionManager: - def __init__(self, sessions=None): - self.sessions = sessions or {} + sessions: dict[str, MockSession] + + def get_active_connection_count(self) -> int: + return len( + [ + s + for s in self.sessions.values() + if s.connection_state() + in (ConnectionState.OPEN, ConnectionState.ORPHANED) + ] + ) + + +@pytest.fixture +def tool() -> GetActiveNotebooks: + """Create a GetActiveNotebooks tool instance.""" + return GetActiveNotebooks(ToolContext()) + + +@pytest.fixture +def mock_context() -> Mock: + """Create a mock ToolContext.""" + context = Mock(spec=ToolContext) + context.get_active_sessions_internal = ( + ToolContext.get_active_sessions_internal + ) + return context -def test_get_active_sessions_internal_empty(): - tool = GetActiveNotebooks(ToolContext()) - result = tool._get_active_sessions_internal(MockSessionManager()) +def test_get_active_sessions_internal_empty(mock_context: Mock): + """Test get_active_sessions_internal with no sessions.""" + mock_context.session_manager = MockSessionManager(sessions={}) + + result = mock_context.get_active_sessions_internal(mock_context) + assert result == [] -def test_get_active_sessions_internal_open_session(): - tool = GetActiveNotebooks(ToolContext()) +def test_get_active_sessions_internal_open_session(mock_context: Mock): + """Test get_active_sessions_internal with one open session.""" session = MockSession( - connection_state=ConnectionState.OPEN, - filename="/path/to/notebook.py", - session_id="session1", + _connection_state=ConnectionState.OPEN, + app_file_manager=MockAppFileManager( + filename="/path/to/notebook.py", + path=os.path.abspath("/path/to/notebook.py"), + ), + ) + mock_context.session_manager = MockSessionManager( + sessions={"session1": session} ) - session_manager = MockSessionManager({"session1": session}) - with patch( - "marimo._ai._tools.tools.notebooks.pretty_path" - ) as mock_pretty_path: - mock_pretty_path.return_value = "notebook.py" - result = tool._get_active_sessions_internal(session_manager) + result = mock_context.get_active_sessions_internal(mock_context) assert len(result) == 1 assert result[0].name == "notebook.py" - assert result[0].path == "notebook.py" + assert result[0].path == os.path.abspath("/path/to/notebook.py") assert result[0].session_id == "session1" - assert result[0].initialization_id == "init_session1" -def test_get_active_sessions_internal_orphaned_session(): - tool = GetActiveNotebooks(ToolContext()) +def test_get_active_sessions_internal_orphaned_session(mock_context: Mock): + """Test get_active_sessions_internal with orphaned session.""" session = MockSession( - connection_state=ConnectionState.ORPHANED, - filename="/path/to/test.py", - session_id="session2", + _connection_state=ConnectionState.ORPHANED, + app_file_manager=MockAppFileManager( + filename="/path/to/test.py", + path=os.path.abspath("/path/to/test.py"), + ), + ) + mock_context.session_manager = MockSessionManager( + sessions={"session2": session} ) - session_manager = MockSessionManager({"session2": session}) - with patch( - "marimo._ai._tools.tools.notebooks.pretty_path" - ) as mock_pretty_path: - mock_pretty_path.return_value = "test.py" - result = tool._get_active_sessions_internal(session_manager) + result = mock_context.get_active_sessions_internal(mock_context) assert len(result) == 1 assert result[0].name == "test.py" -def test_get_active_sessions_internal_closed_session(): - tool = GetActiveNotebooks(ToolContext()) +def test_get_active_sessions_internal_closed_session(mock_context: Mock): + """Test get_active_sessions_internal filters out closed sessions.""" session = MockSession( - connection_state=ConnectionState.CLOSED, - filename="/path/to/closed.py", - session_id="session3", + _connection_state=ConnectionState.CLOSED, + app_file_manager=MockAppFileManager( + filename="/path/to/closed.py", + path=os.path.abspath("/path/to/closed.py"), + ), + ) + mock_context.session_manager = MockSessionManager( + sessions={"session3": session} ) - session_manager = MockSessionManager({"session3": session}) - result = tool._get_active_sessions_internal(session_manager) + result = mock_context.get_active_sessions_internal(mock_context) + assert result == [] -def test_get_active_sessions_internal_no_filename(): - tool = GetActiveNotebooks(ToolContext()) +def test_get_active_sessions_internal_no_filename(mock_context: Mock): + """Test get_active_sessions_internal with unsaved notebook.""" session = MockSession( - connection_state=ConnectionState.OPEN, filename=None, session_id="s4" + _connection_state=ConnectionState.OPEN, + app_file_manager=MockAppFileManager(filename=None, path=None), ) - session_manager = MockSessionManager({"s4": session}) + mock_context.session_manager = MockSessionManager(sessions={"s4": session}) + + result = mock_context.get_active_sessions_internal(mock_context) - result = tool._get_active_sessions_internal(session_manager) assert len(result) == 1 assert result[0].name == "new notebook" - assert result[0].path == "s4" + assert ( + result[0].path == "(unsaved notebook - save to disk to get file path)" + ) assert result[0].session_id == "s4" -def test_get_active_sessions_internal_multiple_sessions(): - tool = GetActiveNotebooks(ToolContext()) +def test_get_active_sessions_internal_multiple_sessions(mock_context: Mock): + """Test get_active_sessions_internal with multiple sessions of different states.""" sessions = { - "s1": MockSession(ConnectionState.OPEN, "/path/first.py", "s1"), - "s2": MockSession(ConnectionState.CLOSED, "/path/closed.py", "s2"), - "s3": MockSession(ConnectionState.ORPHANED, "/path/third.py", "s3"), + "s1": MockSession( + ConnectionState.OPEN, + MockAppFileManager( + "/path/first.py", os.path.abspath("/path/first.py") + ), + ), + "s2": MockSession( + ConnectionState.CLOSED, + MockAppFileManager( + "/path/closed.py", os.path.abspath("/path/closed.py") + ), + ), + "s3": MockSession( + ConnectionState.ORPHANED, + MockAppFileManager( + "/path/third.py", os.path.abspath("/path/third.py") + ), + ), } - session_manager = MockSessionManager(sessions) + mock_context.session_manager = MockSessionManager(sessions=sessions) - with patch( - "marimo._ai._tools.tools.notebooks.pretty_path" - ) as mock_pretty_path: - mock_pretty_path.side_effect = lambda x: os.path.basename(x) - result = tool._get_active_sessions_internal(session_manager) + result = mock_context.get_active_sessions_internal(mock_context) assert len(result) == 2 session_ids = [f.session_id for f in result] @@ -122,25 +177,33 @@ def test_get_active_sessions_internal_multiple_sessions(): assert "s2" not in session_ids -def test_get_active_notebooks_handle(): +def test_get_active_notebooks_handle(tool: GetActiveNotebooks): """Test GetActiveNotebooks.handle() end-to-end.""" - tool = GetActiveNotebooks(ToolContext()) session = MockSession( - ConnectionState.OPEN, "/test/notebook.py", "session1" + ConnectionState.OPEN, + MockAppFileManager( + "/test/notebook.py", os.path.abspath("/test/notebook.py") + ), ) - session_manager = MockSessionManager({"session1": session}) - session_manager.get_active_connection_count = Mock(return_value=1) + session_manager = MockSessionManager(sessions={"session1": session}) + # Mock the context context = Mock(spec=ToolContext) context.session_manager = session_manager + context.get_active_sessions_internal = Mock( + return_value=[ + MarimoNotebookInfo( + name="notebook.py", + path="/test/notebook.py", + session_id=SessionId("session1"), + ) + ] + ) tool.context = context - from marimo._ai._tools.types import EmptyArgs - result = tool.handle(EmptyArgs()) assert result.status == "success" assert result.data.summary.total_notebooks == 1 - assert result.data.summary.total_sessions == 1 assert result.data.summary.active_connections == 1 assert len(result.data.notebooks) == 1