Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions marimo/_ai/_tools/tools/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright 2025 Marimo. All rights reserved.
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Optional

from marimo import _loggers
from marimo._ai._tools.base import ToolBase
from marimo._ai._tools.tools.cells import ErrorDetail
from marimo._ai._tools.types import SuccessResult
from marimo._server.sessions import Session
from marimo._types.ids import SessionId

LOGGER = _loggers.marimo_logger()


@dataclass
class GetNotebookErrorsArgs:
session_id: SessionId


@dataclass
class CellErrorsSummary:
cell_id: str
cell_name: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i thikn we have strongly typed string for CellId_t and maybe CellName (?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually maybe I should get rid of cell name since its optional and in my opinion doesn't provide a lot of value? Cell ID is necessary for next steps so I'll fix the typing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea that sounds good

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

errors: list[ErrorDetail] = field(default_factory=list)


@dataclass
class GetNotebookErrorsOutput(SuccessResult):
has_errors: bool = False
total_errors: int = 0
cells: list[CellErrorsSummary] = field(default_factory=list)


class GetNotebookErrors(
ToolBase[GetNotebookErrorsArgs, GetNotebookErrorsOutput]
):
"""
Get all errors in the current notebook session, organized by cell.

Args:
session_id: The session ID of the notebook.

Returns:
A success result containing per-cell error details and totals.
"""

def handle(self, args: GetNotebookErrorsArgs) -> GetNotebookErrorsOutput:
session = self.context.get_session(args.session_id)
summaries = self._collect_errors(session)

total_errors = sum(len(s.errors) for s in summaries)
has_errors = total_errors > 0

return GetNotebookErrorsOutput(
has_errors=has_errors,
total_errors=total_errors,
cells=summaries,
next_steps=(
[
"Use get_cell_runtime_data to inspect the impacted cells to fix syntax/runtime issues",
"Re-run the notebook after addressing the errors",
]
if has_errors
else ["No errors detected"]
),
)

# helpers
def _collect_errors(self, session: Session) -> list[CellErrorsSummary]:
from marimo._messaging.cell_output import CellChannel

cell_manager = session.app_file_manager.app.cell_manager
session_view = session.session_view

summaries: list[CellErrorsSummary] = []
for cell_id, cell_op in session_view.cell_operations.items():
# Resolve cell name when possible
cell_data = cell_manager.get_cell_data(cell_id)
cell_name = cell_data.name if cell_data else str(cell_id)

errors: list[ErrorDetail] = []

# Collect structured marimo errors from output
if (
cell_op.output
and cell_op.output.channel == CellChannel.MARIMO_ERROR
):
items = cell_op.output.data
if isinstance(items, list):
for err in items:
if isinstance(err, dict):
errors.append(
ErrorDetail(
type=err.get("type", "UnknownError"),
message=err.get("msg", str(err)),
traceback=err.get("traceback", []),
)
)
else:
# Fallback for rich error objects
err_type: str = getattr(
err, "type", type(err).__name__
)
describe_fn: Optional[Any] = getattr(
err, "describe", None
)
message_val = (
describe_fn()
if callable(describe_fn)
else str(err)
)
message: str = str(message_val)
tb: list[str] = getattr(err, "traceback", []) or []
errors.append(
ErrorDetail(
type=err_type,
message=message,
traceback=tb,
)
)

# Collect stderr messages as error details
if cell_op.console:
console_outputs = (
cell_op.console
if isinstance(cell_op.console, list)
else [cell_op.console]
)
for console in console_outputs:
if console.channel == CellChannel.STDERR:
errors.append(
ErrorDetail(
type="STDERR",
message=str(console.data),
traceback=[],
)
)

if errors:
summaries.append(
CellErrorsSummary(
cell_id=str(cell_id),
cell_name=cell_name,
errors=errors,
)
)

# Sort by cell_id for stable output
summaries.sort(key=lambda s: s.cell_id)
return summaries
2 changes: 2 additions & 0 deletions marimo/_ai/_tools/tools_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
GetLightweightCellMap,
)
from marimo._ai._tools.tools.datasource import GetDatabaseTables
from marimo._ai._tools.tools.errors import GetNotebookErrors
from marimo._ai._tools.tools.notebooks import GetActiveNotebooks
from marimo._ai._tools.tools.tables_and_variables import GetTablesAndVariables

