diff --git a/marimo/_ai/_tools/tools/errors.py b/marimo/_ai/_tools/tools/errors.py new file mode 100644 index 00000000000..820537d75fb --- /dev/null +++ b/marimo/_ai/_tools/tools/errors.py @@ -0,0 +1,145 @@ +# Copyright 2025 Marimo. All rights reserved. +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional + +from marimo import _loggers +from marimo._ai._tools.base import ToolBase +from marimo._ai._tools.tools.cells import ErrorDetail +from marimo._ai._tools.types import SuccessResult +from marimo._server.sessions import Session +from marimo._types.ids import CellId_t, SessionId + +LOGGER = _loggers.marimo_logger() + + +@dataclass +class GetNotebookErrorsArgs: + session_id: SessionId + + +@dataclass +class CellErrorsSummary: + cell_id: CellId_t + errors: list[ErrorDetail] = field(default_factory=list) + + +@dataclass +class GetNotebookErrorsOutput(SuccessResult): + has_errors: bool = False + total_errors: int = 0 + cells: list[CellErrorsSummary] = field(default_factory=list) + + +class GetNotebookErrors( + ToolBase[GetNotebookErrorsArgs, GetNotebookErrorsOutput] +): + """ + Get all errors in the current notebook session, organized by cell. + + Args: + session_id: The session ID of the notebook. + + Returns: + A success result containing per-cell error details and totals. + """ + + def handle(self, args: GetNotebookErrorsArgs) -> GetNotebookErrorsOutput: + session = self.context.get_session(args.session_id) + summaries = self._collect_errors(session) + + total_errors = sum(len(s.errors) for s in summaries) + has_errors = total_errors > 0 + + return GetNotebookErrorsOutput( + has_errors=has_errors, + total_errors=total_errors, + cells=summaries, + next_steps=( + [ + "Use get_cell_runtime_data to inspect the impacted cells to fix syntax/runtime issues", + "Re-run the notebook after addressing the errors", + ] + if has_errors + else ["No errors detected"] + ), + ) + + # helpers + def _collect_errors(self, session: Session) -> list[CellErrorsSummary]: + from marimo._messaging.cell_output import CellChannel + + session_view = session.session_view + + summaries: list[CellErrorsSummary] = [] + for cell_id, cell_op in session_view.cell_operations.items(): + errors: list[ErrorDetail] = [] + + # Collect structured marimo errors from output + if ( + cell_op.output + and cell_op.output.channel == CellChannel.MARIMO_ERROR + ): + items = cell_op.output.data + if isinstance(items, list): + for err in items: + if isinstance(err, dict): + errors.append( + ErrorDetail( + type=err.get("type", "UnknownError"), + message=err.get("msg", str(err)), + traceback=err.get("traceback", []), + ) + ) + else: + # Fallback for rich error objects + err_type: str = getattr( + err, "type", type(err).__name__ + ) + describe_fn: Optional[Any] = getattr( + err, "describe", None + ) + message_val = ( + describe_fn() + if callable(describe_fn) + else str(err) + ) + message: str = str(message_val) + tb: list[str] = getattr(err, "traceback", []) or [] + errors.append( + ErrorDetail( + type=err_type, + message=message, + traceback=tb, + ) + ) + + # Collect stderr messages as error details + if cell_op.console: + console_outputs = ( + cell_op.console + if isinstance(cell_op.console, list) + else [cell_op.console] + ) + for console in console_outputs: + if console.channel == CellChannel.STDERR: + errors.append( + ErrorDetail( + type="STDERR", + message=str(console.data), + traceback=[], + ) + ) + + if errors: + summaries.append( + CellErrorsSummary( + cell_id=cell_id, + errors=errors, + ) + ) + + # Sort by cell_id for stable output + summaries.sort(key=lambda s: s.cell_id) + return summaries diff --git a/marimo/_ai/_tools/tools_registry.py b/marimo/_ai/_tools/tools_registry.py index e023c718f03..8d6729a05d2 100644 --- a/marimo/_ai/_tools/tools_registry.py +++ b/marimo/_ai/_tools/tools_registry.py @@ -7,6 +7,7 @@ GetLightweightCellMap, ) from marimo._ai._tools.tools.datasource import GetDatabaseTables +from marimo._ai._tools.tools.errors import GetNotebookErrors from marimo._ai._tools.tools.notebooks import GetActiveNotebooks from marimo._ai._tools.tools.tables_and_variables import GetTablesAndVariables @@ -16,4 +17,5 @@ GetLightweightCellMap, GetTablesAndVariables, GetDatabaseTables, + GetNotebookErrors, ] diff --git a/tests/_ai/tools/tools/test_errors_tool.py b/tests/_ai/tools/tools/test_errors_tool.py new file mode 100644 index 00000000000..2a503b2430f --- /dev/null +++ b/tests/_ai/tools/tools/test_errors_tool.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from dataclasses import dataclass +from unittest.mock import Mock + +from marimo._ai._tools.base import ToolContext +from marimo._ai._tools.tools.errors import ( + GetNotebookErrors, + GetNotebookErrorsArgs, +) +from marimo._messaging.cell_output import CellChannel +from marimo._types.ids import SessionId + + +@dataclass +class MockCellOp: + output: object | None = None + console: object | None = None + + +@dataclass +class MockOutput: + channel: object + data: object + + +@dataclass +class MockConsoleOutput: + channel: object + data: object + + +@dataclass +class MockSessionView: + cell_operations: dict | None = None + + def __post_init__(self) -> None: + if self.cell_operations is None: + self.cell_operations = {} + + +@dataclass +class DummyAppFileManager: + app: object + + +@dataclass +class MockSession: + session_view: MockSessionView + app_file_manager: DummyAppFileManager + + +def test_collect_errors_none() -> None: + tool = GetNotebookErrors(ToolContext()) + + # Empty session view + session = MockSession( + session_view=MockSessionView(), + app_file_manager=DummyAppFileManager(app=Mock()), + ) + + summaries = tool._collect_errors(session) # type: ignore[arg-type] + assert summaries == [] + + +def test_collect_errors_marimo_and_stderr() -> None: + tool = GetNotebookErrors(ToolContext()) + + # Cell c1 has MARIMO_ERROR (dict) and STDERR; c2 has STDERR only + err_dict = {"type": "ValueError", "msg": "bad value", "traceback": ["tb"]} + c1 = MockCellOp( + output=MockOutput(CellChannel.MARIMO_ERROR, [err_dict]), + console=[MockConsoleOutput(CellChannel.STDERR, "warn")], + ) + c2 = MockCellOp(console=[MockConsoleOutput(CellChannel.STDERR, "oops")]) + + session = MockSession( + session_view=MockSessionView(cell_operations={"c1": c1, "c2": c2}), + app_file_manager=DummyAppFileManager(app=Mock()), + ) + + summaries = tool._collect_errors(session) # type: ignore[arg-type] + + # Sorted by cell_id: c1 then c2 + assert len(summaries) == 2 + assert summaries[0].cell_id == "c1" + assert len(summaries[0].errors) == 2 # one MARIMO_ERROR, one STDERR + assert summaries[0].errors[0].type == "ValueError" + assert summaries[1].cell_id == "c2" + assert len(summaries[1].errors) == 1 + assert summaries[1].errors[0].type == "STDERR" + + +def test_handle_integration_uses_context_get_session() -> None: + tool = GetNotebookErrors(ToolContext()) + + c1 = MockCellOp(console=[MockConsoleOutput(CellChannel.STDERR, "warn")]) + session = MockSession( + session_view=MockSessionView(cell_operations={"c1": c1}), + app_file_manager=DummyAppFileManager(app=Mock()), + ) + + # Mock ToolContext.get_session + context = Mock(spec=ToolContext) + context.get_session.return_value = session + tool.context = context # type: ignore[assignment] + + out = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1"))) + assert out.has_errors is True + assert out.total_errors == 1 + assert len(out.cells) == 1 + assert out.cells[0].cell_id == "c1"