Skip to content
150 changes: 148 additions & 2 deletions marimo/_ai/_tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@
)

from marimo import _loggers
from marimo._ai._tools.types import MarimoNotebookInfo, ToolGuidelines
from marimo._ai._tools.types import (
MarimoCellErrors,
MarimoErrorDetail,
MarimoNotebookInfo,
ToolGuidelines,
)
from marimo._ai._tools.utils.exceptions import ToolExecutionError
from marimo._ai._tools.utils.output_cleaning import clean_output
from marimo._config.config import CopilotMode
from marimo._messaging.cell_output import CellChannel
from marimo._messaging.ops import CellOp
from marimo._server.ai.tools.types import (
FunctionArgs,
ToolDefinition,
Expand All @@ -28,7 +36,7 @@
from marimo._server.api.deps import AppStateBase
from marimo._server.model import ConnectionState
from marimo._server.sessions import Session, SessionManager
from marimo._types.ids import SessionId
from marimo._types.ids import CellId_t, SessionId
from marimo._utils.case import to_snake_case
from marimo._utils.dataclass_to_openapi import PythonTypeToOpenAPI
from marimo._utils.parse_dataclass import parse_raw
Expand Down Expand Up @@ -81,6 +89,18 @@ def get_session(self, session_id: SessionId) -> Session:
)
return session_manager.sessions[session_id]

def get_cell_ops(self, session_id: SessionId, cell_id: CellId_t) -> CellOp:
session_view = self.get_session(session_id).session_view
if cell_id not in session_view.cell_operations:
raise ToolExecutionError(
f"Cell operation not found for cell {cell_id}",
code="CELL_OPERATION_NOT_FOUND",
is_retryable=False,
suggested_fix="Try again with a valid cell ID.",
meta={"cell_id": cell_id},
)
return session_view.cell_operations[cell_id]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe some basic unit tests for these new context calls

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done! added this to test_base.py


def get_active_sessions_internal(self) -> list[MarimoNotebookInfo]:
"""
Get active sessions from the app state.
Expand Down Expand Up @@ -113,6 +133,132 @@ def get_active_sessions_internal(self) -> list[MarimoNotebookInfo]:
# Return most recent notebooks first (reverse chronological order)
return files[::-1]

def get_notebook_errors(
self, session_id: SessionId, include_stderr: bool = False
) -> list[MarimoCellErrors]:
"""
Get all errors in the current notebook session, organized by cell.

Args:
session_id: The session ID of the notebook.
include_stderr: Whether to include stderr errors.

Returns:
A list of MarimoCellErrors in the order of the cells in the notebook.
"""
session = self.get_session(session_id)
session_view = session.session_view
cell_errors_map: dict[CellId_t, list[MarimoErrorDetail]] = {}
notebook_errors: list[MarimoCellErrors] = []

for cell_id, cell_op in session_view.cell_operations.items():
errors = self.get_cell_errors(
session_id,
cell_id,
maybe_cell_op=cell_op,
include_stderr=include_stderr,
)
if len(errors) > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if errors:

cell_errors_map[cell_id] = errors

# Use cell_manager to get cells in the correct notebook order
cell_manager = session.app_file_manager.app.cell_manager
for cell_data in cell_manager.cell_data():
cell_id = cell_data.cell_id
if cell_id in cell_errors_map:
notebook_errors.append(
MarimoCellErrors(
cell_id=cell_id,
errors=cell_errors_map[cell_id],
)
)

return notebook_errors

def get_cell_errors(
self,
session_id: SessionId,
cell_id: CellId_t,
maybe_cell_op: Optional[CellOp] = None,
include_stderr: bool = False,
) -> list[MarimoErrorDetail]:
"""
Get all errors for a given cell.

Args:
session_id: The session ID of the notebook.
cell_id: The ID of the cell.
maybe_cell_op: The cell operation.
include_stderr: Whether to include stderr errors.

