From a3de1807bb1b661f8b48d2ef4c5259e037951d6d Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Wed, 29 Oct 2025 19:45:46 +0100 Subject: [PATCH 1/8] dry error handing in tools, create error prompt for easy injection --- marimo/_ai/_tools/base.py | 150 +++++++++++- marimo/_ai/_tools/tools/cells.py | 91 +------ marimo/_ai/_tools/tools/errors.py | 119 ++------- marimo/_ai/_tools/types.py | 17 +- marimo/_mcp/server/_prompts/prompts/errors.py | 84 +++++++ marimo/_mcp/server/_prompts/registry.py | 2 + tests/_ai/tools/tools/test_cells.py | 49 ---- tests/_ai/tools/tools/test_errors_tool.py | 225 ++++++++++++++---- 8 files changed, 458 insertions(+), 279 deletions(-) create mode 100644 marimo/_mcp/server/_prompts/prompts/errors.py diff --git a/marimo/_ai/_tools/base.py b/marimo/_ai/_tools/base.py index d1b2cd70f97..3e634f24d0d 100644 --- a/marimo/_ai/_tools/base.py +++ b/marimo/_ai/_tools/base.py @@ -17,9 +17,17 @@ ) from marimo import _loggers -from marimo._ai._tools.types import MarimoNotebookInfo, ToolGuidelines +from marimo._ai._tools.types import ( + MarimoCellErrors, + MarimoErrorDetail, + MarimoNotebookInfo, + ToolGuidelines, +) from marimo._ai._tools.utils.exceptions import ToolExecutionError +from marimo._ai._tools.utils.output_cleaning import clean_output from marimo._config.config import CopilotMode +from marimo._messaging.cell_output import CellChannel +from marimo._messaging.ops import CellOp from marimo._server.ai.tools.types import ( FunctionArgs, ToolDefinition, @@ -28,7 +36,7 @@ from marimo._server.api.deps import AppStateBase from marimo._server.model import ConnectionState from marimo._server.sessions import Session, SessionManager -from marimo._types.ids import SessionId +from marimo._types.ids import CellId_t, SessionId from marimo._utils.case import to_snake_case from marimo._utils.dataclass_to_openapi import PythonTypeToOpenAPI from marimo._utils.parse_dataclass import parse_raw @@ -81,6 +89,18 @@ def get_session(self, session_id: SessionId) -> Session: ) return session_manager.sessions[session_id] + def get_cell_ops(self, session_id: SessionId, cell_id: CellId_t) -> CellOp: + session_view = self.get_session(session_id).session_view + if cell_id not in session_view.cell_operations: + raise ToolExecutionError( + f"Cell operation not found for cell {cell_id}", + code="CELL_OPERATION_NOT_FOUND", + is_retryable=False, + suggested_fix="Try again with a valid cell ID.", + meta={"cell_id": cell_id}, + ) + return session_view.cell_operations[cell_id] + def get_active_sessions_internal(self) -> list[MarimoNotebookInfo]: """ Get active sessions from the app state. @@ -113,6 +133,132 @@ def get_active_sessions_internal(self) -> list[MarimoNotebookInfo]: # Return most recent notebooks first (reverse chronological order) return files[::-1] + def get_notebook_errors( + self, session_id: SessionId, include_stderr: bool = False + ) -> list[MarimoCellErrors]: + """ + Get all errors in the current notebook session, organized by cell. + + Args: + session_id: The session ID of the notebook. + include_stderr: Whether to include stderr errors. + + Returns: + A list of MarimoCellErrors in the order of the cells in the notebook. + """ + session = self.get_session(session_id) + session_view = session.session_view + cell_errors_map: dict[CellId_t, list[MarimoErrorDetail]] = {} + notebook_errors: list[MarimoCellErrors] = [] + + for cell_id, cell_op in session_view.cell_operations.items(): + errors = self.get_cell_errors( + session_id, + cell_id, + maybe_cell_op=cell_op, + include_stderr=include_stderr, + ) + if len(errors) > 0: + cell_errors_map[cell_id] = errors + + # Use cell_manager to get cells in the correct notebook order + cell_manager = session.app_file_manager.app.cell_manager + for cell_data in cell_manager.cell_data(): + cell_id = cell_data.cell_id + if cell_id in cell_errors_map: + notebook_errors.append( + MarimoCellErrors( + cell_id=cell_id, + errors=cell_errors_map[cell_id], + ) + ) + + return notebook_errors + + def get_cell_errors( + self, + session_id: SessionId, + cell_id: CellId_t, + maybe_cell_op: Optional[CellOp] = None, + include_stderr: bool = False, + ) -> list[MarimoErrorDetail]: + """ + Get all errors for a given cell. + + Args: + session_id: The session ID of the notebook. + cell_id: The ID of the cell. + maybe_cell_op: The cell operation. + include_stderr: Whether to include stderr errors. + + Returns: + A list of MarimoErrorDetails for the cell with STDERR errors if include_stderr is True. + """ + errors: list[MarimoErrorDetail] = [] + cell_op = maybe_cell_op or self.get_cell_ops(session_id, cell_id) + + if ( + cell_op.output + and cell_op.output.channel == CellChannel.MARIMO_ERROR + ): + items = cell_op.output.data + + if not isinstance(items, list): + # no errors + return errors + + for err in items: + # TODO: filter out noisy useless errors + # like "An ancestor raised an exception..." + if isinstance(err, dict): + errors.append( + MarimoErrorDetail( + type=err.get("type", "UnknownError"), + message=err.get("msg", str(err)), + traceback=err.get("traceback", []), + ) + ) + else: + # Fallback for rich error objects + err_type: str = getattr(err, "type", type(err).__name__) + describe_fn: Optional[Any] = getattr(err, "describe", None) + message_val = ( + describe_fn() if callable(describe_fn) else str(err) + ) + message: str = str(message_val) + tb: list[str] = getattr(err, "traceback", []) or [] + errors.append( + MarimoErrorDetail( + type=err_type, + message=message, + traceback=tb, + ) + ) + + if cell_op.console and include_stderr: + console_outputs = ( + cell_op.console + if isinstance(cell_op.console, list) + else [cell_op.console] + ) + stderr_messages: list[str] = [] + for console in console_outputs: + if console.channel == CellChannel.STDERR: + stderr_messages.append(str(console.data)) + cleaned_stderr_messages = clean_output(stderr_messages) + errors.extend( + [ + MarimoErrorDetail( + type="STDERR", + message=message, + traceback=[], + ) + for message in cleaned_stderr_messages + ] + ) + + return errors + class ToolBase(Generic[ArgsT, OutT], ABC): """ diff --git a/marimo/_ai/_tools/tools/cells.py b/marimo/_ai/_tools/tools/cells.py index f812cd20e60..90941ef81e4 100644 --- a/marimo/_ai/_tools/tools/cells.py +++ b/marimo/_ai/_tools/tools/cells.py @@ -6,7 +6,11 @@ from typing import TYPE_CHECKING, Any, Optional from marimo._ai._tools.base import ToolBase -from marimo._ai._tools.types import SuccessResult, ToolGuidelines +from marimo._ai._tools.types import ( + MarimoErrorDetail, + SuccessResult, + ToolGuidelines, +) from marimo._ai._tools.utils.exceptions import ToolExecutionError from marimo._ai._tools.utils.output_cleaning import clean_output from marimo._ast.models import CellData @@ -56,12 +60,6 @@ class ErrorDetail: traceback: list[str] -@dataclass -class CellErrors: - has_errors: bool - error_details: Optional[list[ErrorDetail]] - - @dataclass class CellRuntimeMetadata: # String form of the runtime state (see marimo._ast.cell.RuntimeStateType); @@ -79,7 +77,7 @@ class GetCellRuntimeDataData: session_id: str cell_id: str code: Optional[str] = None - errors: Optional[CellErrors] = None + errors: Optional[list[MarimoErrorDetail]] = None metadata: Optional[CellRuntimeMetadata] = None variables: Optional[CellVariables] = None @@ -267,7 +265,9 @@ def handle(self, args: GetCellRuntimeDataArgs) -> GetCellRuntimeDataOutput: cell_code = cell_data.code # Get cell errors from session view with actual error details - cell_errors = self._get_cell_errors(session, cell_id) + cell_errors = context.get_cell_errors( + session_id, cell_id, include_stderr=True + ) # Get cell runtime metadata cell_metadata = self._get_cell_metadata(session, cell_id) @@ -307,79 +307,6 @@ def _get_cell_data( ) return cell_data - def _get_cell_errors( - self, session: Session, cell_id: CellId_t - ) -> CellErrors: - """Get cell errors from session view with actual error details.""" - from marimo._messaging.cell_output import CellChannel - - # Get cell operation from session view - session_view = session.session_view - cell_op = session_view.cell_operations.get(cell_id) - - if cell_op is None: - # No operations recorded for this cell - return CellErrors(has_errors=False, error_details=None) - - # Check for actual error details in the output - has_errors = False - error_details = [] - if ( - cell_op.output - and cell_op.output.channel == CellChannel.MARIMO_ERROR - ): - has_errors = True - # Extract actual error objects - errors = cell_op.output.data - if isinstance(errors, list): - for error in errors: - if hasattr(error, "type") and hasattr(error, "describe"): - # Rich Error object - error_detail = ErrorDetail( - type=error.type, - message=error.describe(), - traceback=getattr(error, "traceback", []), - ) - error_details.append(error_detail) - elif isinstance(error, dict): - # Dict-based error - dict_error_detail = ErrorDetail( - type=error.get("type", "UnknownError"), - message=error.get("msg", str(error)), - traceback=error.get("traceback", []), - ) - error_details.append(dict_error_detail) - else: - # Fallback for other error types - fallback_error_detail = ErrorDetail( - type=type(error).__name__, - message=str(error), - traceback=[], - ) - error_details.append(fallback_error_detail) - - # Check console outputs for STDERR (includes print statements to stderr, warnings, etc.) - if cell_op.console: - console_outputs = ( - cell_op.console - if isinstance(cell_op.console, list) - else [cell_op.console] - ) - for console_output in console_outputs: - if console_output.channel == CellChannel.STDERR: - has_errors = True - stderr_error_detail = ErrorDetail( - type="STDERR", - message=str(console_output.data), - traceback=[], - ) - error_details.append(stderr_error_detail) - - return CellErrors( - has_errors=has_errors, - error_details=error_details if error_details else None, - ) - def _get_cell_metadata( self, session: Session, cell_id: CellId_t ) -> CellRuntimeMetadata: diff --git a/marimo/_ai/_tools/tools/errors.py b/marimo/_ai/_tools/tools/errors.py index 26f31d112f4..a4e8e34b377 100644 --- a/marimo/_ai/_tools/tools/errors.py +++ b/marimo/_ai/_tools/tools/errors.py @@ -2,16 +2,14 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Optional -from marimo import _loggers from marimo._ai._tools.base import ToolBase -from marimo._ai._tools.tools.cells import ErrorDetail -from marimo._ai._tools.types import SuccessResult, ToolGuidelines -from marimo._server.sessions import Session -from marimo._types.ids import CellId_t, SessionId - -LOGGER = _loggers.marimo_logger() +from marimo._ai._tools.types import ( + MarimoCellErrors, + SuccessResult, + ToolGuidelines, +) +from marimo._types.ids import SessionId @dataclass @@ -19,17 +17,12 @@ class GetNotebookErrorsArgs: session_id: SessionId -@dataclass -class CellErrorsSummary: - cell_id: CellId_t - errors: list[ErrorDetail] = field(default_factory=list) - - @dataclass class GetNotebookErrorsOutput(SuccessResult): has_errors: bool = False total_errors: int = 0 - cells: list[CellErrorsSummary] = field(default_factory=list) + total_cells_with_errors: int = 0 + cells: list[MarimoCellErrors] = field(default_factory=list) class GetNotebookErrors( @@ -42,7 +35,7 @@ class GetNotebookErrors( session_id: The session ID of the notebook. Returns: - A success result containing per-cell error details and totals. + A success result containing notebook errors organized by cell. """ guidelines = ToolGuidelines( @@ -56,15 +49,20 @@ class GetNotebookErrors( ) def handle(self, args: GetNotebookErrorsArgs) -> GetNotebookErrorsOutput: - session = self.context.get_session(args.session_id) - summaries = self._collect_errors(session) + context = self.context + session_id = args.session_id + summaries = context.get_notebook_errors( + session_id, include_stderr=True + ) - total_errors = sum(len(s.errors) for s in summaries) + total_errors = self._get_total_errors_count_without_stderr(summaries) + total_cells_with_errors = len(summaries) has_errors = total_errors > 0 return GetNotebookErrorsOutput( has_errors=has_errors, total_errors=total_errors, + total_cells_with_errors=total_cells_with_errors, cells=summaries, next_steps=( [ @@ -76,80 +74,9 @@ def handle(self, args: GetNotebookErrorsArgs) -> GetNotebookErrorsOutput: ), ) - # helpers - def _collect_errors(self, session: Session) -> list[CellErrorsSummary]: - from marimo._messaging.cell_output import CellChannel - - session_view = session.session_view - - summaries: list[CellErrorsSummary] = [] - for cell_id, cell_op in session_view.cell_operations.items(): - errors: list[ErrorDetail] = [] - - # Collect structured marimo errors from output - if ( - cell_op.output - and cell_op.output.channel == CellChannel.MARIMO_ERROR - ): - items = cell_op.output.data - if isinstance(items, list): - for err in items: - if isinstance(err, dict): - errors.append( - ErrorDetail( - type=err.get("type", "UnknownError"), - message=err.get("msg", str(err)), - traceback=err.get("traceback", []), - ) - ) - else: - # Fallback for rich error objects - err_type: str = getattr( - err, "type", type(err).__name__ - ) - describe_fn: Optional[Any] = getattr( - err, "describe", None - ) - message_val = ( - describe_fn() - if callable(describe_fn) - else str(err) - ) - message: str = str(message_val) - tb: list[str] = getattr(err, "traceback", []) or [] - errors.append( - ErrorDetail( - type=err_type, - message=message, - traceback=tb, - ) - ) - - # Collect stderr messages as error details - if cell_op.console: - console_outputs = ( - cell_op.console - if isinstance(cell_op.console, list) - else [cell_op.console] - ) - for console in console_outputs: - if console.channel == CellChannel.STDERR: - errors.append( - ErrorDetail( - type="STDERR", - message=str(console.data), - traceback=[], - ) - ) - - if errors: - summaries.append( - CellErrorsSummary( - cell_id=cell_id, - errors=errors, - ) - ) - - # Sort by cell_id for stable output - summaries.sort(key=lambda s: s.cell_id) - return summaries + def _get_total_errors_count_without_stderr( + self, summaries: list[MarimoCellErrors] + ) -> int: + return sum( + len([e for e in s.errors if e.type != "STDERR"]) for s in summaries + ) diff --git a/marimo/_ai/_tools/types.py b/marimo/_ai/_tools/types.py index 99e5c3b1bf1..7202adf9d95 100644 --- a/marimo/_ai/_tools/types.py +++ b/marimo/_ai/_tools/types.py @@ -1,10 +1,10 @@ # Copyright 2025 Marimo. All rights reserved. from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Literal, Optional -from marimo._types.ids import SessionId +from marimo._types.ids import CellId_t, SessionId # helper classes StatusValue = Literal["success", "error", "warning"] @@ -41,3 +41,16 @@ class MarimoNotebookInfo: name: str path: str session_id: SessionId + + +@dataclass +class MarimoCellErrors: + cell_id: CellId_t + errors: list[MarimoErrorDetail] = field(default_factory=list) + + +@dataclass +class MarimoErrorDetail: + type: str + message: str + traceback: list[str] diff --git a/marimo/_mcp/server/_prompts/prompts/errors.py b/marimo/_mcp/server/_prompts/prompts/errors.py new file mode 100644 index 00000000000..cd48de7a467 --- /dev/null +++ b/marimo/_mcp/server/_prompts/prompts/errors.py @@ -0,0 +1,84 @@ +# Copyright 2024 Marimo. All rights reserved. +"""MCP Prompts for notebook error information.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from marimo._mcp.server._prompts.base import PromptBase + +if TYPE_CHECKING: + from mcp.types import PromptMessage + + +class ErrorsSummary(PromptBase): + """Get error summaries for all active notebooks.""" + + def handle(self) -> list[PromptMessage]: + """Generate prompt messages summarizing errors in active notebooks. + + Returns: + List of PromptMessage objects, one per notebook with errors. + """ + from mcp.types import PromptMessage, TextContent + + context = self.context + notebooks = context.get_active_sessions_internal() + + if len(notebooks) == 0: + return [ + PromptMessage( + role="user", + content=TextContent( + type="text", + text="No active marimo notebook sessions found.", + ), + ) + ] + + error_messages: list[PromptMessage] = [] + + for notebook in notebooks: + session_id = notebook.session_id + notebook_errors = context.get_notebook_errors(session_id) + + if len(notebook_errors) == 0: + continue + + error_lines = [ + f"Notebook: {notebook.name} (session: {notebook.session_id})", + f"Path: {notebook.path}", + f"Cells with errors: {len(notebook_errors)}", + f"Total errors: {sum(len(cell_errors.errors) for cell_errors in notebook_errors)}", + ] + + for cell_errors in notebook_errors: + error_lines.append(f"**Cell {cell_errors.cell_id}**:") + all_cell_errors = [ + f" • {error.type} - {error.message}" + for error in cell_errors.errors + ] + error_lines.extend(all_cell_errors) + + error_messages.append( + PromptMessage( + role="user", + content=TextContent( + type="text", + text="\n".join(error_lines), + ), + ) + ) + + if len(error_messages) == 0: + return [ + PromptMessage( + role="user", + content=TextContent( + type="text", + text="No errors found in any active notebooks.", + ), + ) + ] + + return error_messages diff --git a/marimo/_mcp/server/_prompts/registry.py b/marimo/_mcp/server/_prompts/registry.py index d60e8814701..403234fa7fb 100644 --- a/marimo/_mcp/server/_prompts/registry.py +++ b/marimo/_mcp/server/_prompts/registry.py @@ -2,8 +2,10 @@ """Registry of all supported MCP prompts.""" from marimo._mcp.server._prompts.base import PromptBase +from marimo._mcp.server._prompts.prompts.errors import ErrorsSummary from marimo._mcp.server._prompts.prompts.notebooks import ActiveNotebooks SUPPORTED_MCP_PROMPTS: list[type[PromptBase]] = [ ActiveNotebooks, + ErrorsSummary, ] diff --git a/tests/_ai/tools/tools/test_cells.py b/tests/_ai/tools/tools/test_cells.py index 638c7f4fec3..da3a4cef895 100644 --- a/tests/_ai/tools/tools/test_cells.py +++ b/tests/_ai/tools/tools/test_cells.py @@ -7,7 +7,6 @@ from marimo._ai._tools.base import ToolContext from marimo._ai._tools.tools.cells import ( - CellErrors, CellRuntimeMetadata, CellVariables, GetCellOutputs, @@ -76,54 +75,6 @@ def test_is_markdown_cell(): assert tool._is_markdown_cell("print('x')") is False -def test_get_cell_errors_no_cell_op(): - tool = GetCellRuntimeData(ToolContext()) - session = MockSession(MockSessionView()) - - result = tool._get_cell_errors(session, CellId_t("missing")) - assert result == CellErrors(has_errors=False, error_details=None) - - -def test_get_cell_errors_with_marimo_error(): - tool = GetCellRuntimeData(ToolContext()) - error = MockError( - "NameError", "name 'x' is not defined", ["line1", "line2"] - ) - output = MockOutput(CellChannel.MARIMO_ERROR, [error]) - cell_op = MockCellOp(output=output) - session = MockSession(MockSessionView(cell_operations={"c1": cell_op})) - - result = tool._get_cell_errors(session, CellId_t("c1")) - assert result.has_errors is True - assert result.error_details is not None - assert result.error_details[0].type == "NameError" - - -def test_get_cell_errors_with_stderr(): - tool = GetCellRuntimeData(ToolContext()) - console_output = MockConsoleOutput(CellChannel.STDERR, "warn") - cell_op = MockCellOp(console=[console_output]) - session = MockSession(MockSessionView(cell_operations={"c1": cell_op})) - - result = tool._get_cell_errors(session, CellId_t("c1")) - assert result.has_errors is True - assert result.error_details is not None - assert result.error_details[0].type == "STDERR" - - -def test_get_cell_errors_dict_error(): - tool = GetCellRuntimeData(ToolContext()) - dict_error = {"type": "ValueError", "msg": "invalid", "traceback": ["tb1"]} - output = MockOutput(CellChannel.MARIMO_ERROR, [dict_error]) - cell_op = MockCellOp(output=output) - session = MockSession(MockSessionView(cell_operations={"c1": cell_op})) - - result = tool._get_cell_errors(session, CellId_t("c1")) - assert result.has_errors is True - assert result.error_details is not None - assert result.error_details[0].type == "ValueError" - - def test_get_cell_metadata_basic(): tool = GetCellRuntimeData(ToolContext()) cell_op = MockCellOp(status="idle") diff --git a/tests/_ai/tools/tools/test_errors_tool.py b/tests/_ai/tools/tools/test_errors_tool.py index 2a503b2430f..7bda02577d5 100644 --- a/tests/_ai/tools/tools/test_errors_tool.py +++ b/tests/_ai/tools/tools/test_errors_tool.py @@ -3,13 +3,15 @@ from dataclasses import dataclass from unittest.mock import Mock +import pytest + from marimo._ai._tools.base import ToolContext from marimo._ai._tools.tools.errors import ( GetNotebookErrors, GetNotebookErrorsArgs, ) -from marimo._messaging.cell_output import CellChannel -from marimo._types.ids import SessionId +from marimo._ai._tools.types import MarimoCellErrors, MarimoErrorDetail +from marimo._types.ids import CellId_t, SessionId @dataclass @@ -30,9 +32,15 @@ class MockConsoleOutput: data: object +@dataclass +class MockUpdateCellIdsRequest: + cell_ids: list[str] + + @dataclass class MockSessionView: cell_operations: dict | None = None + cell_ids: MockUpdateCellIdsRequest | None = None def __post_init__(self) -> None: if self.cell_operations is None: @@ -50,63 +58,184 @@ class MockSession: app_file_manager: DummyAppFileManager -def test_collect_errors_none() -> None: - tool = GetNotebookErrors(ToolContext()) +@pytest.fixture +def tool() -> GetNotebookErrors: + """Create a GetNotebookErrors tool instance.""" + return GetNotebookErrors(ToolContext()) - # Empty session view - session = MockSession( - session_view=MockSessionView(), - app_file_manager=DummyAppFileManager(app=Mock()), - ) - summaries = tool._collect_errors(session) # type: ignore[arg-type] - assert summaries == [] +@pytest.fixture +def mock_context() -> Mock: + """Create a mock ToolContext.""" + context = Mock(spec=ToolContext) + return context + +def test_get_notebook_errors_empty_session(mock_context: Mock) -> None: + """Test get_notebook_errors with no errors.""" + mock_context.get_notebook_errors.return_value = [] -def test_collect_errors_marimo_and_stderr() -> None: tool = GetNotebookErrors(ToolContext()) + tool.context = mock_context + + result = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1"))) + + assert result.has_errors is False + assert result.total_errors == 0 + assert result.total_cells_with_errors == 0 + assert result.cells == [] + assert result.next_steps is not None + assert "No errors detected" in result.next_steps + + +def test_get_notebook_errors_marimo_error_only(mock_context: Mock) -> None: + """Test get_notebook_errors with MARIMO_ERROR only.""" + marimo_errors = [ + MarimoCellErrors( + cell_id=CellId_t("c1"), + errors=[ + MarimoErrorDetail( + type="ValueError", + message="bad value", + traceback=["line 1"], + ) + ], + ) + ] + mock_context.get_notebook_errors.return_value = marimo_errors - # Cell c1 has MARIMO_ERROR (dict) and STDERR; c2 has STDERR only - err_dict = {"type": "ValueError", "msg": "bad value", "traceback": ["tb"]} - c1 = MockCellOp( - output=MockOutput(CellChannel.MARIMO_ERROR, [err_dict]), - console=[MockConsoleOutput(CellChannel.STDERR, "warn")], - ) - c2 = MockCellOp(console=[MockConsoleOutput(CellChannel.STDERR, "oops")]) + tool = GetNotebookErrors(ToolContext()) + tool.context = mock_context + + result = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1"))) + + assert result.has_errors is True + assert result.total_errors == 1 + assert result.total_cells_with_errors == 1 + assert len(result.cells) == 1 + assert result.cells[0].cell_id == CellId_t("c1") + assert result.cells[0].errors[0].type == "ValueError" + assert result.next_steps is not None + assert "get_cell_runtime_data" in result.next_steps[0] + + +def test_get_notebook_errors_stderr_only(mock_context: Mock) -> None: + """Test get_notebook_errors with STDERR only.""" + stderr_errors = [ + MarimoCellErrors( + cell_id=CellId_t("c2"), + errors=[ + MarimoErrorDetail( + type="STDERR", + message="warning message", + traceback=[], + ) + ], + ) + ] + mock_context.get_notebook_errors.return_value = stderr_errors - session = MockSession( - session_view=MockSessionView(cell_operations={"c1": c1, "c2": c2}), - app_file_manager=DummyAppFileManager(app=Mock()), - ) + tool = GetNotebookErrors(ToolContext()) + tool.context = mock_context + + result = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1"))) + + assert result.has_errors is False + assert result.total_errors == 0 + assert result.total_cells_with_errors == 1 + assert result.cells[0].errors[0].type == "STDERR" + + +def test_get_notebook_errors_mixed_errors(mock_context: Mock) -> None: + """Test get_notebook_errors with both MARIMO_ERROR and STDERR.""" + mixed_errors = [ + MarimoCellErrors( + cell_id=CellId_t("c1"), + errors=[ + MarimoErrorDetail( + type="ValueError", + message="bad value", + traceback=["line 1"], + ), + MarimoErrorDetail( + type="STDERR", + message="warn", + traceback=[], + ), + ], + ) + ] + mock_context.get_notebook_errors.return_value = mixed_errors + + tool = GetNotebookErrors(ToolContext()) + tool.context = mock_context + + result = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1"))) + + assert result.has_errors is True + assert result.total_errors == 1 + assert result.total_cells_with_errors == 1 + assert len(result.cells[0].errors) == 2 + + +def test_get_notebook_errors_multiple_cells(mock_context: Mock) -> None: + """Test get_notebook_errors with errors in multiple cells.""" + multiple_errors = [ + MarimoCellErrors( + cell_id=CellId_t("c1"), + errors=[ + MarimoErrorDetail( + type="ValueError", + message="error in c1", + traceback=[], + ) + ], + ), + MarimoCellErrors( + cell_id=CellId_t("c2"), + errors=[ + MarimoErrorDetail( + type="TypeError", + message="error in c2", + traceback=[], + ) + ], + ), + MarimoCellErrors( + cell_id=CellId_t("c3"), + errors=[ + MarimoErrorDetail( + type="STDERR", + message="stderr in c3", + traceback=[], + ) + ], + ), + ] + mock_context.get_notebook_errors.return_value = multiple_errors + + tool = GetNotebookErrors(ToolContext()) + tool.context = mock_context + + result = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1"))) - summaries = tool._collect_errors(session) # type: ignore[arg-type] + assert result.has_errors is True + assert result.total_errors == 2 + assert result.total_cells_with_errors == 3 + assert len(result.cells) == 3 - # Sorted by cell_id: c1 then c2 - assert len(summaries) == 2 - assert summaries[0].cell_id == "c1" - assert len(summaries[0].errors) == 2 # one MARIMO_ERROR, one STDERR - assert summaries[0].errors[0].type == "ValueError" - assert summaries[1].cell_id == "c2" - assert len(summaries[1].errors) == 1 - assert summaries[1].errors[0].type == "STDERR" +def test_get_notebook_errors_respects_session_id(mock_context: Mock) -> None: + """Test that get_notebook_errors passes the correct session_id.""" + session_id = SessionId("test-session-123") + mock_context.get_notebook_errors.return_value = [] -def test_handle_integration_uses_context_get_session() -> None: tool = GetNotebookErrors(ToolContext()) + tool.context = mock_context - c1 = MockCellOp(console=[MockConsoleOutput(CellChannel.STDERR, "warn")]) - session = MockSession( - session_view=MockSessionView(cell_operations={"c1": c1}), - app_file_manager=DummyAppFileManager(app=Mock()), - ) + tool.handle(GetNotebookErrorsArgs(session_id=session_id)) - # Mock ToolContext.get_session - context = Mock(spec=ToolContext) - context.get_session.return_value = session - tool.context = context # type: ignore[assignment] - - out = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1"))) - assert out.has_errors is True - assert out.total_errors == 1 - assert len(out.cells) == 1 - assert out.cells[0].cell_id == "c1" + # Verify the session_id was passed correctly with include_stderr=True + mock_context.get_notebook_errors.assert_called_once_with( + session_id, include_stderr=True + ) From bea3afa2efaa690696ce3ce9f6a4d1421eb416a7 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Thu, 30 Oct 2025 00:50:57 +0100 Subject: [PATCH 2/8] remove unused dataclass --- marimo/_ai/_tools/tools/cells.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/marimo/_ai/_tools/tools/cells.py b/marimo/_ai/_tools/tools/cells.py index 90941ef81e4..9b35c83e9d6 100644 --- a/marimo/_ai/_tools/tools/cells.py +++ b/marimo/_ai/_tools/tools/cells.py @@ -53,13 +53,6 @@ class GetLightweightCellMapOutput(SuccessResult): preview_lines: int = 3 -@dataclass -class ErrorDetail: - type: str - message: str - traceback: list[str] - - @dataclass class CellRuntimeMetadata: # String form of the runtime state (see marimo._ast.cell.RuntimeStateType); From 83ad76cdf769515f54a6c1b90279b1632c2b5e70 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Thu, 30 Oct 2025 03:02:04 +0100 Subject: [PATCH 3/8] decouple stderr from get_cell_errors, add get_cell_console_outputs to ToolContext, update GetCellOutputs to use it --- marimo/_ai/_tools/base.py | 159 +++++++++--------- marimo/_ai/_tools/tools/cells.py | 78 +++------ marimo/_ai/_tools/tools/errors.py | 15 +- marimo/_ai/_tools/types.py | 6 + marimo/_mcp/server/_prompts/prompts/errors.py | 2 +- 5 files changed, 115 insertions(+), 145 deletions(-) diff --git a/marimo/_ai/_tools/base.py b/marimo/_ai/_tools/base.py index 3e634f24d0d..7b64d8aec5a 100644 --- a/marimo/_ai/_tools/base.py +++ b/marimo/_ai/_tools/base.py @@ -21,6 +21,7 @@ MarimoCellErrors, MarimoErrorDetail, MarimoNotebookInfo, + MarimoCellConsoleOutputs, ToolGuidelines, ) from marimo._ai._tools.utils.exceptions import ToolExecutionError @@ -134,32 +135,31 @@ def get_active_sessions_internal(self) -> list[MarimoNotebookInfo]: return files[::-1] def get_notebook_errors( - self, session_id: SessionId, include_stderr: bool = False + self, session_id: SessionId, include_stderr: bool ) -> list[MarimoCellErrors]: """ Get all errors in the current notebook session, organized by cell. - - Args: - session_id: The session ID of the notebook. - include_stderr: Whether to include stderr errors. - - Returns: - A list of MarimoCellErrors in the order of the cells in the notebook. """ session = self.get_session(session_id) session_view = session.session_view - cell_errors_map: dict[CellId_t, list[MarimoErrorDetail]] = {} + cell_errors_map: dict[CellId_t, MarimoCellErrors] = {} notebook_errors: list[MarimoCellErrors] = [] + stderr: list[str] = [] for cell_id, cell_op in session_view.cell_operations.items(): errors = self.get_cell_errors( session_id, cell_id, maybe_cell_op=cell_op, - include_stderr=include_stderr, ) - if len(errors) > 0: - cell_errors_map[cell_id] = errors + if include_stderr: + stderr = self.get_cell_console_outputs(cell_op).stderr + if errors: + cell_errors_map[cell_id] = MarimoCellErrors( + cell_id=cell_id, + errors=errors, + stderr=stderr, + ) # Use cell_manager to get cells in the correct notebook order cell_manager = session.app_file_manager.app.cell_manager @@ -167,10 +167,7 @@ def get_notebook_errors( cell_id = cell_data.cell_id if cell_id in cell_errors_map: notebook_errors.append( - MarimoCellErrors( - cell_id=cell_id, - errors=cell_errors_map[cell_id], - ) + cell_errors_map[cell_id] ) return notebook_errors @@ -180,85 +177,85 @@ def get_cell_errors( session_id: SessionId, cell_id: CellId_t, maybe_cell_op: Optional[CellOp] = None, - include_stderr: bool = False, ) -> list[MarimoErrorDetail]: """ Get all errors for a given cell. - - Args: - session_id: The session ID of the notebook. - cell_id: The ID of the cell. - maybe_cell_op: The cell operation. - include_stderr: Whether to include stderr errors. - - Returns: - A list of MarimoErrorDetails for the cell with STDERR errors if include_stderr is True. """ errors: list[MarimoErrorDetail] = [] cell_op = maybe_cell_op or self.get_cell_ops(session_id, cell_id) - if ( - cell_op.output - and cell_op.output.channel == CellChannel.MARIMO_ERROR - ): - items = cell_op.output.data - - if not isinstance(items, list): - # no errors - return errors - - for err in items: - # TODO: filter out noisy useless errors - # like "An ancestor raised an exception..." - if isinstance(err, dict): - errors.append( - MarimoErrorDetail( - type=err.get("type", "UnknownError"), - message=err.get("msg", str(err)), - traceback=err.get("traceback", []), - ) - ) - else: - # Fallback for rich error objects - err_type: str = getattr(err, "type", type(err).__name__) - describe_fn: Optional[Any] = getattr(err, "describe", None) - message_val = ( - describe_fn() if callable(describe_fn) else str(err) - ) - message: str = str(message_val) - tb: list[str] = getattr(err, "traceback", []) or [] - errors.append( - MarimoErrorDetail( - type=err_type, - message=message, - traceback=tb, - ) - ) + if not cell_op.output or cell_op.output.channel != CellChannel.MARIMO_ERROR: + return errors - if cell_op.console and include_stderr: - console_outputs = ( - cell_op.console - if isinstance(cell_op.console, list) - else [cell_op.console] - ) - stderr_messages: list[str] = [] - for console in console_outputs: - if console.channel == CellChannel.STDERR: - stderr_messages.append(str(console.data)) - cleaned_stderr_messages = clean_output(stderr_messages) - errors.extend( - [ + items = cell_op.output.data + + if not isinstance(items, list): + # no errors + return errors + + for err in items: + # TODO: filter out noisy useless errors + # like "An ancestor raised an exception..." + if isinstance(err, dict): + errors.append( MarimoErrorDetail( - type="STDERR", + type=err.get("type", "UnknownError"), + message=err.get("msg", str(err)), + traceback=err.get("traceback", []), + ) + ) + else: + # Fallback for rich error objects + err_type: str = getattr(err, "type", type(err).__name__) + describe_fn: Optional[Any] = getattr(err, "describe", None) + message_val = ( + describe_fn() if callable(describe_fn) else str(err) + ) + message: str = str(message_val) + tb: list[str] = getattr(err, "traceback", []) or [] + errors.append( + MarimoErrorDetail( + type=err_type, message=message, - traceback=[], + traceback=tb, ) - for message in cleaned_stderr_messages - ] - ) + ) return errors + def get_cell_console_outputs( + self, cell_op: CellOp + ) -> MarimoCellConsoleOutputs: + """ + Get the console outputs for a given cell operation. + """ + stdout_messages: list[str] = [] + stderr_messages: list[str] = [] + + if cell_op.console is None: + return MarimoCellConsoleOutputs(stdout=[], stderr=[]) + + console_outputs = ( + cell_op.console + if isinstance(cell_op.console, list) + else [cell_op.console] + ) + for output in console_outputs: + if output is None: + continue + elif output.channel == CellChannel.STDOUT: + stdout_messages.append(str(output.data)) + elif output.channel == CellChannel.STDERR: + stderr_messages.append(str(output.data)) + + cleaned_stdout_messages = clean_output(stdout_messages) + cleaned_stderr_messages = clean_output(stderr_messages) + + return MarimoCellConsoleOutputs( + stdout=cleaned_stdout_messages, + stderr=cleaned_stderr_messages + ) + class ToolBase(Generic[ArgsT, OutT], ABC): """ diff --git a/marimo/_ai/_tools/tools/cells.py b/marimo/_ai/_tools/tools/cells.py index 9b35c83e9d6..454eb7b815e 100644 --- a/marimo/_ai/_tools/tools/cells.py +++ b/marimo/_ai/_tools/tools/cells.py @@ -10,11 +10,10 @@ MarimoErrorDetail, SuccessResult, ToolGuidelines, + MarimoCellConsoleOutputs, ) from marimo._ai._tools.utils.exceptions import ToolExecutionError -from marimo._ai._tools.utils.output_cleaning import clean_output from marimo._ast.models import CellData -from marimo._messaging.cell_output import CellChannel from marimo._messaging.errors import Error from marimo._messaging.ops import CellOp, VariableValue from marimo._types.ids import CellId_t, SessionId @@ -93,13 +92,11 @@ class GetCellRuntimeDataOutput(SuccessResult): @dataclass -class CellOutputData: - """Visual and console output from a cell execution.""" +class CellVisualOutput: + """Visual from a cell execution.""" visual_output: Optional[str] = None visual_mimetype: Optional[str] = None - stdout: list[str] = field(default_factory=list) - stderr: list[str] = field(default_factory=list) @dataclass @@ -110,7 +107,8 @@ class GetCellOutputArgs: @dataclass class GetCellOutputOutput(SuccessResult): - data: CellOutputData = field(default_factory=CellOutputData) + visual_output: CellVisualOutput = field(default_factory=CellVisualOutput) + console_outputs: MarimoCellConsoleOutputs = field(default_factory=MarimoCellConsoleOutputs) class GetLightweightCellMap( @@ -259,7 +257,7 @@ def handle(self, args: GetCellRuntimeDataArgs) -> GetCellRuntimeDataOutput: # Get cell errors from session view with actual error details cell_errors = context.get_cell_errors( - session_id, cell_id, include_stderr=True + session_id, cell_id ) # Get cell runtime metadata @@ -346,16 +344,12 @@ def _get_cell_variables( class GetCellOutputs(ToolBase[GetCellOutputArgs, GetCellOutputOutput]): """Get cell execution output including visual display and console streams. - Returns comprehensive output data for a single cell: - - Visual output (HTML, charts, tables, etc.) with mimetype - - Console stdout and stderr messages - Args: session_id: The session ID of the notebook from get_active_notebooks cell_id: The specific cell ID from get_lightweight_cell_map Returns: - A success result containing all output data from the cell execution. + Visual output (HTML, charts, tables, etc.) with mimetype and console streams (stdout/stderr). """ guidelines = ToolGuidelines( @@ -370,23 +364,29 @@ class GetCellOutputs(ToolBase[GetCellOutputArgs, GetCellOutputOutput]): ) def handle(self, args: GetCellOutputArgs) -> GetCellOutputOutput: - session = self.context.get_session(args.session_id) + context = self.context + session = context.get_session(args.session_id) session_view = session.session_view cell_id = args.cell_id - maybe_cell_op = session_view.cell_operations.get(cell_id) + cell_op = session_view.cell_operations.get(cell_id) - visual_output, visual_mimetype = self._get_visual_output(maybe_cell_op) - stdout_messages, stderr_messages = self._get_console_outputs( - maybe_cell_op - ) + if cell_op is None: + raise ToolExecutionError( + f"Cell {cell_id} not found in session {args.session_id}", + code="CELL_NOT_FOUND", + is_retryable=False, + suggested_fix="Use get_lightweight_cell_map to find valid cell IDs", + ) + + visual_output, visual_mimetype = self._get_visual_output(cell_op) + console_outputs = context.get_cell_console_outputs(cell_op) return GetCellOutputOutput( - data=CellOutputData( + visual_output=CellVisualOutput( visual_output=visual_output, visual_mimetype=visual_mimetype, - stdout=stdout_messages, - stderr=stderr_messages, ), + console_outputs=console_outputs, next_steps=[ "Review visual_output to see what was displayed to the user", "Check stdout/stderr for print statements and warnings", @@ -394,14 +394,14 @@ def handle(self, args: GetCellOutputArgs) -> GetCellOutputOutput: ) def _get_visual_output( - self, maybe_cell_op: Optional[CellOp] + self, cell_op: CellOp ) -> tuple[Optional[str], Optional[str]]: visual_output = None visual_mimetype = None - if maybe_cell_op and maybe_cell_op.output: - data = maybe_cell_op.output.data + if cell_op.output: + data = cell_op.output.data visual_output = self._get_str_output_data(data) - visual_mimetype = maybe_cell_op.output.mimetype + visual_mimetype = cell_op.output.mimetype return visual_output, visual_mimetype def _get_str_output_data( @@ -411,29 +411,3 @@ def _get_str_output_data( return data else: return str(data) - - def _get_console_outputs( - self, maybe_cell_op: Optional[CellOp] - ) -> tuple[list[str], list[str]]: - stdout_messages: list[str] = [] - stderr_messages: list[str] = [] - if maybe_cell_op is None or maybe_cell_op.console is None: - return stdout_messages, stderr_messages - - console_outputs = ( - maybe_cell_op.console - if isinstance(maybe_cell_op.console, list) - else [maybe_cell_op.console] - ) - for output in console_outputs: - if output is None: - continue - elif output.channel == CellChannel.STDOUT: - stdout_messages.append(str(output.data)) - elif output.channel == CellChannel.STDERR: - stderr_messages.append(str(output.data)) - - cleaned_stdout_messages = clean_output(stdout_messages) - cleaned_stderr_messages = clean_output(stderr_messages) - - return cleaned_stdout_messages, cleaned_stderr_messages diff --git a/marimo/_ai/_tools/tools/errors.py b/marimo/_ai/_tools/tools/errors.py index a4e8e34b377..f7183366e12 100644 --- a/marimo/_ai/_tools/tools/errors.py +++ b/marimo/_ai/_tools/tools/errors.py @@ -51,19 +51,19 @@ class GetNotebookErrors( def handle(self, args: GetNotebookErrorsArgs) -> GetNotebookErrorsOutput: context = self.context session_id = args.session_id - summaries = context.get_notebook_errors( + notebook_errors = context.get_notebook_errors( session_id, include_stderr=True ) - total_errors = self._get_total_errors_count_without_stderr(summaries) - total_cells_with_errors = len(summaries) + total_errors = sum(len(c.errors) for c in notebook_errors) + total_cells_with_errors = len(notebook_errors) has_errors = total_errors > 0 return GetNotebookErrorsOutput( has_errors=has_errors, total_errors=total_errors, total_cells_with_errors=total_cells_with_errors, - cells=summaries, + cells=notebook_errors, next_steps=( [ "Use get_cell_runtime_data to inspect the impacted cells to fix syntax/runtime issues", @@ -73,10 +73,3 @@ def handle(self, args: GetNotebookErrorsArgs) -> GetNotebookErrorsOutput: else ["No errors detected"] ), ) - - def _get_total_errors_count_without_stderr( - self, summaries: list[MarimoCellErrors] - ) -> int: - return sum( - len([e for e in s.errors if e.type != "STDERR"]) for s in summaries - ) diff --git a/marimo/_ai/_tools/types.py b/marimo/_ai/_tools/types.py index 7202adf9d95..6af2331cec1 100644 --- a/marimo/_ai/_tools/types.py +++ b/marimo/_ai/_tools/types.py @@ -47,6 +47,7 @@ class MarimoNotebookInfo: class MarimoCellErrors: cell_id: CellId_t errors: list[MarimoErrorDetail] = field(default_factory=list) + stderr: list[str] = field(default_factory=list) @dataclass @@ -54,3 +55,8 @@ class MarimoErrorDetail: type: str message: str traceback: list[str] + +@dataclass +class MarimoCellConsoleOutputs: + stdout: list[str] = field(default_factory=list) + stderr: list[str] = field(default_factory=list) \ No newline at end of file diff --git a/marimo/_mcp/server/_prompts/prompts/errors.py b/marimo/_mcp/server/_prompts/prompts/errors.py index cd48de7a467..2166b804139 100644 --- a/marimo/_mcp/server/_prompts/prompts/errors.py +++ b/marimo/_mcp/server/_prompts/prompts/errors.py @@ -40,7 +40,7 @@ def handle(self) -> list[PromptMessage]: for notebook in notebooks: session_id = notebook.session_id - notebook_errors = context.get_notebook_errors(session_id) + notebook_errors = context.get_notebook_errors(session_id, include_stderr=False) if len(notebook_errors) == 0: continue From fabc20f1d24f07125f75a7003153c6e301a7583d Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Thu, 30 Oct 2025 03:04:21 +0100 Subject: [PATCH 4/8] improve comment --- marimo/_ai/_tools/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/marimo/_ai/_tools/base.py b/marimo/_ai/_tools/base.py index 7b64d8aec5a..1c8cb64b6a9 100644 --- a/marimo/_ai/_tools/base.py +++ b/marimo/_ai/_tools/base.py @@ -139,6 +139,8 @@ def get_notebook_errors( ) -> list[MarimoCellErrors]: """ Get all errors in the current notebook session, organized by cell. + + Optionally include stderr messages foreach cell. """ session = self.get_session(session_id) session_view = session.session_view From 36d68119d3e982df5be23c9ed7f4cebf2c96c802 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Thu, 30 Oct 2025 03:05:50 +0100 Subject: [PATCH 5/8] update tests to reflect changes --- tests/_ai/tools/test_base.py | 91 +++++++++++++++++++++++ tests/_ai/tools/tools/test_cells.py | 22 ------ tests/_ai/tools/tools/test_errors_tool.py | 78 ++----------------- 3 files changed, 99 insertions(+), 92 deletions(-) diff --git a/tests/_ai/tools/test_base.py b/tests/_ai/tools/test_base.py index 60edc7b7c7b..4857ea9d1f1 100644 --- a/tests/_ai/tools/test_base.py +++ b/tests/_ai/tools/test_base.py @@ -8,6 +8,7 @@ from marimo._ai._tools.base import ToolBase, ToolContext from marimo._ai._tools.utils.exceptions import ToolExecutionError +from marimo._messaging import msgspec_encoder @dataclass @@ -125,3 +126,93 @@ def test_as_backend_tool() -> None: is_valid, msg = validator({"invalid": "field"}) assert is_valid is False assert "Invalid arguments" in msg + +# test ToolContext methods + +def test_get_notebook_errors_orders_by_cell_manager(): + """Test errors follow cell_manager order, not alphabetical.""" + from unittest.mock import Mock + from marimo._messaging.cell_output import CellChannel + from marimo._types.ids import CellId_t, SessionId + + context = ToolContext() + + # Mock error cell_op + error_op = Mock() + error_op.output = Mock() + error_op.output.channel = CellChannel.MARIMO_ERROR + error_op.output.data = [{"type": "Error", "msg": "test", "traceback": []}] + error_op.console = None + + # Mock session with cells c1, c2, c3 + session = Mock() + session.session_view.cell_operations = { + CellId_t("c1"): error_op, + CellId_t("c2"): error_op, + CellId_t("c3"): error_op, + } + + # Cell manager returns in order: c3, c2, c1 (not alphabetical) + cell_data = [Mock(cell_id=CellId_t("c3")), Mock(cell_id=CellId_t("c2")), Mock(cell_id=CellId_t("c1"))] + session.app_file_manager.app.cell_manager.cell_data.return_value = cell_data + + context.get_session = Mock(return_value=session) + + errors = context.get_notebook_errors(SessionId("test"), include_stderr=False) + + # Should be c3, c2, c1 (not c1, c2, c3) + assert errors[0].cell_id == CellId_t("c3") + assert errors[1].cell_id == CellId_t("c2") + assert errors[2].cell_id == CellId_t("c1") + + +def test_get_cell_errors_extracts_from_output(): + """Test get_cell_errors extracts error details from cell output.""" + from unittest.mock import Mock + from marimo._messaging.cell_output import CellChannel + from marimo._types.ids import CellId_t, SessionId + + context = ToolContext() + + # Mock cell_op with error + cell_op = Mock() + cell_op.output = Mock() + cell_op.output.channel = CellChannel.MARIMO_ERROR + cell_op.output.data = [ + {"type": "ValueError", "msg": "bad value", "traceback": ["line 1"]} + ] + + errors = context.get_cell_errors(SessionId("test"), CellId_t("c1"), maybe_cell_op=cell_op) + + assert len(errors) == 1 + assert errors[0].type == "ValueError" + assert errors[0].message == "bad value" + assert errors[0].traceback == ["line 1"] + + +def test_get_cell_console_outputs_separates_stdout_stderr(): + """Test get_cell_console_outputs separates stdout and stderr.""" + from unittest.mock import Mock + from marimo._messaging.cell_output import CellChannel + + context = ToolContext() + + # Mock cell_op with stdout and stderr + stdout_output = Mock() + stdout_output.channel = CellChannel.STDOUT + stdout_output.data = "hello" + + stderr_output = Mock() + stderr_output.channel = CellChannel.STDERR + stderr_output.data = "warning" + + cell_op = Mock() + cell_op.console = [stdout_output, stderr_output] + + result = context.get_cell_console_outputs(cell_op) + + assert len(result.stdout) == 1 + assert "hello" in result.stdout[0] + assert len(result.stderr) == 1 + assert "warning" in result.stderr[0] + diff --git a/tests/_ai/tools/tools/test_cells.py b/tests/_ai/tools/tools/test_cells.py index da3a4cef895..d15579d1773 100644 --- a/tests/_ai/tools/tools/test_cells.py +++ b/tests/_ai/tools/tools/test_cells.py @@ -190,25 +190,3 @@ def test_get_visual_output_no_output(): visual_output, mimetype = tool._get_visual_output(cell_op) # type: ignore[arg-type] assert visual_output is None assert mimetype is None - - -def test_get_console_outputs_with_stdout_stderr(): - tool = GetCellOutputs(ToolContext()) - console = [ - MockConsoleOutput(CellChannel.STDOUT, "hello"), - MockConsoleOutput(CellChannel.STDERR, "warning"), - ] - cell_op = MockCellOp(console=console) - - stdout, stderr = tool._get_console_outputs(cell_op) # type: ignore[arg-type] - assert stdout == ["hello"] - assert stderr == ["warning"] - - -def test_get_console_outputs_no_console(): - tool = GetCellOutputs(ToolContext()) - cell_op = MockCellOp(console=None) - - stdout, stderr = tool._get_console_outputs(cell_op) # type: ignore[arg-type] - assert stdout == [] - assert stderr == [] diff --git a/tests/_ai/tools/tools/test_errors_tool.py b/tests/_ai/tools/tools/test_errors_tool.py index 7bda02577d5..2b8684a8110 100644 --- a/tests/_ai/tools/tools/test_errors_tool.py +++ b/tests/_ai/tools/tools/test_errors_tool.py @@ -119,65 +119,6 @@ def test_get_notebook_errors_marimo_error_only(mock_context: Mock) -> None: assert "get_cell_runtime_data" in result.next_steps[0] -def test_get_notebook_errors_stderr_only(mock_context: Mock) -> None: - """Test get_notebook_errors with STDERR only.""" - stderr_errors = [ - MarimoCellErrors( - cell_id=CellId_t("c2"), - errors=[ - MarimoErrorDetail( - type="STDERR", - message="warning message", - traceback=[], - ) - ], - ) - ] - mock_context.get_notebook_errors.return_value = stderr_errors - - tool = GetNotebookErrors(ToolContext()) - tool.context = mock_context - - result = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1"))) - - assert result.has_errors is False - assert result.total_errors == 0 - assert result.total_cells_with_errors == 1 - assert result.cells[0].errors[0].type == "STDERR" - - -def test_get_notebook_errors_mixed_errors(mock_context: Mock) -> None: - """Test get_notebook_errors with both MARIMO_ERROR and STDERR.""" - mixed_errors = [ - MarimoCellErrors( - cell_id=CellId_t("c1"), - errors=[ - MarimoErrorDetail( - type="ValueError", - message="bad value", - traceback=["line 1"], - ), - MarimoErrorDetail( - type="STDERR", - message="warn", - traceback=[], - ), - ], - ) - ] - mock_context.get_notebook_errors.return_value = mixed_errors - - tool = GetNotebookErrors(ToolContext()) - tool.context = mock_context - - result = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1"))) - - assert result.has_errors is True - assert result.total_errors == 1 - assert result.total_cells_with_errors == 1 - assert len(result.cells[0].errors) == 2 - - def test_get_notebook_errors_multiple_cells(mock_context: Mock) -> None: """Test get_notebook_errors with errors in multiple cells.""" multiple_errors = [ @@ -190,6 +131,7 @@ def test_get_notebook_errors_multiple_cells(mock_context: Mock) -> None: traceback=[], ) ], + stderr=[], ), MarimoCellErrors( cell_id=CellId_t("c2"), @@ -198,18 +140,14 @@ def test_get_notebook_errors_multiple_cells(mock_context: Mock) -> None: type="TypeError", message="error in c2", traceback=[], - ) - ], - ), - MarimoCellErrors( - cell_id=CellId_t("c3"), - errors=[ + ), MarimoErrorDetail( - type="STDERR", - message="stderr in c3", + type="ValueError", + message="error in c2", traceback=[], ) ], + stderr=[], ), ] mock_context.get_notebook_errors.return_value = multiple_errors @@ -220,9 +158,9 @@ def test_get_notebook_errors_multiple_cells(mock_context: Mock) -> None: result = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1"))) assert result.has_errors is True - assert result.total_errors == 2 - assert result.total_cells_with_errors == 3 - assert len(result.cells) == 3 + assert result.total_errors == 3 + assert result.total_cells_with_errors == 2 + assert len(result.cells) == 2 def test_get_notebook_errors_respects_session_id(mock_context: Mock) -> None: From fd6daebff5953abbac0a7ab2c03e56a3a0d84225 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Thu, 30 Oct 2025 03:06:43 +0100 Subject: [PATCH 6/8] update prompt tests to check for essential data in prompts only --- .../server/prompts/test_errors_prompts.py | 67 +++++++++++++++ .../server/prompts/test_notebooks_prompts.py | 84 ++++++------------- 2 files changed, 91 insertions(+), 60 deletions(-) create mode 100644 tests/_mcp/server/prompts/test_errors_prompts.py diff --git a/tests/_mcp/server/prompts/test_errors_prompts.py b/tests/_mcp/server/prompts/test_errors_prompts.py new file mode 100644 index 00000000000..b595a4ff37b --- /dev/null +++ b/tests/_mcp/server/prompts/test_errors_prompts.py @@ -0,0 +1,67 @@ +import pytest + +pytest.importorskip("mcp", reason="MCP requires Python 3.10+") + +from unittest.mock import Mock + +from marimo._ai._tools.types import ( + MarimoCellErrors, + MarimoErrorDetail, + MarimoNotebookInfo, +) +from marimo._mcp.server._prompts.prompts.errors import ErrorsSummary +from marimo._types.ids import CellId_t, SessionId + + +def test_errors_summary_includes_essential_data(): + """Test that essential data (notebook info, cell IDs, error messages) are included.""" + context = Mock() + context.get_active_sessions_internal.return_value = [ + MarimoNotebookInfo( + name="notebook.py", + path="/path/to/notebook.py", + session_id=SessionId("session_1"), + ) + ] + context.get_notebook_errors.return_value = [ + MarimoCellErrors( + cell_id=CellId_t("cell_1"), + errors=[ + MarimoErrorDetail( + type="ValueError", + message="invalid value", + traceback=["line 1"], + ) + ], + ), + MarimoCellErrors( + cell_id=CellId_t("cell_2"), + errors=[ + MarimoErrorDetail( + type="TypeError", + message="wrong type", + traceback=["line 2"], + ) + ], + ), + ] + + prompt = ErrorsSummary(context=context) + messages = prompt.handle() + text = "\n".join( + msg.content.text # type: ignore[attr-defined] + for msg in messages + if hasattr(msg.content, "text") + ) + + # Essential data must be present + assert "notebook.py" in text + assert "session_1" in text + assert "/path/to/notebook.py" in text + assert "cell_1" in text + assert "cell_2" in text + assert "ValueError" in text + assert "invalid value" in text + assert "TypeError" in text + assert "wrong type" in text + diff --git a/tests/_mcp/server/prompts/test_notebooks_prompts.py b/tests/_mcp/server/prompts/test_notebooks_prompts.py index 45f044d01de..212265eb98f 100644 --- a/tests/_mcp/server/prompts/test_notebooks_prompts.py +++ b/tests/_mcp/server/prompts/test_notebooks_prompts.py @@ -4,73 +4,37 @@ from unittest.mock import Mock +from marimo._ai._tools.types import MarimoNotebookInfo from marimo._mcp.server._prompts.prompts.notebooks import ActiveNotebooks +from marimo._types.ids import SessionId -def test_active_notebooks_metadata(): - """Test that name and description are properly set.""" - prompt = ActiveNotebooks(context=Mock()) - assert prompt.name == "active_notebooks" - assert ( - prompt.description - == "Get current active notebooks and their session IDs and file paths." - ) - - -def test_active_notebooks_no_sessions(): - """Test output when no sessions are active.""" +def test_active_notebooks_includes_session_ids_and_paths(): + """Test that essential data (session IDs and file paths) are included.""" context = Mock() - context.get_active_sessions_internal.return_value = [] - - prompt = ActiveNotebooks(context=context) - messages = prompt.handle() - - assert len(messages) == 1 - assert messages[0].role == "user" - assert messages[0].content.type == "text" - assert ( - "No active marimo notebook sessions found" in messages[0].content.text - ) - - -def test_active_notebooks_with_sessions(): - """Test output with active sessions.""" - context = Mock() - - # Mock active session objects - active_session1 = Mock() - active_session1.session_id = "session_1" - active_session1.path = "/path/to/notebook.py" - - active_session2 = Mock() - active_session2.session_id = "session_2" - active_session2.path = None - context.get_active_sessions_internal.return_value = [ - active_session1, - active_session2, + MarimoNotebookInfo( + name="notebook.py", + path="/path/to/notebook.py", + session_id=SessionId("session_1"), + ), + MarimoNotebookInfo( + name="other.py", + path="/other/path.py", + session_id=SessionId("session_2"), + ), ] prompt = ActiveNotebooks(context=context) messages = prompt.handle() - - assert len(messages) == 3 # 2 sessions + 1 action message - - # Check first message (with file path) - assert messages[0].role == "user" - assert messages[0].content.type == "text" - assert "session_1" in messages[0].content.text - assert "/path/to/notebook.py" in messages[0].content.text - - # Check second message (without file path) - assert messages[1].role == "user" - assert messages[1].content.type == "text" - assert "session_2" in messages[1].content.text - - # Check action message - assert messages[2].role == "user" - assert messages[2].content.type == "text" - assert ( - "Use these session_ids when calling marimo MCP tools" - in messages[2].content.text + text = "\n".join( + msg.content.text # type: ignore[attr-defined] + for msg in messages + if hasattr(msg.content, "text") ) + + # Essential data must be present + assert "session_1" in text + assert "session_2" in text + assert "/path/to/notebook.py" in text + assert "/other/path.py" in text From 8dcb1d25e0714dae36115cf71aa7e6eeddce2f6a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Oct 2025 02:08:01 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- marimo/_ai/_tools/base.py | 16 +++--- marimo/_ai/_tools/tools/cells.py | 10 ++-- marimo/_ai/_tools/types.py | 3 +- marimo/_mcp/server/_prompts/prompts/errors.py | 4 +- tests/_ai/tools/test_base.py | 56 ++++++++++++------- tests/_ai/tools/tools/test_errors_tool.py | 2 +- .../server/prompts/test_errors_prompts.py | 1 - 7 files changed, 54 insertions(+), 38 deletions(-) diff --git a/marimo/_ai/_tools/base.py b/marimo/_ai/_tools/base.py index 1c8cb64b6a9..8491412cac8 100644 --- a/marimo/_ai/_tools/base.py +++ b/marimo/_ai/_tools/base.py @@ -18,10 +18,10 @@ from marimo import _loggers from marimo._ai._tools.types import ( + MarimoCellConsoleOutputs, MarimoCellErrors, MarimoErrorDetail, MarimoNotebookInfo, - MarimoCellConsoleOutputs, ToolGuidelines, ) from marimo._ai._tools.utils.exceptions import ToolExecutionError @@ -168,9 +168,7 @@ def get_notebook_errors( for cell_data in cell_manager.cell_data(): cell_id = cell_data.cell_id if cell_id in cell_errors_map: - notebook_errors.append( - cell_errors_map[cell_id] - ) + notebook_errors.append(cell_errors_map[cell_id]) return notebook_errors @@ -186,7 +184,10 @@ def get_cell_errors( errors: list[MarimoErrorDetail] = [] cell_op = maybe_cell_op or self.get_cell_ops(session_id, cell_id) - if not cell_op.output or cell_op.output.channel != CellChannel.MARIMO_ERROR: + if ( + not cell_op.output + or cell_op.output.channel != CellChannel.MARIMO_ERROR + ): return errors items = cell_op.output.data @@ -233,7 +234,7 @@ def get_cell_console_outputs( """ stdout_messages: list[str] = [] stderr_messages: list[str] = [] - + if cell_op.console is None: return MarimoCellConsoleOutputs(stdout=[], stderr=[]) @@ -254,8 +255,7 @@ def get_cell_console_outputs( cleaned_stderr_messages = clean_output(stderr_messages) return MarimoCellConsoleOutputs( - stdout=cleaned_stdout_messages, - stderr=cleaned_stderr_messages + stdout=cleaned_stdout_messages, stderr=cleaned_stderr_messages ) diff --git a/marimo/_ai/_tools/tools/cells.py b/marimo/_ai/_tools/tools/cells.py index 454eb7b815e..d12097eb87e 100644 --- a/marimo/_ai/_tools/tools/cells.py +++ b/marimo/_ai/_tools/tools/cells.py @@ -7,10 +7,10 @@ from marimo._ai._tools.base import ToolBase from marimo._ai._tools.types import ( + MarimoCellConsoleOutputs, MarimoErrorDetail, SuccessResult, ToolGuidelines, - MarimoCellConsoleOutputs, ) from marimo._ai._tools.utils.exceptions import ToolExecutionError from marimo._ast.models import CellData @@ -108,7 +108,9 @@ class GetCellOutputArgs: @dataclass class GetCellOutputOutput(SuccessResult): visual_output: CellVisualOutput = field(default_factory=CellVisualOutput) - console_outputs: MarimoCellConsoleOutputs = field(default_factory=MarimoCellConsoleOutputs) + console_outputs: MarimoCellConsoleOutputs = field( + default_factory=MarimoCellConsoleOutputs + ) class GetLightweightCellMap( @@ -256,9 +258,7 @@ def handle(self, args: GetCellRuntimeDataArgs) -> GetCellRuntimeDataOutput: cell_code = cell_data.code # Get cell errors from session view with actual error details - cell_errors = context.get_cell_errors( - session_id, cell_id - ) + cell_errors = context.get_cell_errors(session_id, cell_id) # Get cell runtime metadata cell_metadata = self._get_cell_metadata(session, cell_id) diff --git a/marimo/_ai/_tools/types.py b/marimo/_ai/_tools/types.py index 6af2331cec1..9b3e68b0c97 100644 --- a/marimo/_ai/_tools/types.py +++ b/marimo/_ai/_tools/types.py @@ -56,7 +56,8 @@ class MarimoErrorDetail: message: str traceback: list[str] + @dataclass class MarimoCellConsoleOutputs: stdout: list[str] = field(default_factory=list) - stderr: list[str] = field(default_factory=list) \ No newline at end of file + stderr: list[str] = field(default_factory=list) diff --git a/marimo/_mcp/server/_prompts/prompts/errors.py b/marimo/_mcp/server/_prompts/prompts/errors.py index 2166b804139..d59e452bcef 100644 --- a/marimo/_mcp/server/_prompts/prompts/errors.py +++ b/marimo/_mcp/server/_prompts/prompts/errors.py @@ -40,7 +40,9 @@ def handle(self) -> list[PromptMessage]: for notebook in notebooks: session_id = notebook.session_id - notebook_errors = context.get_notebook_errors(session_id, include_stderr=False) + notebook_errors = context.get_notebook_errors( + session_id, include_stderr=False + ) if len(notebook_errors) == 0: continue diff --git a/tests/_ai/tools/test_base.py b/tests/_ai/tools/test_base.py index 4857ea9d1f1..9c877ce725e 100644 --- a/tests/_ai/tools/test_base.py +++ b/tests/_ai/tools/test_base.py @@ -127,23 +127,26 @@ def test_as_backend_tool() -> None: assert is_valid is False assert "Invalid arguments" in msg + # test ToolContext methods + def test_get_notebook_errors_orders_by_cell_manager(): """Test errors follow cell_manager order, not alphabetical.""" from unittest.mock import Mock + from marimo._messaging.cell_output import CellChannel from marimo._types.ids import CellId_t, SessionId context = ToolContext() - + # Mock error cell_op error_op = Mock() error_op.output = Mock() error_op.output.channel = CellChannel.MARIMO_ERROR error_op.output.data = [{"type": "Error", "msg": "test", "traceback": []}] error_op.console = None - + # Mock session with cells c1, c2, c3 session = Mock() session.session_view.cell_operations = { @@ -151,15 +154,23 @@ def test_get_notebook_errors_orders_by_cell_manager(): CellId_t("c2"): error_op, CellId_t("c3"): error_op, } - + # Cell manager returns in order: c3, c2, c1 (not alphabetical) - cell_data = [Mock(cell_id=CellId_t("c3")), Mock(cell_id=CellId_t("c2")), Mock(cell_id=CellId_t("c1"))] - session.app_file_manager.app.cell_manager.cell_data.return_value = cell_data - + cell_data = [ + Mock(cell_id=CellId_t("c3")), + Mock(cell_id=CellId_t("c2")), + Mock(cell_id=CellId_t("c1")), + ] + session.app_file_manager.app.cell_manager.cell_data.return_value = ( + cell_data + ) + context.get_session = Mock(return_value=session) - - errors = context.get_notebook_errors(SessionId("test"), include_stderr=False) - + + errors = context.get_notebook_errors( + SessionId("test"), include_stderr=False + ) + # Should be c3, c2, c1 (not c1, c2, c3) assert errors[0].cell_id == CellId_t("c3") assert errors[1].cell_id == CellId_t("c2") @@ -169,11 +180,12 @@ def test_get_notebook_errors_orders_by_cell_manager(): def test_get_cell_errors_extracts_from_output(): """Test get_cell_errors extracts error details from cell output.""" from unittest.mock import Mock + from marimo._messaging.cell_output import CellChannel from marimo._types.ids import CellId_t, SessionId - + context = ToolContext() - + # Mock cell_op with error cell_op = Mock() cell_op.output = Mock() @@ -181,9 +193,11 @@ def test_get_cell_errors_extracts_from_output(): cell_op.output.data = [ {"type": "ValueError", "msg": "bad value", "traceback": ["line 1"]} ] - - errors = context.get_cell_errors(SessionId("test"), CellId_t("c1"), maybe_cell_op=cell_op) - + + errors = context.get_cell_errors( + SessionId("test"), CellId_t("c1"), maybe_cell_op=cell_op + ) + assert len(errors) == 1 assert errors[0].type == "ValueError" assert errors[0].message == "bad value" @@ -193,26 +207,26 @@ def test_get_cell_errors_extracts_from_output(): def test_get_cell_console_outputs_separates_stdout_stderr(): """Test get_cell_console_outputs separates stdout and stderr.""" from unittest.mock import Mock + from marimo._messaging.cell_output import CellChannel - + context = ToolContext() - + # Mock cell_op with stdout and stderr stdout_output = Mock() stdout_output.channel = CellChannel.STDOUT stdout_output.data = "hello" - + stderr_output = Mock() stderr_output.channel = CellChannel.STDERR stderr_output.data = "warning" - + cell_op = Mock() cell_op.console = [stdout_output, stderr_output] - + result = context.get_cell_console_outputs(cell_op) - + assert len(result.stdout) == 1 assert "hello" in result.stdout[0] assert len(result.stderr) == 1 assert "warning" in result.stderr[0] - diff --git a/tests/_ai/tools/tools/test_errors_tool.py b/tests/_ai/tools/tools/test_errors_tool.py index 2b8684a8110..2bb8f6578fe 100644 --- a/tests/_ai/tools/tools/test_errors_tool.py +++ b/tests/_ai/tools/tools/test_errors_tool.py @@ -145,7 +145,7 @@ def test_get_notebook_errors_multiple_cells(mock_context: Mock) -> None: type="ValueError", message="error in c2", traceback=[], - ) + ), ], stderr=[], ), diff --git a/tests/_mcp/server/prompts/test_errors_prompts.py b/tests/_mcp/server/prompts/test_errors_prompts.py index b595a4ff37b..622b3580129 100644 --- a/tests/_mcp/server/prompts/test_errors_prompts.py +++ b/tests/_mcp/server/prompts/test_errors_prompts.py @@ -64,4 +64,3 @@ def test_errors_summary_includes_essential_data(): assert "invalid value" in text assert "TypeError" in text assert "wrong type" in text - From dbd0ccb0a208da4e0918f2d12f57b94f5c6855c8 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Thu, 30 Oct 2025 03:16:24 +0100 Subject: [PATCH 8/8] ran make py-check --- marimo/_ai/_tools/base.py | 16 +++--- marimo/_ai/_tools/tools/cells.py | 10 ++-- marimo/_ai/_tools/types.py | 3 +- marimo/_mcp/server/_prompts/prompts/errors.py | 4 +- tests/_ai/tools/test_base.py | 57 ++++++++++++------- tests/_ai/tools/tools/test_cells.py | 1 - tests/_ai/tools/tools/test_errors_tool.py | 2 +- .../server/prompts/test_errors_prompts.py | 1 - 8 files changed, 54 insertions(+), 40 deletions(-) diff --git a/marimo/_ai/_tools/base.py b/marimo/_ai/_tools/base.py index 1c8cb64b6a9..8491412cac8 100644 --- a/marimo/_ai/_tools/base.py +++ b/marimo/_ai/_tools/base.py @@ -18,10 +18,10 @@ from marimo import _loggers from marimo._ai._tools.types import ( + MarimoCellConsoleOutputs, MarimoCellErrors, MarimoErrorDetail, MarimoNotebookInfo, - MarimoCellConsoleOutputs, ToolGuidelines, ) from marimo._ai._tools.utils.exceptions import ToolExecutionError @@ -168,9 +168,7 @@ def get_notebook_errors( for cell_data in cell_manager.cell_data(): cell_id = cell_data.cell_id if cell_id in cell_errors_map: - notebook_errors.append( - cell_errors_map[cell_id] - ) + notebook_errors.append(cell_errors_map[cell_id]) return notebook_errors @@ -186,7 +184,10 @@ def get_cell_errors( errors: list[MarimoErrorDetail] = [] cell_op = maybe_cell_op or self.get_cell_ops(session_id, cell_id) - if not cell_op.output or cell_op.output.channel != CellChannel.MARIMO_ERROR: + if ( + not cell_op.output + or cell_op.output.channel != CellChannel.MARIMO_ERROR + ): return errors items = cell_op.output.data @@ -233,7 +234,7 @@ def get_cell_console_outputs( """ stdout_messages: list[str] = [] stderr_messages: list[str] = [] - + if cell_op.console is None: return MarimoCellConsoleOutputs(stdout=[], stderr=[]) @@ -254,8 +255,7 @@ def get_cell_console_outputs( cleaned_stderr_messages = clean_output(stderr_messages) return MarimoCellConsoleOutputs( - stdout=cleaned_stdout_messages, - stderr=cleaned_stderr_messages + stdout=cleaned_stdout_messages, stderr=cleaned_stderr_messages ) diff --git a/marimo/_ai/_tools/tools/cells.py b/marimo/_ai/_tools/tools/cells.py index 454eb7b815e..d12097eb87e 100644 --- a/marimo/_ai/_tools/tools/cells.py +++ b/marimo/_ai/_tools/tools/cells.py @@ -7,10 +7,10 @@ from marimo._ai._tools.base import ToolBase from marimo._ai._tools.types import ( + MarimoCellConsoleOutputs, MarimoErrorDetail, SuccessResult, ToolGuidelines, - MarimoCellConsoleOutputs, ) from marimo._ai._tools.utils.exceptions import ToolExecutionError from marimo._ast.models import CellData @@ -108,7 +108,9 @@ class GetCellOutputArgs: @dataclass class GetCellOutputOutput(SuccessResult): visual_output: CellVisualOutput = field(default_factory=CellVisualOutput) - console_outputs: MarimoCellConsoleOutputs = field(default_factory=MarimoCellConsoleOutputs) + console_outputs: MarimoCellConsoleOutputs = field( + default_factory=MarimoCellConsoleOutputs + ) class GetLightweightCellMap( @@ -256,9 +258,7 @@ def handle(self, args: GetCellRuntimeDataArgs) -> GetCellRuntimeDataOutput: cell_code = cell_data.code # Get cell errors from session view with actual error details - cell_errors = context.get_cell_errors( - session_id, cell_id - ) + cell_errors = context.get_cell_errors(session_id, cell_id) # Get cell runtime metadata cell_metadata = self._get_cell_metadata(session, cell_id) diff --git a/marimo/_ai/_tools/types.py b/marimo/_ai/_tools/types.py index 6af2331cec1..9b3e68b0c97 100644 --- a/marimo/_ai/_tools/types.py +++ b/marimo/_ai/_tools/types.py @@ -56,7 +56,8 @@ class MarimoErrorDetail: message: str traceback: list[str] + @dataclass class MarimoCellConsoleOutputs: stdout: list[str] = field(default_factory=list) - stderr: list[str] = field(default_factory=list) \ No newline at end of file + stderr: list[str] = field(default_factory=list) diff --git a/marimo/_mcp/server/_prompts/prompts/errors.py b/marimo/_mcp/server/_prompts/prompts/errors.py index 2166b804139..d59e452bcef 100644 --- a/marimo/_mcp/server/_prompts/prompts/errors.py +++ b/marimo/_mcp/server/_prompts/prompts/errors.py @@ -40,7 +40,9 @@ def handle(self) -> list[PromptMessage]: for notebook in notebooks: session_id = notebook.session_id - notebook_errors = context.get_notebook_errors(session_id, include_stderr=False) + notebook_errors = context.get_notebook_errors( + session_id, include_stderr=False + ) if len(notebook_errors) == 0: continue diff --git a/tests/_ai/tools/test_base.py b/tests/_ai/tools/test_base.py index 4857ea9d1f1..3afd7485343 100644 --- a/tests/_ai/tools/test_base.py +++ b/tests/_ai/tools/test_base.py @@ -8,7 +8,6 @@ from marimo._ai._tools.base import ToolBase, ToolContext from marimo._ai._tools.utils.exceptions import ToolExecutionError -from marimo._messaging import msgspec_encoder @dataclass @@ -127,23 +126,26 @@ def test_as_backend_tool() -> None: assert is_valid is False assert "Invalid arguments" in msg + # test ToolContext methods + def test_get_notebook_errors_orders_by_cell_manager(): """Test errors follow cell_manager order, not alphabetical.""" from unittest.mock import Mock + from marimo._messaging.cell_output import CellChannel from marimo._types.ids import CellId_t, SessionId context = ToolContext() - + # Mock error cell_op error_op = Mock() error_op.output = Mock() error_op.output.channel = CellChannel.MARIMO_ERROR error_op.output.data = [{"type": "Error", "msg": "test", "traceback": []}] error_op.console = None - + # Mock session with cells c1, c2, c3 session = Mock() session.session_view.cell_operations = { @@ -151,15 +153,23 @@ def test_get_notebook_errors_orders_by_cell_manager(): CellId_t("c2"): error_op, CellId_t("c3"): error_op, } - + # Cell manager returns in order: c3, c2, c1 (not alphabetical) - cell_data = [Mock(cell_id=CellId_t("c3")), Mock(cell_id=CellId_t("c2")), Mock(cell_id=CellId_t("c1"))] - session.app_file_manager.app.cell_manager.cell_data.return_value = cell_data - + cell_data = [ + Mock(cell_id=CellId_t("c3")), + Mock(cell_id=CellId_t("c2")), + Mock(cell_id=CellId_t("c1")), + ] + session.app_file_manager.app.cell_manager.cell_data.return_value = ( + cell_data + ) + context.get_session = Mock(return_value=session) - - errors = context.get_notebook_errors(SessionId("test"), include_stderr=False) - + + errors = context.get_notebook_errors( + SessionId("test"), include_stderr=False + ) + # Should be c3, c2, c1 (not c1, c2, c3) assert errors[0].cell_id == CellId_t("c3") assert errors[1].cell_id == CellId_t("c2") @@ -169,11 +179,12 @@ def test_get_notebook_errors_orders_by_cell_manager(): def test_get_cell_errors_extracts_from_output(): """Test get_cell_errors extracts error details from cell output.""" from unittest.mock import Mock + from marimo._messaging.cell_output import CellChannel from marimo._types.ids import CellId_t, SessionId - + context = ToolContext() - + # Mock cell_op with error cell_op = Mock() cell_op.output = Mock() @@ -181,9 +192,11 @@ def test_get_cell_errors_extracts_from_output(): cell_op.output.data = [ {"type": "ValueError", "msg": "bad value", "traceback": ["line 1"]} ] - - errors = context.get_cell_errors(SessionId("test"), CellId_t("c1"), maybe_cell_op=cell_op) - + + errors = context.get_cell_errors( + SessionId("test"), CellId_t("c1"), maybe_cell_op=cell_op + ) + assert len(errors) == 1 assert errors[0].type == "ValueError" assert errors[0].message == "bad value" @@ -193,26 +206,26 @@ def test_get_cell_errors_extracts_from_output(): def test_get_cell_console_outputs_separates_stdout_stderr(): """Test get_cell_console_outputs separates stdout and stderr.""" from unittest.mock import Mock + from marimo._messaging.cell_output import CellChannel - + context = ToolContext() - + # Mock cell_op with stdout and stderr stdout_output = Mock() stdout_output.channel = CellChannel.STDOUT stdout_output.data = "hello" - + stderr_output = Mock() stderr_output.channel = CellChannel.STDERR stderr_output.data = "warning" - + cell_op = Mock() cell_op.console = [stdout_output, stderr_output] - + result = context.get_cell_console_outputs(cell_op) - + assert len(result.stdout) == 1 assert "hello" in result.stdout[0] assert len(result.stderr) == 1 assert "warning" in result.stderr[0] - diff --git a/tests/_ai/tools/tools/test_cells.py b/tests/_ai/tools/tools/test_cells.py index d15579d1773..02137bc3a6c 100644 --- a/tests/_ai/tools/tools/test_cells.py +++ b/tests/_ai/tools/tools/test_cells.py @@ -13,7 +13,6 @@ GetCellRuntimeData, GetLightweightCellMap, ) -from marimo._messaging.cell_output import CellChannel from marimo._messaging.ops import VariableValue from marimo._server.sessions import Session from marimo._types.ids import CellId_t, SessionId diff --git a/tests/_ai/tools/tools/test_errors_tool.py b/tests/_ai/tools/tools/test_errors_tool.py index 2b8684a8110..2bb8f6578fe 100644 --- a/tests/_ai/tools/tools/test_errors_tool.py +++ b/tests/_ai/tools/tools/test_errors_tool.py @@ -145,7 +145,7 @@ def test_get_notebook_errors_multiple_cells(mock_context: Mock) -> None: type="ValueError", message="error in c2", traceback=[], - ) + ), ], stderr=[], ), diff --git a/tests/_mcp/server/prompts/test_errors_prompts.py b/tests/_mcp/server/prompts/test_errors_prompts.py index b595a4ff37b..622b3580129 100644 --- a/tests/_mcp/server/prompts/test_errors_prompts.py +++ b/tests/_mcp/server/prompts/test_errors_prompts.py @@ -64,4 +64,3 @@ def test_errors_summary_includes_essential_data(): assert "invalid value" in text assert "TypeError" in text assert "wrong type" in text -