Expand All @@ -16,4 +17,5 @@
GetLightweightCellMap,
GetTablesAndVariables,
GetDatabaseTables,
GetNotebookErrors,
]
134 changes: 134 additions & 0 deletions tests/_ai/tools/tools/test_errors_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from __future__ import annotations

from dataclasses import dataclass
from unittest.mock import Mock

from marimo._ai._tools.base import ToolContext
from marimo._ai._tools.tools.errors import (
GetNotebookErrors,
GetNotebookErrorsArgs,
)
from marimo._messaging.cell_output import CellChannel
from marimo._types.ids import SessionId


@dataclass
class MockCellOp:
output: object | None = None
console: object | None = None


@dataclass
class MockOutput:
channel: object
data: object


@dataclass
class MockConsoleOutput:
channel: object
data: object


@dataclass
class MockSessionView:
cell_operations: dict | None = None

def __post_init__(self) -> None:
if self.cell_operations is None:
self.cell_operations = {}


@dataclass
class DummyCellData:
name: str


class DummyCellManager:
def __init__(self, names: dict[str, str]) -> None:
self._names = names

def get_cell_data(self, cell_id: str) -> DummyCellData:
return DummyCellData(self._names.get(cell_id, cell_id))


@dataclass
class DummyAppFileManager:
app: object


@dataclass
class MockSession:
session_view: MockSessionView
app_file_manager: DummyAppFileManager


def test_collect_errors_none() -> None:
tool = GetNotebookErrors(ToolContext())

# Empty session view
session = MockSession(
session_view=MockSessionView(),
app_file_manager=DummyAppFileManager(
app=Mock(cell_manager=DummyCellManager({}))
),
)

summaries = tool._collect_errors(session) # type: ignore[arg-type]
assert summaries == []


def test_collect_errors_marimo_and_stderr() -> None:
tool = GetNotebookErrors(ToolContext())

# Cell c1 has MARIMO_ERROR (dict) and STDERR; c2 has STDERR only
err_dict = {"type": "ValueError", "msg": "bad value", "traceback": ["tb"]}
c1 = MockCellOp(
output=MockOutput(CellChannel.MARIMO_ERROR, [err_dict]),
console=[MockConsoleOutput(CellChannel.STDERR, "warn")],
)
c2 = MockCellOp(console=[MockConsoleOutput(CellChannel.STDERR, "oops")])

names = {"c1": "Cell 1", "c2": "Cell 2"}
session = MockSession(
session_view=MockSessionView(cell_operations={"c1": c1, "c2": c2}),
app_file_manager=DummyAppFileManager(
app=Mock(cell_manager=DummyCellManager(names))
),
)

summaries = tool._collect_errors(session) # type: ignore[arg-type]

# Sorted by cell_id: c1 then c2
assert len(summaries) == 2
assert summaries[0].cell_id == "c1"
assert summaries[0].cell_name == "Cell 1"
assert len(summaries[0].errors) == 2 # one MARIMO_ERROR, one STDERR
assert summaries[0].errors[0].type == "ValueError"
assert summaries[1].cell_id == "c2"
assert summaries[1].cell_name == "Cell 2"
assert len(summaries[1].errors) == 1
assert summaries[1].errors[0].type == "STDERR"


def test_handle_integration_uses_context_get_session() -> None:
tool = GetNotebookErrors(ToolContext())

c1 = MockCellOp(console=[MockConsoleOutput(CellChannel.STDERR, "warn")])
session = MockSession(
session_view=MockSessionView(cell_operations={"c1": c1}),
app_file_manager=DummyAppFileManager(
app=Mock(cell_manager=DummyCellManager({"c1": "Cell 1"}))
),
)

# Mock ToolContext.get_session
context = Mock(spec=ToolContext)
context.get_session.return_value = session
tool.context = context # type: ignore[assignment]

out = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1")))
assert out.has_errors is True
assert out.total_errors == 1
assert len(out.cells) == 1
assert out.cells[0].cell_id == "c1"
Loading