From 0d58d5ecad2f5a15893e86dfec3b60fd9141a9ae Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Fri, 19 Sep 2025 20:14:28 +0200 Subject: [PATCH 1/3] Add GetNotebookErrors tool with tests --- marimo/_ai/_tools/tools/errors.py | 152 +++++++++++++++++++++++++++ marimo/_ai/_tools/tools_registry.py | 2 + tests/_ai/tools/tools/test_errors.py | 134 +++++++++++++++++++++++ 3 files changed, 288 insertions(+) create mode 100644 marimo/_ai/_tools/tools/errors.py create mode 100644 tests/_ai/tools/tools/test_errors.py diff --git a/marimo/_ai/_tools/tools/errors.py b/marimo/_ai/_tools/tools/errors.py new file mode 100644 index 00000000000..d1dd3ea27b4 --- /dev/null +++ b/marimo/_ai/_tools/tools/errors.py @@ -0,0 +1,152 @@ +# Copyright 2025 Marimo. All rights reserved. +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional + +from marimo import _loggers +from marimo._ai._tools.base import ToolBase +from marimo._ai._tools.tools.cells import ErrorDetail +from marimo._ai._tools.types import SuccessResult +from marimo._server.sessions import Session +from marimo._types.ids import SessionId + +LOGGER = _loggers.marimo_logger() + + +@dataclass +class GetNotebookErrorsArgs: + session_id: SessionId + + +@dataclass +class CellErrorsSummary: + cell_id: str + cell_name: str + errors: list[ErrorDetail] = field(default_factory=list) + + +@dataclass +class GetNotebookErrorsOutput(SuccessResult): + has_errors: bool = False + total_errors: int = 0 + cells: list[CellErrorsSummary] = field(default_factory=list) + + +class GetNotebookErrors( + ToolBase[GetNotebookErrorsArgs, GetNotebookErrorsOutput] +): + """ + Get all errors in the current notebook session, organized by cell. + + Args: + session_id: The session ID of the notebook. + + Returns: + A success result containing per-cell error details and totals. + """ + + def handle(self, args: GetNotebookErrorsArgs) -> GetNotebookErrorsOutput: + session = self.context.get_session(args.session_id) + summaries = self._collect_errors(session) + + total_errors = sum(len(s.errors) for s in summaries) + has_errors = total_errors > 0 + + return GetNotebookErrorsOutput( + has_errors=has_errors, + total_errors=total_errors, + cells=summaries, + next_steps=( + [ + "Use get_cell_runtime_data to inspect the impacted cells to fix syntax/runtime issues", + "Re-run the notebook after addressing the errors", + ] + if has_errors + else ["No errors detected"] + ), + ) + + # helpers + def _collect_errors(self, session: Session) -> list[CellErrorsSummary]: + from marimo._messaging.cell_output import CellChannel + + cell_manager = session.app_file_manager.app.cell_manager + session_view = session.session_view + + summaries: list[CellErrorsSummary] = [] + for cell_id, cell_op in session_view.cell_operations.items(): + # Resolve cell name when possible + cell_data = cell_manager.get_cell_data(cell_id) + cell_name = cell_data.name if cell_data else str(cell_id) + + errors: list[ErrorDetail] = [] + + # Collect structured marimo errors from output + if ( + cell_op.output + and cell_op.output.channel == CellChannel.MARIMO_ERROR + ): + items = cell_op.output.data + if isinstance(items, list): + for err in items: + if isinstance(err, dict): + errors.append( + ErrorDetail( + type=err.get("type", "UnknownError"), + message=err.get("msg", str(err)), + traceback=err.get("traceback", []), + ) + ) + else: + # Fallback for rich error objects + err_type: str = getattr( + err, "type", type(err).__name__ + ) + describe_fn: Optional[Any] = getattr( + err, "describe", None + ) + message_val = ( + describe_fn() + if callable(describe_fn) + else str(err) + ) + message: str = str(message_val) + tb: list[str] = getattr(err, "traceback", []) or [] + errors.append( + ErrorDetail( + type=err_type, + message=message, + traceback=tb, + ) + ) + + # Collect stderr messages as error details + if cell_op.console: + console_outputs = ( + cell_op.console + if isinstance(cell_op.console, list) + else [cell_op.console] + ) + for console in console_outputs: + if console.channel == CellChannel.STDERR: + errors.append( + ErrorDetail( + type="STDERR", + message=str(console.data), + traceback=[], + ) + ) + + if errors: + summaries.append( + CellErrorsSummary( + cell_id=str(cell_id), + cell_name=cell_name, + errors=errors, + ) + ) + + # Sort by cell_id for stable output + summaries.sort(key=lambda s: s.cell_id) + return summaries diff --git a/marimo/_ai/_tools/tools_registry.py b/marimo/_ai/_tools/tools_registry.py index e023c718f03..8d6729a05d2 100644 --- a/marimo/_ai/_tools/tools_registry.py +++ b/marimo/_ai/_tools/tools_registry.py @@ -7,6 +7,7 @@ GetLightweightCellMap, ) from marimo._ai._tools.tools.datasource import GetDatabaseTables +from marimo._ai._tools.tools.errors import GetNotebookErrors from marimo._ai._tools.tools.notebooks import GetActiveNotebooks from marimo._ai._tools.tools.tables_and_variables import GetTablesAndVariables @@ -16,4 +17,5 @@ GetLightweightCellMap, GetTablesAndVariables, GetDatabaseTables, + GetNotebookErrors, ] diff --git a/tests/_ai/tools/tools/test_errors.py b/tests/_ai/tools/tools/test_errors.py new file mode 100644 index 00000000000..92b6146a9f8 --- /dev/null +++ b/tests/_ai/tools/tools/test_errors.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from dataclasses import dataclass +from unittest.mock import Mock + +from marimo._ai._tools.base import ToolContext +from marimo._ai._tools.tools.errors import ( + GetNotebookErrors, + GetNotebookErrorsArgs, +) +from marimo._messaging.cell_output import CellChannel +from marimo._types.ids import SessionId + + +@dataclass +class MockCellOp: + output: object | None = None + console: object | None = None + + +@dataclass +class MockOutput: + channel: object + data: object + + +@dataclass +class MockConsoleOutput: + channel: object + data: object + + +@dataclass +class MockSessionView: + cell_operations: dict | None = None + + def __post_init__(self) -> None: + if self.cell_operations is None: + self.cell_operations = {} + + +@dataclass +class DummyCellData: + name: str + + +class DummyCellManager: + def __init__(self, names: dict[str, str]) -> None: + self._names = names + + def get_cell_data(self, cell_id: str) -> DummyCellData: + return DummyCellData(self._names.get(cell_id, cell_id)) + + +@dataclass +class DummyAppFileManager: + app: object + + +@dataclass +class MockSession: + session_view: MockSessionView + app_file_manager: DummyAppFileManager + + +def test_collect_errors_none() -> None: + tool = GetNotebookErrors(ToolContext()) + + # Empty session view + session = MockSession( + session_view=MockSessionView(), + app_file_manager=DummyAppFileManager( + app=Mock(cell_manager=DummyCellManager({})) + ), + ) + + summaries = tool._collect_errors(session) # type: ignore[arg-type] + assert summaries == [] + + +def test_collect_errors_marimo_and_stderr() -> None: + tool = GetNotebookErrors(ToolContext()) + + # Cell c1 has MARIMO_ERROR (dict) and STDERR; c2 has STDERR only + err_dict = {"type": "ValueError", "msg": "bad value", "traceback": ["tb"]} + c1 = MockCellOp( + output=MockOutput(CellChannel.MARIMO_ERROR, [err_dict]), + console=[MockConsoleOutput(CellChannel.STDERR, "warn")], + ) + c2 = MockCellOp(console=[MockConsoleOutput(CellChannel.STDERR, "oops")]) + + names = {"c1": "Cell 1", "c2": "Cell 2"} + session = MockSession( + session_view=MockSessionView(cell_operations={"c1": c1, "c2": c2}), + app_file_manager=DummyAppFileManager( + app=Mock(cell_manager=DummyCellManager(names)) + ), + ) + + summaries = tool._collect_errors(session) # type: ignore[arg-type] + + # Sorted by cell_id: c1 then c2 + assert len(summaries) == 2 + assert summaries[0].cell_id == "c1" + assert summaries[0].cell_name == "Cell 1" + assert len(summaries[0].errors) == 2 # one MARIMO_ERROR, one STDERR + assert summaries[0].errors[0].type == "ValueError" + assert summaries[1].cell_id == "c2" + assert summaries[1].cell_name == "Cell 2" + assert len(summaries[1].errors) == 1 + assert summaries[1].errors[0].type == "STDERR" + + +def test_handle_integration_uses_context_get_session() -> None: + tool = GetNotebookErrors(ToolContext()) + + c1 = MockCellOp(console=[MockConsoleOutput(CellChannel.STDERR, "warn")]) + session = MockSession( + session_view=MockSessionView(cell_operations={"c1": c1}), + app_file_manager=DummyAppFileManager( + app=Mock(cell_manager=DummyCellManager({"c1": "Cell 1"})) + ), + ) + + # Mock ToolContext.get_session + context = Mock(spec=ToolContext) + context.get_session.return_value = session + tool.context = context # type: ignore[assignment] + + out = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1"))) + assert out.has_errors is True + assert out.total_errors == 1 + assert len(out.cells) == 1 + assert out.cells[0].cell_id == "c1" From 42dde78a2382f9cd70eb6ae5cca99554d87914eb Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Fri, 19 Sep 2025 20:28:42 +0200 Subject: [PATCH 2/3] Rename test_errors.py -> test_errors_tool.py to avoid naming conflict --- tests/_ai/tools/tools/{test_errors.py => test_errors_tool.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/_ai/tools/tools/{test_errors.py => test_errors_tool.py} (100%) diff --git a/tests/_ai/tools/tools/test_errors.py b/tests/_ai/tools/tools/test_errors_tool.py similarity index 100% rename from tests/_ai/tools/tools/test_errors.py rename to tests/_ai/tools/tools/test_errors_tool.py From 017aa5d8bdc23a86cd8ccf40443b45bc8da44120 Mon Sep 17 00:00:00 2001 From: Joaquin Coromina Date: Fri, 19 Sep 2025 21:34:11 +0200 Subject: [PATCH 3/3] Update cell_id type to CellId_t and removed cell_name --- marimo/_ai/_tools/tools/errors.py | 13 +++-------- tests/_ai/tools/tools/test_errors_tool.py | 28 +++-------------------- 2 files changed, 6 insertions(+), 35 deletions(-) diff --git a/marimo/_ai/_tools/tools/errors.py b/marimo/_ai/_tools/tools/errors.py index d1dd3ea27b4..820537d75fb 100644 --- a/marimo/_ai/_tools/tools/errors.py +++ b/marimo/_ai/_tools/tools/errors.py @@ -9,7 +9,7 @@ from marimo._ai._tools.tools.cells import ErrorDetail from marimo._ai._tools.types import SuccessResult from marimo._server.sessions import Session -from marimo._types.ids import SessionId +from marimo._types.ids import CellId_t, SessionId LOGGER = _loggers.marimo_logger() @@ -21,8 +21,7 @@ class GetNotebookErrorsArgs: @dataclass class CellErrorsSummary: - cell_id: str - cell_name: str + cell_id: CellId_t errors: list[ErrorDetail] = field(default_factory=list) @@ -71,15 +70,10 @@ def handle(self, args: GetNotebookErrorsArgs) -> GetNotebookErrorsOutput: def _collect_errors(self, session: Session) -> list[CellErrorsSummary]: from marimo._messaging.cell_output import CellChannel - cell_manager = session.app_file_manager.app.cell_manager session_view = session.session_view summaries: list[CellErrorsSummary] = [] for cell_id, cell_op in session_view.cell_operations.items(): - # Resolve cell name when possible - cell_data = cell_manager.get_cell_data(cell_id) - cell_name = cell_data.name if cell_data else str(cell_id) - errors: list[ErrorDetail] = [] # Collect structured marimo errors from output @@ -141,8 +135,7 @@ def _collect_errors(self, session: Session) -> list[CellErrorsSummary]: if errors: summaries.append( CellErrorsSummary( - cell_id=str(cell_id), - cell_name=cell_name, + cell_id=cell_id, errors=errors, ) ) diff --git a/tests/_ai/tools/tools/test_errors_tool.py b/tests/_ai/tools/tools/test_errors_tool.py index 92b6146a9f8..2a503b2430f 100644 --- a/tests/_ai/tools/tools/test_errors_tool.py +++ b/tests/_ai/tools/tools/test_errors_tool.py @@ -39,19 +39,6 @@ def __post_init__(self) -> None: self.cell_operations = {} -@dataclass -class DummyCellData: - name: str - - -class DummyCellManager: - def __init__(self, names: dict[str, str]) -> None: - self._names = names - - def get_cell_data(self, cell_id: str) -> DummyCellData: - return DummyCellData(self._names.get(cell_id, cell_id)) - - @dataclass class DummyAppFileManager: app: object @@ -69,9 +56,7 @@ def test_collect_errors_none() -> None: # Empty session view session = MockSession( session_view=MockSessionView(), - app_file_manager=DummyAppFileManager( - app=Mock(cell_manager=DummyCellManager({})) - ), + app_file_manager=DummyAppFileManager(app=Mock()), ) summaries = tool._collect_errors(session) # type: ignore[arg-type] @@ -89,12 +74,9 @@ def test_collect_errors_marimo_and_stderr() -> None: ) c2 = MockCellOp(console=[MockConsoleOutput(CellChannel.STDERR, "oops")]) - names = {"c1": "Cell 1", "c2": "Cell 2"} session = MockSession( session_view=MockSessionView(cell_operations={"c1": c1, "c2": c2}), - app_file_manager=DummyAppFileManager( - app=Mock(cell_manager=DummyCellManager(names)) - ), + app_file_manager=DummyAppFileManager(app=Mock()), ) summaries = tool._collect_errors(session) # type: ignore[arg-type] @@ -102,11 +84,9 @@ def test_collect_errors_marimo_and_stderr() -> None: # Sorted by cell_id: c1 then c2 assert len(summaries) == 2 assert summaries[0].cell_id == "c1" - assert summaries[0].cell_name == "Cell 1" assert len(summaries[0].errors) == 2 # one MARIMO_ERROR, one STDERR assert summaries[0].errors[0].type == "ValueError" assert summaries[1].cell_id == "c2" - assert summaries[1].cell_name == "Cell 2" assert len(summaries[1].errors) == 1 assert summaries[1].errors[0].type == "STDERR" @@ -117,9 +97,7 @@ def test_handle_integration_uses_context_get_session() -> None: c1 = MockCellOp(console=[MockConsoleOutput(CellChannel.STDERR, "warn")]) session = MockSession( session_view=MockSessionView(cell_operations={"c1": c1}), - app_file_manager=DummyAppFileManager( - app=Mock(cell_manager=DummyCellManager({"c1": "Cell 1"})) - ), + app_file_manager=DummyAppFileManager(app=Mock()), ) # Mock ToolContext.get_session