Returns:
A list of MarimoErrorDetails for the cell with STDERR errors if include_stderr is True.
"""
errors: list[MarimoErrorDetail] = []
cell_op = maybe_cell_op or self.get_cell_ops(session_id, cell_id)

if (
cell_op.output
and cell_op.output.channel == CellChannel.MARIMO_ERROR
):
items = cell_op.output.data

if not isinstance(items, list):
# no errors
return errors

for err in items:
# TODO: filter out noisy useless errors
# like "An ancestor raised an exception..."
if isinstance(err, dict):
errors.append(
MarimoErrorDetail(
type=err.get("type", "UnknownError"),
message=err.get("msg", str(err)),
traceback=err.get("traceback", []),
)
)
else:
# Fallback for rich error objects
err_type: str = getattr(err, "type", type(err).__name__)
describe_fn: Optional[Any] = getattr(err, "describe", None)
message_val = (
describe_fn() if callable(describe_fn) else str(err)
)
message: str = str(message_val)
tb: list[str] = getattr(err, "traceback", []) or []
errors.append(
MarimoErrorDetail(
type=err_type,
message=message,
traceback=tb,
)
)

if cell_op.console and include_stderr:
console_outputs = (
cell_op.console
if isinstance(cell_op.console, list)
else [cell_op.console]
)
stderr_messages: list[str] = []
for console in console_outputs:
if console.channel == CellChannel.STDERR:
stderr_messages.append(str(console.data))
cleaned_stderr_messages = clean_output(stderr_messages)
errors.extend(
[
MarimoErrorDetail(
type="STDERR",
message=message,
traceback=[],
)
for message in cleaned_stderr_messages
]
)

return errors


class ToolBase(Generic[ArgsT, OutT], ABC):
"""
Expand Down
91 changes: 9 additions & 82 deletions marimo/_ai/_tools/tools/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from typing import TYPE_CHECKING, Any, Optional

from marimo._ai._tools.base import ToolBase
from marimo._ai._tools.types import SuccessResult, ToolGuidelines
from marimo._ai._tools.types import (
MarimoErrorDetail,
SuccessResult,
ToolGuidelines,
)
from marimo._ai._tools.utils.exceptions import ToolExecutionError
from marimo._ai._tools.utils.output_cleaning import clean_output
from marimo._ast.models import CellData
Expand Down Expand Up @@ -56,12 +60,6 @@ class ErrorDetail:
traceback: list[str]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this different from MarimoErrorDetail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope I forgot to remove it. I'll fix that thanks


@dataclass
class CellErrors:
has_errors: bool
error_details: Optional[list[ErrorDetail]]


@dataclass
class CellRuntimeMetadata:
# String form of the runtime state (see marimo._ast.cell.RuntimeStateType);
Expand All @@ -79,7 +77,7 @@ class GetCellRuntimeDataData:
session_id: str
cell_id: str
code: Optional[str] = None
errors: Optional[CellErrors] = None
errors: Optional[list[MarimoErrorDetail]] = None
metadata: Optional[CellRuntimeMetadata] = None
variables: Optional[CellVariables] = None

Expand Down Expand Up @@ -267,7 +265,9 @@ def handle(self, args: GetCellRuntimeDataArgs) -> GetCellRuntimeDataOutput:
cell_code = cell_data.code

# Get cell errors from session view with actual error details
cell_errors = self._get_cell_errors(session, cell_id)
cell_errors = context.get_cell_errors(
session_id, cell_id, include_stderr=True
)

# Get cell runtime metadata
cell_metadata = self._get_cell_metadata(session, cell_id)
Expand Down Expand Up @@ -307,79 +307,6 @@ def _get_cell_data(
)
return cell_data

def _get_cell_errors(
self, session: Session, cell_id: CellId_t
) -> CellErrors:
"""Get cell errors from session view with actual error details."""
from marimo._messaging.cell_output import CellChannel

# Get cell operation from session view
session_view = session.session_view
cell_op = session_view.cell_operations.get(cell_id)

if cell_op is None:
# No operations recorded for this cell
return CellErrors(has_errors=False, error_details=None)

# Check for actual error details in the output
has_errors = False
error_details = []
if (
cell_op.output
and cell_op.output.channel == CellChannel.MARIMO_ERROR
):
has_errors = True
# Extract actual error objects
errors = cell_op.output.data
if isinstance(errors, list):
for error in errors:
if hasattr(error, "type") and hasattr(error, "describe"):
# Rich Error object
error_detail = ErrorDetail(
type=error.type,
message=error.describe(),
traceback=getattr(error, "traceback", []),
)
error_details.append(error_detail)
elif isinstance(error, dict):
# Dict-based error
dict_error_detail = ErrorDetail(
type=error.get("type", "UnknownError"),
message=error.get("msg", str(error)),
traceback=error.get("traceback", []),
)
error_details.append(dict_error_detail)
else:
# Fallback for other error types
fallback_error_detail = ErrorDetail(
type=type(error).__name__,
message=str(error),
traceback=[],
)
error_details.append(fallback_error_detail)

# Check console outputs for STDERR (includes print statements to stderr, warnings, etc.)
if cell_op.console:
console_outputs = (
cell_op.console
if isinstance(cell_op.console, list)
else [cell_op.console]
)
for console_output in console_outputs:
if console_output.channel == CellChannel.STDERR:
has_errors = True
stderr_error_detail = ErrorDetail(
type="STDERR",
message=str(console_output.data),
traceback=[],
)
error_details.append(stderr_error_detail)

return CellErrors(
has_errors=has_errors,
error_details=error_details if error_details else None,
)

def _get_cell_metadata(
self, session: Session, cell_id: CellId_t
) -> CellRuntimeMetadata:
Expand Down
Loading
Loading