Skip to content
149 changes: 147 additions & 2 deletions marimo/_ai/_tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,18 @@
)

from marimo import _loggers
from marimo._ai._tools.types import MarimoNotebookInfo, ToolGuidelines
from marimo._ai._tools.types import (
MarimoCellConsoleOutputs,
MarimoCellErrors,
MarimoErrorDetail,
MarimoNotebookInfo,
ToolGuidelines,
)
from marimo._ai._tools.utils.exceptions import ToolExecutionError
from marimo._ai._tools.utils.output_cleaning import clean_output
from marimo._config.config import CopilotMode
from marimo._messaging.cell_output import CellChannel
from marimo._messaging.ops import CellOp
from marimo._server.ai.tools.types import (
FunctionArgs,
ToolDefinition,
Expand All @@ -28,7 +37,7 @@
from marimo._server.api.deps import AppStateBase
from marimo._server.model import ConnectionState
from marimo._server.sessions import Session, SessionManager
from marimo._types.ids import SessionId
from marimo._types.ids import CellId_t, SessionId
from marimo._utils.case import to_snake_case
from marimo._utils.dataclass_to_openapi import PythonTypeToOpenAPI
from marimo._utils.parse_dataclass import parse_raw
Expand Down Expand Up @@ -81,6 +90,18 @@ def get_session(self, session_id: SessionId) -> Session:
)
return session_manager.sessions[session_id]

def get_cell_ops(self, session_id: SessionId, cell_id: CellId_t) -> CellOp:
session_view = self.get_session(session_id).session_view
if cell_id not in session_view.cell_operations:
raise ToolExecutionError(
f"Cell operation not found for cell {cell_id}",
code="CELL_OPERATION_NOT_FOUND",
is_retryable=False,
suggested_fix="Try again with a valid cell ID.",
meta={"cell_id": cell_id},
)
return session_view.cell_operations[cell_id]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe some basic unit tests for these new context calls

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done! added this to test_base.py


def get_active_sessions_internal(self) -> list[MarimoNotebookInfo]:
"""
Get active sessions from the app state.
Expand Down Expand Up @@ -113,6 +134,130 @@ def get_active_sessions_internal(self) -> list[MarimoNotebookInfo]:
# Return most recent notebooks first (reverse chronological order)
return files[::-1]

def get_notebook_errors(
self, session_id: SessionId, include_stderr: bool
) -> list[MarimoCellErrors]:
"""
Get all errors in the current notebook session, organized by cell.

Optionally include stderr messages foreach cell.
"""
session = self.get_session(session_id)
session_view = session.session_view
cell_errors_map: dict[CellId_t, MarimoCellErrors] = {}
notebook_errors: list[MarimoCellErrors] = []
stderr: list[str] = []

for cell_id, cell_op in session_view.cell_operations.items():
errors = self.get_cell_errors(
session_id,
cell_id,
maybe_cell_op=cell_op,
)
if include_stderr:
stderr = self.get_cell_console_outputs(cell_op).stderr
if errors:
cell_errors_map[cell_id] = MarimoCellErrors(
cell_id=cell_id,
errors=errors,
stderr=stderr,
)

# Use cell_manager to get cells in the correct notebook order
cell_manager = session.app_file_manager.app.cell_manager
for cell_data in cell_manager.cell_data():
cell_id = cell_data.cell_id
if cell_id in cell_errors_map:
notebook_errors.append(cell_errors_map[cell_id])

return notebook_errors

def get_cell_errors(
self,
session_id: SessionId,
cell_id: CellId_t,
maybe_cell_op: Optional[CellOp] = None,
) -> list[MarimoErrorDetail]:
"""
Get all errors for a given cell.
"""
errors: list[MarimoErrorDetail] = []
cell_op = maybe_cell_op or self.get_cell_ops(session_id, cell_id)

if (
not cell_op.output
or cell_op.output.channel != CellChannel.MARIMO_ERROR
):
return errors

items = cell_op.output.data

if not isinstance(items, list):
# no errors
return errors

for err in items:
# TODO: filter out noisy useless errors
# like "An ancestor raised an exception..."
if isinstance(err, dict):
errors.append(
MarimoErrorDetail(
type=err.get("type", "UnknownError"),
message=err.get("msg", str(err)),
traceback=err.get("traceback", []),
)
)
else:
# Fallback for rich error objects
err_type: str = getattr(err, "type", type(err).__name__)
describe_fn: Optional[Any] = getattr(err, "describe", None)
message_val = (
describe_fn() if callable(describe_fn) else str(err)
)
message: str = str(message_val)
tb: list[str] = getattr(err, "traceback", []) or []
errors.append(
MarimoErrorDetail(
type=err_type,
message=message,
traceback=tb,
)
)

return errors

def get_cell_console_outputs(
self, cell_op: CellOp
) -> MarimoCellConsoleOutputs:
"""
Get the console outputs for a given cell operation.
"""
stdout_messages: list[str] = []
stderr_messages: list[str] = []

if cell_op.console is None:
return MarimoCellConsoleOutputs(stdout=[], stderr=[])

console_outputs = (
cell_op.console
if isinstance(cell_op.console, list)
else [cell_op.console]
)
for output in console_outputs:
if output is None:
continue
elif output.channel == CellChannel.STDOUT:
stdout_messages.append(str(output.data))
elif output.channel == CellChannel.STDERR:
stderr_messages.append(str(output.data))

cleaned_stdout_messages = clean_output(stdout_messages)
cleaned_stderr_messages = clean_output(stderr_messages)

return MarimoCellConsoleOutputs(
stdout=cleaned_stdout_messages, stderr=cleaned_stderr_messages
)


class ToolBase(Generic[ArgsT, OutT], ABC):
"""
Expand Down
Loading
Loading