diff --git a/marimo/_ai/_tools/base.py b/marimo/_ai/_tools/base.py index d1b2cd70f97..8491412cac8 100644 --- a/marimo/_ai/_tools/base.py +++ b/marimo/_ai/_tools/base.py @@ -17,9 +17,18 @@ ) from marimo import _loggers -from marimo._ai._tools.types import MarimoNotebookInfo, ToolGuidelines +from marimo._ai._tools.types import ( + MarimoCellConsoleOutputs, + MarimoCellErrors, + MarimoErrorDetail, + MarimoNotebookInfo, + ToolGuidelines, +) from marimo._ai._tools.utils.exceptions import ToolExecutionError +from marimo._ai._tools.utils.output_cleaning import clean_output from marimo._config.config import CopilotMode +from marimo._messaging.cell_output import CellChannel +from marimo._messaging.ops import CellOp from marimo._server.ai.tools.types import ( FunctionArgs, ToolDefinition, @@ -28,7 +37,7 @@ from marimo._server.api.deps import AppStateBase from marimo._server.model import ConnectionState from marimo._server.sessions import Session, SessionManager -from marimo._types.ids import SessionId +from marimo._types.ids import CellId_t, SessionId from marimo._utils.case import to_snake_case from marimo._utils.dataclass_to_openapi import PythonTypeToOpenAPI from marimo._utils.parse_dataclass import parse_raw @@ -81,6 +90,18 @@ def get_session(self, session_id: SessionId) -> Session: ) return session_manager.sessions[session_id] + def get_cell_ops(self, session_id: SessionId, cell_id: CellId_t) -> CellOp: + session_view = self.get_session(session_id).session_view + if cell_id not in session_view.cell_operations: + raise ToolExecutionError( + f"Cell operation not found for cell {cell_id}", + code="CELL_OPERATION_NOT_FOUND", + is_retryable=False, + suggested_fix="Try again with a valid cell ID.", + meta={"cell_id": cell_id}, + ) + return session_view.cell_operations[cell_id] + def get_active_sessions_internal(self) -> list[MarimoNotebookInfo]: """ Get active sessions from the app state. @@ -113,6 +134,130 @@ def get_active_sessions_internal(self) -> list[MarimoNotebookInfo]: # Return most recent notebooks first (reverse chronological order) return files[::-1] + def get_notebook_errors( + self, session_id: SessionId, include_stderr: bool + ) -> list[MarimoCellErrors]: + """ + Get all errors in the current notebook session, organized by cell. + + Optionally include stderr messages foreach cell. + """ + session = self.get_session(session_id) + session_view = session.session_view + cell_errors_map: dict[CellId_t, MarimoCellErrors] = {} + notebook_errors: list[MarimoCellErrors] = [] + stderr: list[str] = [] + + for cell_id, cell_op in session_view.cell_operations.items(): + errors = self.get_cell_errors( + session_id, + cell_id, + maybe_cell_op=cell_op, + ) + if include_stderr: + stderr = self.get_cell_console_outputs(cell_op).stderr + if errors: + cell_errors_map[cell_id] = MarimoCellErrors( + cell_id=cell_id, + errors=errors, + stderr=stderr, + ) + + # Use cell_manager to get cells in the correct notebook order + cell_manager = session.app_file_manager.app.cell_manager + for cell_data in cell_manager.cell_data(): + cell_id = cell_data.cell_id + if cell_id in cell_errors_map: + notebook_errors.append(cell_errors_map[cell_id]) + + return notebook_errors + + def get_cell_errors( + self, + session_id: SessionId, + cell_id: CellId_t, + maybe_cell_op: Optional[CellOp] = None, + ) -> list[MarimoErrorDetail]: + """ + Get all errors for a given cell. + """ + errors: list[MarimoErrorDetail] = [] + cell_op = maybe_cell_op or self.get_cell_ops(session_id, cell_id) + + if ( + not cell_op.output + or cell_op.output.channel != CellChannel.MARIMO_ERROR + ): + return errors + + items = cell_op.output.data + + if not isinstance(items, list): + # no errors + return errors + + for err in items: + # TODO: filter out noisy useless errors + # like "An ancestor raised an exception..." + if isinstance(err, dict): + errors.append( + MarimoErrorDetail( + type=err.get("type", "UnknownError"), + message=err.get("msg", str(err)), + traceback=err.get("traceback", []), + ) + ) + else: + # Fallback for rich error objects + err_type: str = getattr(err, "type", type(err).__name__) + describe_fn: Optional[Any] = getattr(err, "describe", None) + message_val = ( + describe_fn() if callable(describe_fn) else str(err) + ) + message: str = str(message_val) + tb: list[str] = getattr(err, "traceback", []) or [] + errors.append( + MarimoErrorDetail( + type=err_type, + message=message, + traceback=tb, + ) + ) + + return errors + + def get_cell_console_outputs( + self, cell_op: CellOp + ) -> MarimoCellConsoleOutputs: + """ + Get the console outputs for a given cell operation. + """ + stdout_messages: list[str] = [] + stderr_messages: list[str] = [] + + if cell_op.console is None: + return MarimoCellConsoleOutputs(stdout=[], stderr=[]) + + console_outputs = ( + cell_op.console + if isinstance(cell_op.console, list) + else [cell_op.console] + ) + for output in console_outputs: + if output is None: + continue + elif output.channel == CellChannel.STDOUT: + stdout_messages.append(str(output.data)) + elif output.channel == CellChannel.STDERR: + stderr_messages.append(str(output.data)) + + cleaned_stdout_messages = clean_output(stdout_messages) + cleaned_stderr_messages = clean_output(stderr_messages) + + return MarimoCellConsoleOutputs( + stdout=cleaned_stdout_messages, stderr=cleaned_stderr_messages + ) + class ToolBase(Generic[ArgsT, OutT], ABC): """ diff --git a/marimo/_ai/_tools/tools/cells.py b/marimo/_ai/_tools/tools/cells.py index f812cd20e60..d12097eb87e 100644 --- a/marimo/_ai/_tools/tools/cells.py +++ b/marimo/_ai/_tools/tools/cells.py @@ -6,11 +6,14 @@ from typing import TYPE_CHECKING, Any, Optional from marimo._ai._tools.base import ToolBase -from marimo._ai._tools.types import SuccessResult, ToolGuidelines +from marimo._ai._tools.types import ( + MarimoCellConsoleOutputs, + MarimoErrorDetail, + SuccessResult, + ToolGuidelines, +) from marimo._ai._tools.utils.exceptions import ToolExecutionError -from marimo._ai._tools.utils.output_cleaning import clean_output from marimo._ast.models import CellData -from marimo._messaging.cell_output import CellChannel from marimo._messaging.errors import Error from marimo._messaging.ops import CellOp, VariableValue from marimo._types.ids import CellId_t, SessionId @@ -49,19 +52,6 @@ class GetLightweightCellMapOutput(SuccessResult): preview_lines: int = 3 -@dataclass -class ErrorDetail: - type: str - message: str - traceback: list[str] - - -@dataclass -class CellErrors: - has_errors: bool - error_details: Optional[list[ErrorDetail]] - - @dataclass class CellRuntimeMetadata: # String form of the runtime state (see marimo._ast.cell.RuntimeStateType); @@ -79,7 +69,7 @@ class GetCellRuntimeDataData: session_id: str cell_id: str code: Optional[str] = None - errors: Optional[CellErrors] = None + errors: Optional[list[MarimoErrorDetail]] = None metadata: Optional[CellRuntimeMetadata] = None variables: Optional[CellVariables] = None @@ -102,13 +92,11 @@ class GetCellRuntimeDataOutput(SuccessResult): @dataclass -class CellOutputData: - """Visual and console output from a cell execution.""" +class CellVisualOutput: + """Visual from a cell execution.""" visual_output: Optional[str] = None visual_mimetype: Optional[str] = None - stdout: list[str] = field(default_factory=list) - stderr: list[str] = field(default_factory=list) @dataclass @@ -119,7 +107,10 @@ class GetCellOutputArgs: @dataclass class GetCellOutputOutput(SuccessResult): - data: CellOutputData = field(default_factory=CellOutputData) + visual_output: CellVisualOutput = field(default_factory=CellVisualOutput) + console_outputs: MarimoCellConsoleOutputs = field( + default_factory=MarimoCellConsoleOutputs + ) class GetLightweightCellMap( @@ -267,7 +258,7 @@ def handle(self, args: GetCellRuntimeDataArgs) -> GetCellRuntimeDataOutput: cell_code = cell_data.code # Get cell errors from session view with actual error details - cell_errors = self._get_cell_errors(session, cell_id) + cell_errors = context.get_cell_errors(session_id, cell_id) # Get cell runtime metadata cell_metadata = self._get_cell_metadata(session, cell_id) @@ -307,79 +298,6 @@ def _get_cell_data( ) return cell_data - def _get_cell_errors( - self, session: Session, cell_id: CellId_t - ) -> CellErrors: - """Get cell errors from session view with actual error details.""" - from marimo._messaging.cell_output import CellChannel - - # Get cell operation from session view - session_view = session.session_view - cell_op = session_view.cell_operations.get(cell_id) - - if cell_op is None: - # No operations recorded for this cell - return CellErrors(has_errors=False, error_details=None) - - # Check for actual error details in the output - has_errors = False - error_details = [] - if ( - cell_op.output - and cell_op.output.channel == CellChannel.MARIMO_ERROR - ): - has_errors = True - # Extract actual error objects - errors = cell_op.output.data - if isinstance(errors, list): - for error in errors: - if hasattr(error, "type") and hasattr(error, "describe"): - # Rich Error object - error_detail = ErrorDetail( - type=error.type, - message=error.describe(), - traceback=getattr(error, "traceback", []), - ) - error_details.append(error_detail) - elif isinstance(error, dict): - # Dict-based error - dict_error_detail = ErrorDetail( - type=error.get("type", "UnknownError"), - message=error.get("msg", str(error)), - traceback=error.get("traceback", []), - ) - error_details.append(dict_error_detail) - else: - # Fallback for other error types - fallback_error_detail = ErrorDetail( - type=type(error).__name__, - message=str(error), - traceback=[], - ) - error_details.append(fallback_error_detail) - - # Check console outputs for STDERR (includes print statements to stderr, warnings, etc.) - if cell_op.console: - console_outputs = ( - cell_op.console - if isinstance(cell_op.console, list) - else [cell_op.console] - ) - for console_output in console_outputs: - if console_output.channel == CellChannel.STDERR: - has_errors = True - stderr_error_detail = ErrorDetail( - type="STDERR", - message=str(console_output.data), - traceback=[], - ) - error_details.append(stderr_error_detail) - - return CellErrors( - has_errors=has_errors, - error_details=error_details if error_details else None, - ) - def _get_cell_metadata( self, session: Session, cell_id: CellId_t ) -> CellRuntimeMetadata: @@ -426,16 +344,12 @@ def _get_cell_variables( class GetCellOutputs(ToolBase[GetCellOutputArgs, GetCellOutputOutput]): """Get cell execution output including visual display and console streams. - Returns comprehensive output data for a single cell: - - Visual output (HTML, charts, tables, etc.) with mimetype - - Console stdout and stderr messages - Args: session_id: The session ID of the notebook from get_active_notebooks cell_id: The specific cell ID from get_lightweight_cell_map Returns: - A success result containing all output data from the cell execution. + Visual output (HTML, charts, tables, etc.) with mimetype and console streams (stdout/stderr). """ guidelines = ToolGuidelines( @@ -450,23 +364,29 @@ class GetCellOutputs(ToolBase[GetCellOutputArgs, GetCellOutputOutput]): ) def handle(self, args: GetCellOutputArgs) -> GetCellOutputOutput: - session = self.context.get_session(args.session_id) + context = self.context + session = context.get_session(args.session_id) session_view = session.session_view cell_id = args.cell_id - maybe_cell_op = session_view.cell_operations.get(cell_id) + cell_op = session_view.cell_operations.get(cell_id) - visual_output, visual_mimetype = self._get_visual_output(maybe_cell_op) - stdout_messages, stderr_messages = self._get_console_outputs( - maybe_cell_op - ) + if cell_op is None: + raise ToolExecutionError( + f"Cell {cell_id} not found in session {args.session_id}", + code="CELL_NOT_FOUND", + is_retryable=False, + suggested_fix="Use get_lightweight_cell_map to find valid cell IDs", + ) + + visual_output, visual_mimetype = self._get_visual_output(cell_op) + console_outputs = context.get_cell_console_outputs(cell_op) return GetCellOutputOutput( - data=CellOutputData( + visual_output=CellVisualOutput( visual_output=visual_output, visual_mimetype=visual_mimetype, - stdout=stdout_messages, - stderr=stderr_messages, ), + console_outputs=console_outputs, next_steps=[ "Review visual_output to see what was displayed to the user", "Check stdout/stderr for print statements and warnings", @@ -474,14 +394,14 @@ def handle(self, args: GetCellOutputArgs) -> GetCellOutputOutput: ) def _get_visual_output( - self, maybe_cell_op: Optional[CellOp] + self, cell_op: CellOp ) -> tuple[Optional[str], Optional[str]]: visual_output = None visual_mimetype = None - if maybe_cell_op and maybe_cell_op.output: - data = maybe_cell_op.output.data + if cell_op.output: + data = cell_op.output.data visual_output = self._get_str_output_data(data) - visual_mimetype = maybe_cell_op.output.mimetype + visual_mimetype = cell_op.output.mimetype return visual_output, visual_mimetype def _get_str_output_data( @@ -491,29 +411,3 @@ def _get_str_output_data( return data else: return str(data) - - def _get_console_outputs( - self, maybe_cell_op: Optional[CellOp] - ) -> tuple[list[str], list[str]]: - stdout_messages: list[str] = [] - stderr_messages: list[str] = [] - if maybe_cell_op is None or maybe_cell_op.console is None: - return stdout_messages, stderr_messages - - console_outputs = ( - maybe_cell_op.console - if isinstance(maybe_cell_op.console, list) - else [maybe_cell_op.console] - ) - for output in console_outputs: - if output is None: - continue - elif output.channel == CellChannel.STDOUT: - stdout_messages.append(str(output.data)) - elif output.channel == CellChannel.STDERR: - stderr_messages.append(str(output.data)) - - cleaned_stdout_messages = clean_output(stdout_messages) - cleaned_stderr_messages = clean_output(stderr_messages) - - return cleaned_stdout_messages, cleaned_stderr_messages diff --git a/marimo/_ai/_tools/tools/errors.py b/marimo/_ai/_tools/tools/errors.py index 26f31d112f4..f7183366e12 100644 --- a/marimo/_ai/_tools/tools/errors.py +++ b/marimo/_ai/_tools/tools/errors.py @@ -2,16 +2,14 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Optional -from marimo import _loggers from marimo._ai._tools.base import ToolBase -from marimo._ai._tools.tools.cells import ErrorDetail -from marimo._ai._tools.types import SuccessResult, ToolGuidelines -from marimo._server.sessions import Session -from marimo._types.ids import CellId_t, SessionId - -LOGGER = _loggers.marimo_logger() +from marimo._ai._tools.types import ( + MarimoCellErrors, + SuccessResult, + ToolGuidelines, +) +from marimo._types.ids import SessionId @dataclass @@ -19,17 +17,12 @@ class GetNotebookErrorsArgs: session_id: SessionId -@dataclass -class CellErrorsSummary: - cell_id: CellId_t - errors: list[ErrorDetail] = field(default_factory=list) - - @dataclass class GetNotebookErrorsOutput(SuccessResult): has_errors: bool = False total_errors: int = 0 - cells: list[CellErrorsSummary] = field(default_factory=list) + total_cells_with_errors: int = 0 + cells: list[MarimoCellErrors] = field(default_factory=list) class GetNotebookErrors( @@ -42,7 +35,7 @@ class GetNotebookErrors( session_id: The session ID of the notebook. Returns: - A success result containing per-cell error details and totals. + A success result containing notebook errors organized by cell. """ guidelines = ToolGuidelines( @@ -56,16 +49,21 @@ class GetNotebookErrors( ) def handle(self, args: GetNotebookErrorsArgs) -> GetNotebookErrorsOutput: - session = self.context.get_session(args.session_id) - summaries = self._collect_errors(session) + context = self.context + session_id = args.session_id + notebook_errors = context.get_notebook_errors( + session_id, include_stderr=True + ) - total_errors = sum(len(s.errors) for s in summaries) + total_errors = sum(len(c.errors) for c in notebook_errors) + total_cells_with_errors = len(notebook_errors) has_errors = total_errors > 0 return GetNotebookErrorsOutput( has_errors=has_errors, total_errors=total_errors, - cells=summaries, + total_cells_with_errors=total_cells_with_errors, + cells=notebook_errors, next_steps=( [ "Use get_cell_runtime_data to inspect the impacted cells to fix syntax/runtime issues", @@ -75,81 +73,3 @@ def handle(self, args: GetNotebookErrorsArgs) -> GetNotebookErrorsOutput: else ["No errors detected"] ), ) - - # helpers - def _collect_errors(self, session: Session) -> list[CellErrorsSummary]: - from marimo._messaging.cell_output import CellChannel - - session_view = session.session_view - - summaries: list[CellErrorsSummary] = [] - for cell_id, cell_op in session_view.cell_operations.items(): - errors: list[ErrorDetail] = [] - - # Collect structured marimo errors from output - if ( - cell_op.output - and cell_op.output.channel == CellChannel.MARIMO_ERROR - ): - items = cell_op.output.data - if isinstance(items, list): - for err in items: - if isinstance(err, dict): - errors.append( - ErrorDetail( - type=err.get("type", "UnknownError"), - message=err.get("msg", str(err)), - traceback=err.get("traceback", []), - ) - ) - else: - # Fallback for rich error objects - err_type: str = getattr( - err, "type", type(err).__name__ - ) - describe_fn: Optional[Any] = getattr( - err, "describe", None - ) - message_val = ( - describe_fn() - if callable(describe_fn) - else str(err) - ) - message: str = str(message_val) - tb: list[str] = getattr(err, "traceback", []) or [] - errors.append( - ErrorDetail( - type=err_type, - message=message, - traceback=tb, - ) - ) - - # Collect stderr messages as error details - if cell_op.console: - console_outputs = ( - cell_op.console - if isinstance(cell_op.console, list) - else [cell_op.console] - ) - for console in console_outputs: - if console.channel == CellChannel.STDERR: - errors.append( - ErrorDetail( - type="STDERR", - message=str(console.data), - traceback=[], - ) - ) - - if errors: - summaries.append( - CellErrorsSummary( - cell_id=cell_id, - errors=errors, - ) - ) - - # Sort by cell_id for stable output - summaries.sort(key=lambda s: s.cell_id) - return summaries diff --git a/marimo/_ai/_tools/types.py b/marimo/_ai/_tools/types.py index 99e5c3b1bf1..9b3e68b0c97 100644 --- a/marimo/_ai/_tools/types.py +++ b/marimo/_ai/_tools/types.py @@ -1,10 +1,10 @@ # Copyright 2025 Marimo. All rights reserved. from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Literal, Optional -from marimo._types.ids import SessionId +from marimo._types.ids import CellId_t, SessionId # helper classes StatusValue = Literal["success", "error", "warning"] @@ -41,3 +41,23 @@ class MarimoNotebookInfo: name: str path: str session_id: SessionId + + +@dataclass +class MarimoCellErrors: + cell_id: CellId_t + errors: list[MarimoErrorDetail] = field(default_factory=list) + stderr: list[str] = field(default_factory=list) + + +@dataclass +class MarimoErrorDetail: + type: str + message: str + traceback: list[str] + + +@dataclass +class MarimoCellConsoleOutputs: + stdout: list[str] = field(default_factory=list) + stderr: list[str] = field(default_factory=list) diff --git a/marimo/_mcp/server/_prompts/prompts/errors.py b/marimo/_mcp/server/_prompts/prompts/errors.py new file mode 100644 index 00000000000..d59e452bcef --- /dev/null +++ b/marimo/_mcp/server/_prompts/prompts/errors.py @@ -0,0 +1,86 @@ +# Copyright 2024 Marimo. All rights reserved. +"""MCP Prompts for notebook error information.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from marimo._mcp.server._prompts.base import PromptBase + +if TYPE_CHECKING: + from mcp.types import PromptMessage + + +class ErrorsSummary(PromptBase): + """Get error summaries for all active notebooks.""" + + def handle(self) -> list[PromptMessage]: + """Generate prompt messages summarizing errors in active notebooks. + + Returns: + List of PromptMessage objects, one per notebook with errors. + """ + from mcp.types import PromptMessage, TextContent + + context = self.context + notebooks = context.get_active_sessions_internal() + + if len(notebooks) == 0: + return [ + PromptMessage( + role="user", + content=TextContent( + type="text", + text="No active marimo notebook sessions found.", + ), + ) + ] + + error_messages: list[PromptMessage] = [] + + for notebook in notebooks: + session_id = notebook.session_id + notebook_errors = context.get_notebook_errors( + session_id, include_stderr=False + ) + + if len(notebook_errors) == 0: + continue + + error_lines = [ + f"Notebook: {notebook.name} (session: {notebook.session_id})", + f"Path: {notebook.path}", + f"Cells with errors: {len(notebook_errors)}", + f"Total errors: {sum(len(cell_errors.errors) for cell_errors in notebook_errors)}", + ] + + for cell_errors in notebook_errors: + error_lines.append(f"**Cell {cell_errors.cell_id}**:") + all_cell_errors = [ + f" • {error.type} - {error.message}" + for error in cell_errors.errors + ] + error_lines.extend(all_cell_errors) + + error_messages.append( + PromptMessage( + role="user", + content=TextContent( + type="text", + text="\n".join(error_lines), + ), + ) + ) + + if len(error_messages) == 0: + return [ + PromptMessage( + role="user", + content=TextContent( + type="text", + text="No errors found in any active notebooks.", + ), + ) + ] + + return error_messages diff --git a/marimo/_mcp/server/_prompts/registry.py b/marimo/_mcp/server/_prompts/registry.py index d60e8814701..403234fa7fb 100644 --- a/marimo/_mcp/server/_prompts/registry.py +++ b/marimo/_mcp/server/_prompts/registry.py @@ -2,8 +2,10 @@ """Registry of all supported MCP prompts.""" from marimo._mcp.server._prompts.base import PromptBase +from marimo._mcp.server._prompts.prompts.errors import ErrorsSummary from marimo._mcp.server._prompts.prompts.notebooks import ActiveNotebooks SUPPORTED_MCP_PROMPTS: list[type[PromptBase]] = [ ActiveNotebooks, + ErrorsSummary, ] diff --git a/tests/_ai/tools/test_base.py b/tests/_ai/tools/test_base.py index 60edc7b7c7b..3afd7485343 100644 --- a/tests/_ai/tools/test_base.py +++ b/tests/_ai/tools/test_base.py @@ -125,3 +125,107 @@ def test_as_backend_tool() -> None: is_valid, msg = validator({"invalid": "field"}) assert is_valid is False assert "Invalid arguments" in msg + + +# test ToolContext methods + + +def test_get_notebook_errors_orders_by_cell_manager(): + """Test errors follow cell_manager order, not alphabetical.""" + from unittest.mock import Mock + + from marimo._messaging.cell_output import CellChannel + from marimo._types.ids import CellId_t, SessionId + + context = ToolContext() + + # Mock error cell_op + error_op = Mock() + error_op.output = Mock() + error_op.output.channel = CellChannel.MARIMO_ERROR + error_op.output.data = [{"type": "Error", "msg": "test", "traceback": []}] + error_op.console = None + + # Mock session with cells c1, c2, c3 + session = Mock() + session.session_view.cell_operations = { + CellId_t("c1"): error_op, + CellId_t("c2"): error_op, + CellId_t("c3"): error_op, + } + + # Cell manager returns in order: c3, c2, c1 (not alphabetical) + cell_data = [ + Mock(cell_id=CellId_t("c3")), + Mock(cell_id=CellId_t("c2")), + Mock(cell_id=CellId_t("c1")), + ] + session.app_file_manager.app.cell_manager.cell_data.return_value = ( + cell_data + ) + + context.get_session = Mock(return_value=session) + + errors = context.get_notebook_errors( + SessionId("test"), include_stderr=False + ) + + # Should be c3, c2, c1 (not c1, c2, c3) + assert errors[0].cell_id == CellId_t("c3") + assert errors[1].cell_id == CellId_t("c2") + assert errors[2].cell_id == CellId_t("c1") + + +def test_get_cell_errors_extracts_from_output(): + """Test get_cell_errors extracts error details from cell output.""" + from unittest.mock import Mock + + from marimo._messaging.cell_output import CellChannel + from marimo._types.ids import CellId_t, SessionId + + context = ToolContext() + + # Mock cell_op with error + cell_op = Mock() + cell_op.output = Mock() + cell_op.output.channel = CellChannel.MARIMO_ERROR + cell_op.output.data = [ + {"type": "ValueError", "msg": "bad value", "traceback": ["line 1"]} + ] + + errors = context.get_cell_errors( + SessionId("test"), CellId_t("c1"), maybe_cell_op=cell_op + ) + + assert len(errors) == 1 + assert errors[0].type == "ValueError" + assert errors[0].message == "bad value" + assert errors[0].traceback == ["line 1"] + + +def test_get_cell_console_outputs_separates_stdout_stderr(): + """Test get_cell_console_outputs separates stdout and stderr.""" + from unittest.mock import Mock + + from marimo._messaging.cell_output import CellChannel + + context = ToolContext() + + # Mock cell_op with stdout and stderr + stdout_output = Mock() + stdout_output.channel = CellChannel.STDOUT + stdout_output.data = "hello" + + stderr_output = Mock() + stderr_output.channel = CellChannel.STDERR + stderr_output.data = "warning" + + cell_op = Mock() + cell_op.console = [stdout_output, stderr_output] + + result = context.get_cell_console_outputs(cell_op) + + assert len(result.stdout) == 1 + assert "hello" in result.stdout[0] + assert len(result.stderr) == 1 + assert "warning" in result.stderr[0] diff --git a/tests/_ai/tools/tools/test_cells.py b/tests/_ai/tools/tools/test_cells.py index 638c7f4fec3..02137bc3a6c 100644 --- a/tests/_ai/tools/tools/test_cells.py +++ b/tests/_ai/tools/tools/test_cells.py @@ -7,14 +7,12 @@ from marimo._ai._tools.base import ToolContext from marimo._ai._tools.tools.cells import ( - CellErrors, CellRuntimeMetadata, CellVariables, GetCellOutputs, GetCellRuntimeData, GetLightweightCellMap, ) -from marimo._messaging.cell_output import CellChannel from marimo._messaging.ops import VariableValue from marimo._server.sessions import Session from marimo._types.ids import CellId_t, SessionId @@ -76,54 +74,6 @@ def test_is_markdown_cell(): assert tool._is_markdown_cell("print('x')") is False -def test_get_cell_errors_no_cell_op(): - tool = GetCellRuntimeData(ToolContext()) - session = MockSession(MockSessionView()) - - result = tool._get_cell_errors(session, CellId_t("missing")) - assert result == CellErrors(has_errors=False, error_details=None) - - -def test_get_cell_errors_with_marimo_error(): - tool = GetCellRuntimeData(ToolContext()) - error = MockError( - "NameError", "name 'x' is not defined", ["line1", "line2"] - ) - output = MockOutput(CellChannel.MARIMO_ERROR, [error]) - cell_op = MockCellOp(output=output) - session = MockSession(MockSessionView(cell_operations={"c1": cell_op})) - - result = tool._get_cell_errors(session, CellId_t("c1")) - assert result.has_errors is True - assert result.error_details is not None - assert result.error_details[0].type == "NameError" - - -def test_get_cell_errors_with_stderr(): - tool = GetCellRuntimeData(ToolContext()) - console_output = MockConsoleOutput(CellChannel.STDERR, "warn") - cell_op = MockCellOp(console=[console_output]) - session = MockSession(MockSessionView(cell_operations={"c1": cell_op})) - - result = tool._get_cell_errors(session, CellId_t("c1")) - assert result.has_errors is True - assert result.error_details is not None - assert result.error_details[0].type == "STDERR" - - -def test_get_cell_errors_dict_error(): - tool = GetCellRuntimeData(ToolContext()) - dict_error = {"type": "ValueError", "msg": "invalid", "traceback": ["tb1"]} - output = MockOutput(CellChannel.MARIMO_ERROR, [dict_error]) - cell_op = MockCellOp(output=output) - session = MockSession(MockSessionView(cell_operations={"c1": cell_op})) - - result = tool._get_cell_errors(session, CellId_t("c1")) - assert result.has_errors is True - assert result.error_details is not None - assert result.error_details[0].type == "ValueError" - - def test_get_cell_metadata_basic(): tool = GetCellRuntimeData(ToolContext()) cell_op = MockCellOp(status="idle") @@ -239,25 +189,3 @@ def test_get_visual_output_no_output(): visual_output, mimetype = tool._get_visual_output(cell_op) # type: ignore[arg-type] assert visual_output is None assert mimetype is None - - -def test_get_console_outputs_with_stdout_stderr(): - tool = GetCellOutputs(ToolContext()) - console = [ - MockConsoleOutput(CellChannel.STDOUT, "hello"), - MockConsoleOutput(CellChannel.STDERR, "warning"), - ] - cell_op = MockCellOp(console=console) - - stdout, stderr = tool._get_console_outputs(cell_op) # type: ignore[arg-type] - assert stdout == ["hello"] - assert stderr == ["warning"] - - -def test_get_console_outputs_no_console(): - tool = GetCellOutputs(ToolContext()) - cell_op = MockCellOp(console=None) - - stdout, stderr = tool._get_console_outputs(cell_op) # type: ignore[arg-type] - assert stdout == [] - assert stderr == [] diff --git a/tests/_ai/tools/tools/test_errors_tool.py b/tests/_ai/tools/tools/test_errors_tool.py index 2a503b2430f..2bb8f6578fe 100644 --- a/tests/_ai/tools/tools/test_errors_tool.py +++ b/tests/_ai/tools/tools/test_errors_tool.py @@ -3,13 +3,15 @@ from dataclasses import dataclass from unittest.mock import Mock +import pytest + from marimo._ai._tools.base import ToolContext from marimo._ai._tools.tools.errors import ( GetNotebookErrors, GetNotebookErrorsArgs, ) -from marimo._messaging.cell_output import CellChannel -from marimo._types.ids import SessionId +from marimo._ai._tools.types import MarimoCellErrors, MarimoErrorDetail +from marimo._types.ids import CellId_t, SessionId @dataclass @@ -30,9 +32,15 @@ class MockConsoleOutput: data: object +@dataclass +class MockUpdateCellIdsRequest: + cell_ids: list[str] + + @dataclass class MockSessionView: cell_operations: dict | None = None + cell_ids: MockUpdateCellIdsRequest | None = None def __post_init__(self) -> None: if self.cell_operations is None: @@ -50,63 +58,122 @@ class MockSession: app_file_manager: DummyAppFileManager -def test_collect_errors_none() -> None: - tool = GetNotebookErrors(ToolContext()) +@pytest.fixture +def tool() -> GetNotebookErrors: + """Create a GetNotebookErrors tool instance.""" + return GetNotebookErrors(ToolContext()) - # Empty session view - session = MockSession( - session_view=MockSessionView(), - app_file_manager=DummyAppFileManager(app=Mock()), - ) - summaries = tool._collect_errors(session) # type: ignore[arg-type] - assert summaries == [] +@pytest.fixture +def mock_context() -> Mock: + """Create a mock ToolContext.""" + context = Mock(spec=ToolContext) + return context + +def test_get_notebook_errors_empty_session(mock_context: Mock) -> None: + """Test get_notebook_errors with no errors.""" + mock_context.get_notebook_errors.return_value = [] -def test_collect_errors_marimo_and_stderr() -> None: tool = GetNotebookErrors(ToolContext()) + tool.context = mock_context + + result = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1"))) + + assert result.has_errors is False + assert result.total_errors == 0 + assert result.total_cells_with_errors == 0 + assert result.cells == [] + assert result.next_steps is not None + assert "No errors detected" in result.next_steps + + +def test_get_notebook_errors_marimo_error_only(mock_context: Mock) -> None: + """Test get_notebook_errors with MARIMO_ERROR only.""" + marimo_errors = [ + MarimoCellErrors( + cell_id=CellId_t("c1"), + errors=[ + MarimoErrorDetail( + type="ValueError", + message="bad value", + traceback=["line 1"], + ) + ], + ) + ] + mock_context.get_notebook_errors.return_value = marimo_errors - # Cell c1 has MARIMO_ERROR (dict) and STDERR; c2 has STDERR only - err_dict = {"type": "ValueError", "msg": "bad value", "traceback": ["tb"]} - c1 = MockCellOp( - output=MockOutput(CellChannel.MARIMO_ERROR, [err_dict]), - console=[MockConsoleOutput(CellChannel.STDERR, "warn")], - ) - c2 = MockCellOp(console=[MockConsoleOutput(CellChannel.STDERR, "oops")]) + tool = GetNotebookErrors(ToolContext()) + tool.context = mock_context + + result = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1"))) + + assert result.has_errors is True + assert result.total_errors == 1 + assert result.total_cells_with_errors == 1 + assert len(result.cells) == 1 + assert result.cells[0].cell_id == CellId_t("c1") + assert result.cells[0].errors[0].type == "ValueError" + assert result.next_steps is not None + assert "get_cell_runtime_data" in result.next_steps[0] + + +def test_get_notebook_errors_multiple_cells(mock_context: Mock) -> None: + """Test get_notebook_errors with errors in multiple cells.""" + multiple_errors = [ + MarimoCellErrors( + cell_id=CellId_t("c1"), + errors=[ + MarimoErrorDetail( + type="ValueError", + message="error in c1", + traceback=[], + ) + ], + stderr=[], + ), + MarimoCellErrors( + cell_id=CellId_t("c2"), + errors=[ + MarimoErrorDetail( + type="TypeError", + message="error in c2", + traceback=[], + ), + MarimoErrorDetail( + type="ValueError", + message="error in c2", + traceback=[], + ), + ], + stderr=[], + ), + ] + mock_context.get_notebook_errors.return_value = multiple_errors - session = MockSession( - session_view=MockSessionView(cell_operations={"c1": c1, "c2": c2}), - app_file_manager=DummyAppFileManager(app=Mock()), - ) + tool = GetNotebookErrors(ToolContext()) + tool.context = mock_context - summaries = tool._collect_errors(session) # type: ignore[arg-type] + result = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1"))) - # Sorted by cell_id: c1 then c2 - assert len(summaries) == 2 - assert summaries[0].cell_id == "c1" - assert len(summaries[0].errors) == 2 # one MARIMO_ERROR, one STDERR - assert summaries[0].errors[0].type == "ValueError" - assert summaries[1].cell_id == "c2" - assert len(summaries[1].errors) == 1 - assert summaries[1].errors[0].type == "STDERR" + assert result.has_errors is True + assert result.total_errors == 3 + assert result.total_cells_with_errors == 2 + assert len(result.cells) == 2 -def test_handle_integration_uses_context_get_session() -> None: +def test_get_notebook_errors_respects_session_id(mock_context: Mock) -> None: + """Test that get_notebook_errors passes the correct session_id.""" + session_id = SessionId("test-session-123") + mock_context.get_notebook_errors.return_value = [] + tool = GetNotebookErrors(ToolContext()) + tool.context = mock_context - c1 = MockCellOp(console=[MockConsoleOutput(CellChannel.STDERR, "warn")]) - session = MockSession( - session_view=MockSessionView(cell_operations={"c1": c1}), - app_file_manager=DummyAppFileManager(app=Mock()), - ) + tool.handle(GetNotebookErrorsArgs(session_id=session_id)) - # Mock ToolContext.get_session - context = Mock(spec=ToolContext) - context.get_session.return_value = session - tool.context = context # type: ignore[assignment] - - out = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1"))) - assert out.has_errors is True - assert out.total_errors == 1 - assert len(out.cells) == 1 - assert out.cells[0].cell_id == "c1" + # Verify the session_id was passed correctly with include_stderr=True + mock_context.get_notebook_errors.assert_called_once_with( + session_id, include_stderr=True + ) diff --git a/tests/_mcp/server/prompts/test_errors_prompts.py b/tests/_mcp/server/prompts/test_errors_prompts.py new file mode 100644 index 00000000000..622b3580129 --- /dev/null +++ b/tests/_mcp/server/prompts/test_errors_prompts.py @@ -0,0 +1,66 @@ +import pytest + +pytest.importorskip("mcp", reason="MCP requires Python 3.10+") + +from unittest.mock import Mock + +from marimo._ai._tools.types import ( + MarimoCellErrors, + MarimoErrorDetail, + MarimoNotebookInfo, +) +from marimo._mcp.server._prompts.prompts.errors import ErrorsSummary +from marimo._types.ids import CellId_t, SessionId + + +def test_errors_summary_includes_essential_data(): + """Test that essential data (notebook info, cell IDs, error messages) are included.""" + context = Mock() + context.get_active_sessions_internal.return_value = [ + MarimoNotebookInfo( + name="notebook.py", + path="/path/to/notebook.py", + session_id=SessionId("session_1"), + ) + ] + context.get_notebook_errors.return_value = [ + MarimoCellErrors( + cell_id=CellId_t("cell_1"), + errors=[ + MarimoErrorDetail( + type="ValueError", + message="invalid value", + traceback=["line 1"], + ) + ], + ), + MarimoCellErrors( + cell_id=CellId_t("cell_2"), + errors=[ + MarimoErrorDetail( + type="TypeError", + message="wrong type", + traceback=["line 2"], + ) + ], + ), + ] + + prompt = ErrorsSummary(context=context) + messages = prompt.handle() + text = "\n".join( + msg.content.text # type: ignore[attr-defined] + for msg in messages + if hasattr(msg.content, "text") + ) + + # Essential data must be present + assert "notebook.py" in text + assert "session_1" in text + assert "/path/to/notebook.py" in text + assert "cell_1" in text + assert "cell_2" in text + assert "ValueError" in text + assert "invalid value" in text + assert "TypeError" in text + assert "wrong type" in text diff --git a/tests/_mcp/server/prompts/test_notebooks_prompts.py b/tests/_mcp/server/prompts/test_notebooks_prompts.py index 45f044d01de..212265eb98f 100644 --- a/tests/_mcp/server/prompts/test_notebooks_prompts.py +++ b/tests/_mcp/server/prompts/test_notebooks_prompts.py @@ -4,73 +4,37 @@ from unittest.mock import Mock +from marimo._ai._tools.types import MarimoNotebookInfo from marimo._mcp.server._prompts.prompts.notebooks import ActiveNotebooks +from marimo._types.ids import SessionId -def test_active_notebooks_metadata(): - """Test that name and description are properly set.""" - prompt = ActiveNotebooks(context=Mock()) - assert prompt.name == "active_notebooks" - assert ( - prompt.description - == "Get current active notebooks and their session IDs and file paths." - ) - - -def test_active_notebooks_no_sessions(): - """Test output when no sessions are active.""" +def test_active_notebooks_includes_session_ids_and_paths(): + """Test that essential data (session IDs and file paths) are included.""" context = Mock() - context.get_active_sessions_internal.return_value = [] - - prompt = ActiveNotebooks(context=context) - messages = prompt.handle() - - assert len(messages) == 1 - assert messages[0].role == "user" - assert messages[0].content.type == "text" - assert ( - "No active marimo notebook sessions found" in messages[0].content.text - ) - - -def test_active_notebooks_with_sessions(): - """Test output with active sessions.""" - context = Mock() - - # Mock active session objects - active_session1 = Mock() - active_session1.session_id = "session_1" - active_session1.path = "/path/to/notebook.py" - - active_session2 = Mock() - active_session2.session_id = "session_2" - active_session2.path = None - context.get_active_sessions_internal.return_value = [ - active_session1, - active_session2, + MarimoNotebookInfo( + name="notebook.py", + path="/path/to/notebook.py", + session_id=SessionId("session_1"), + ), + MarimoNotebookInfo( + name="other.py", + path="/other/path.py", + session_id=SessionId("session_2"), + ), ] prompt = ActiveNotebooks(context=context) messages = prompt.handle() - - assert len(messages) == 3 # 2 sessions + 1 action message - - # Check first message (with file path) - assert messages[0].role == "user" - assert messages[0].content.type == "text" - assert "session_1" in messages[0].content.text - assert "/path/to/notebook.py" in messages[0].content.text - - # Check second message (without file path) - assert messages[1].role == "user" - assert messages[1].content.type == "text" - assert "session_2" in messages[1].content.text - - # Check action message - assert messages[2].role == "user" - assert messages[2].content.type == "text" - assert ( - "Use these session_ids when calling marimo MCP tools" - in messages[2].content.text + text = "\n".join( + msg.content.text # type: ignore[attr-defined] + for msg in messages + if hasattr(msg.content, "text") ) + + # Essential data must be present + assert "session_1" in text + assert "session_2" in text + assert "/path/to/notebook.py" in text + assert "/other/path.py" in text