Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions marimo/_ai/_tools/tools/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2025 Marimo. All rights reserved.
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Optional

from marimo import _loggers
from marimo._ai._tools.base import ToolBase
from marimo._ai._tools.tools.cells import ErrorDetail
from marimo._ai._tools.types import SuccessResult
from marimo._server.sessions import Session
from marimo._types.ids import CellId_t, SessionId

LOGGER = _loggers.marimo_logger()


@dataclass
class GetNotebookErrorsArgs:
session_id: SessionId


@dataclass
class CellErrorsSummary:
cell_id: CellId_t
errors: list[ErrorDetail] = field(default_factory=list)


@dataclass
class GetNotebookErrorsOutput(SuccessResult):
has_errors: bool = False
total_errors: int = 0
cells: list[CellErrorsSummary] = field(default_factory=list)


class GetNotebookErrors(
ToolBase[GetNotebookErrorsArgs, GetNotebookErrorsOutput]
):
"""
Get all errors in the current notebook session, organized by cell.

Args:
session_id: The session ID of the notebook.

Returns:
A success result containing per-cell error details and totals.
"""

def handle(self, args: GetNotebookErrorsArgs) -> GetNotebookErrorsOutput:
session = self.context.get_session(args.session_id)
summaries = self._collect_errors(session)

total_errors = sum(len(s.errors) for s in summaries)
has_errors = total_errors > 0

return GetNotebookErrorsOutput(
has_errors=has_errors,
total_errors=total_errors,
cells=summaries,
next_steps=(
[
"Use get_cell_runtime_data to inspect the impacted cells to fix syntax/runtime issues",
"Re-run the notebook after addressing the errors",
]
if has_errors
else ["No errors detected"]
),
)

# helpers
def _collect_errors(self, session: Session) -> list[CellErrorsSummary]:
from marimo._messaging.cell_output import CellChannel

session_view = session.session_view

summaries: list[CellErrorsSummary] = []
for cell_id, cell_op in session_view.cell_operations.items():
errors: list[ErrorDetail] = []

# Collect structured marimo errors from output
if (
cell_op.output
and cell_op.output.channel == CellChannel.MARIMO_ERROR
):
items = cell_op.output.data
if isinstance(items, list):
for err in items:
if isinstance(err, dict):
errors.append(
ErrorDetail(
type=err.get("type", "UnknownError"),
message=err.get("msg", str(err)),
traceback=err.get("traceback", []),
)
)
else:
# Fallback for rich error objects
err_type: str = getattr(
err, "type", type(err).__name__
)
describe_fn: Optional[Any] = getattr(
err, "describe", None
)
message_val = (
describe_fn()
if callable(describe_fn)
else str(err)
)
message: str = str(message_val)
tb: list[str] = getattr(err, "traceback", []) or []
errors.append(
ErrorDetail(
type=err_type,
message=message,
traceback=tb,
)
)

# Collect stderr messages as error details
if cell_op.console:
console_outputs = (
cell_op.console
if isinstance(cell_op.console, list)
else [cell_op.console]
)
for console in console_outputs:
if console.channel == CellChannel.STDERR:
errors.append(
ErrorDetail(
type="STDERR",
message=str(console.data),
traceback=[],
)
)

if errors:
summaries.append(
CellErrorsSummary(
cell_id=cell_id,
errors=errors,
)
)

# Sort by cell_id for stable output
summaries.sort(key=lambda s: s.cell_id)
return summaries
2 changes: 2 additions & 0 deletions marimo/_ai/_tools/tools_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
GetLightweightCellMap,
)
from marimo._ai._tools.tools.datasource import GetDatabaseTables
from marimo._ai._tools.tools.errors import GetNotebookErrors
from marimo._ai._tools.tools.notebooks import GetActiveNotebooks
from marimo._ai._tools.tools.tables_and_variables import GetTablesAndVariables

Expand All @@ -16,4 +17,5 @@
GetLightweightCellMap,
GetTablesAndVariables,
GetDatabaseTables,
GetNotebookErrors,
]
112 changes: 112 additions & 0 deletions tests/_ai/tools/tools/test_errors_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from __future__ import annotations

from dataclasses import dataclass
from unittest.mock import Mock

from marimo._ai._tools.base import ToolContext
from marimo._ai._tools.tools.errors import (
GetNotebookErrors,
GetNotebookErrorsArgs,
)
from marimo._messaging.cell_output import CellChannel
from marimo._types.ids import SessionId


@dataclass
class MockCellOp:
output: object | None = None
console: object | None = None


@dataclass
class MockOutput:
channel: object
data: object


@dataclass
class MockConsoleOutput:
channel: object
data: object


@dataclass
class MockSessionView:
cell_operations: dict | None = None

def __post_init__(self) -> None:
if self.cell_operations is None:
self.cell_operations = {}


@dataclass
class DummyAppFileManager:
app: object


@dataclass
class MockSession:
session_view: MockSessionView
app_file_manager: DummyAppFileManager


def test_collect_errors_none() -> None:
tool = GetNotebookErrors(ToolContext())

# Empty session view
session = MockSession(
session_view=MockSessionView(),
app_file_manager=DummyAppFileManager(app=Mock()),
)

summaries = tool._collect_errors(session) # type: ignore[arg-type]
assert summaries == []


def test_collect_errors_marimo_and_stderr() -> None:
tool = GetNotebookErrors(ToolContext())

# Cell c1 has MARIMO_ERROR (dict) and STDERR; c2 has STDERR only
err_dict = {"type": "ValueError", "msg": "bad value", "traceback": ["tb"]}
c1 = MockCellOp(
output=MockOutput(CellChannel.MARIMO_ERROR, [err_dict]),
console=[MockConsoleOutput(CellChannel.STDERR, "warn")],
)
c2 = MockCellOp(console=[MockConsoleOutput(CellChannel.STDERR, "oops")])

session = MockSession(
session_view=MockSessionView(cell_operations={"c1": c1, "c2": c2}),
app_file_manager=DummyAppFileManager(app=Mock()),
)

summaries = tool._collect_errors(session) # type: ignore[arg-type]

# Sorted by cell_id: c1 then c2
assert len(summaries) == 2
assert summaries[0].cell_id == "c1"
assert len(summaries[0].errors) == 2 # one MARIMO_ERROR, one STDERR
assert summaries[0].errors[0].type == "ValueError"
assert summaries[1].cell_id == "c2"
assert len(summaries[1].errors) == 1
assert summaries[1].errors[0].type == "STDERR"


def test_handle_integration_uses_context_get_session() -> None:
tool = GetNotebookErrors(ToolContext())

c1 = MockCellOp(console=[MockConsoleOutput(CellChannel.STDERR, "warn")])
session = MockSession(
session_view=MockSessionView(cell_operations={"c1": c1}),
app_file_manager=DummyAppFileManager(app=Mock()),
)

# Mock ToolContext.get_session
context = Mock(spec=ToolContext)
context.get_session.return_value = session
tool.context = context # type: ignore[assignment]

out = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1")))
assert out.has_errors is True
assert out.total_errors == 1
assert len(out.cells) == 1
assert out.cells[0].cell_id == "c1"
Loading