Skip to content

Commit 36d6811

Browse files
committed
update tests to reflect changes
1 parent fabc20f commit 36d6811

File tree

3 files changed

+99
-92
lines changed

3 files changed

+99
-92
lines changed

tests/_ai/tools/test_base.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from marimo._ai._tools.base import ToolBase, ToolContext
1010
from marimo._ai._tools.utils.exceptions import ToolExecutionError
11+
from marimo._messaging import msgspec_encoder
1112

1213

1314
@dataclass
@@ -125,3 +126,93 @@ def test_as_backend_tool() -> None:
125126
is_valid, msg = validator({"invalid": "field"})
126127
assert is_valid is False
127128
assert "Invalid arguments" in msg
129+
130+
# test ToolContext methods
131+
132+
def test_get_notebook_errors_orders_by_cell_manager():
133+
"""Test errors follow cell_manager order, not alphabetical."""
134+
from unittest.mock import Mock
135+
from marimo._messaging.cell_output import CellChannel
136+
from marimo._types.ids import CellId_t, SessionId
137+
138+
context = ToolContext()
139+
140+
# Mock error cell_op
141+
error_op = Mock()
142+
error_op.output = Mock()
143+
error_op.output.channel = CellChannel.MARIMO_ERROR
144+
error_op.output.data = [{"type": "Error", "msg": "test", "traceback": []}]
145+
error_op.console = None
146+
147+
# Mock session with cells c1, c2, c3
148+
session = Mock()
149+
session.session_view.cell_operations = {
150+
CellId_t("c1"): error_op,
151+
CellId_t("c2"): error_op,
152+
CellId_t("c3"): error_op,
153+
}
154+
155+
# Cell manager returns in order: c3, c2, c1 (not alphabetical)
156+
cell_data = [Mock(cell_id=CellId_t("c3")), Mock(cell_id=CellId_t("c2")), Mock(cell_id=CellId_t("c1"))]
157+
session.app_file_manager.app.cell_manager.cell_data.return_value = cell_data
158+
159+
context.get_session = Mock(return_value=session)
160+
161+
errors = context.get_notebook_errors(SessionId("test"), include_stderr=False)
162+
163+
# Should be c3, c2, c1 (not c1, c2, c3)
164+
assert errors[0].cell_id == CellId_t("c3")
165+
assert errors[1].cell_id == CellId_t("c2")
166+
assert errors[2].cell_id == CellId_t("c1")
167+
168+
169+
def test_get_cell_errors_extracts_from_output():
170+
"""Test get_cell_errors extracts error details from cell output."""
171+
from unittest.mock import Mock
172+
from marimo._messaging.cell_output import CellChannel
173+
from marimo._types.ids import CellId_t, SessionId
174+
175+
context = ToolContext()
176+
177+
# Mock cell_op with error
178+
cell_op = Mock()
179+
cell_op.output = Mock()
180+
cell_op.output.channel = CellChannel.MARIMO_ERROR
181+
cell_op.output.data = [
182+
{"type": "ValueError", "msg": "bad value", "traceback": ["line 1"]}
183+
]
184+
185+
errors = context.get_cell_errors(SessionId("test"), CellId_t("c1"), maybe_cell_op=cell_op)
186+
187+
assert len(errors) == 1
188+
assert errors[0].type == "ValueError"
189+
assert errors[0].message == "bad value"
190+
assert errors[0].traceback == ["line 1"]
191+
192+
193+
def test_get_cell_console_outputs_separates_stdout_stderr():
194+
"""Test get_cell_console_outputs separates stdout and stderr."""
195+
from unittest.mock import Mock
196+
from marimo._messaging.cell_output import CellChannel
197+
198+
context = ToolContext()
199+
200+
# Mock cell_op with stdout and stderr
201+
stdout_output = Mock()
202+
stdout_output.channel = CellChannel.STDOUT
203+
stdout_output.data = "hello"
204+
205+
stderr_output = Mock()
206+
stderr_output.channel = CellChannel.STDERR
207+
stderr_output.data = "warning"
208+
209+
cell_op = Mock()
210+
cell_op.console = [stdout_output, stderr_output]
211+
212+
result = context.get_cell_console_outputs(cell_op)
213+
214+
assert len(result.stdout) == 1
215+
assert "hello" in result.stdout[0]
216+
assert len(result.stderr) == 1
217+
assert "warning" in result.stderr[0]
218+

tests/_ai/tools/tools/test_cells.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -190,25 +190,3 @@ def test_get_visual_output_no_output():
190190
visual_output, mimetype = tool._get_visual_output(cell_op) # type: ignore[arg-type]
191191
assert visual_output is None
192192
assert mimetype is None
193-
194-
195-
def test_get_console_outputs_with_stdout_stderr():
196-
tool = GetCellOutputs(ToolContext())
197-
console = [
198-
MockConsoleOutput(CellChannel.STDOUT, "hello"),
199-
MockConsoleOutput(CellChannel.STDERR, "warning"),
200-
]
201-
cell_op = MockCellOp(console=console)
202-
203-
stdout, stderr = tool._get_console_outputs(cell_op) # type: ignore[arg-type]
204-
assert stdout == ["hello"]
205-
assert stderr == ["warning"]
206-
207-
208-
def test_get_console_outputs_no_console():
209-
tool = GetCellOutputs(ToolContext())
210-
cell_op = MockCellOp(console=None)
211-
212-
stdout, stderr = tool._get_console_outputs(cell_op) # type: ignore[arg-type]
213-
assert stdout == []
214-
assert stderr == []

tests/_ai/tools/tools/test_errors_tool.py

Lines changed: 8 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -119,65 +119,6 @@ def test_get_notebook_errors_marimo_error_only(mock_context: Mock) -> None:
119119
assert "get_cell_runtime_data" in result.next_steps[0]
120120

121121

122-
def test_get_notebook_errors_stderr_only(mock_context: Mock) -> None:
123-
"""Test get_notebook_errors with STDERR only."""
124-
stderr_errors = [
125-
MarimoCellErrors(
126-
cell_id=CellId_t("c2"),
127-
errors=[
128-
MarimoErrorDetail(
129-
type="STDERR",
130-
message="warning message",
131-
traceback=[],
132-
)
133-
],
134-
)
135-
]
136-
mock_context.get_notebook_errors.return_value = stderr_errors
137-
138-
tool = GetNotebookErrors(ToolContext())
139-
tool.context = mock_context
140-
141-
result = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1")))
142-
143-
assert result.has_errors is False
144-
assert result.total_errors == 0
145-
assert result.total_cells_with_errors == 1
146-
assert result.cells[0].errors[0].type == "STDERR"
147-
148-
149-
def test_get_notebook_errors_mixed_errors(mock_context: Mock) -> None:
150-
"""Test get_notebook_errors with both MARIMO_ERROR and STDERR."""
151-
mixed_errors = [
152-
MarimoCellErrors(
153-
cell_id=CellId_t("c1"),
154-
errors=[
155-
MarimoErrorDetail(
156-
type="ValueError",
157-
message="bad value",
158-
traceback=["line 1"],
159-
),
160-
MarimoErrorDetail(
161-
type="STDERR",
162-
message="warn",
163-
traceback=[],
164-
),
165-
],
166-
)
167-
]
168-
mock_context.get_notebook_errors.return_value = mixed_errors
169-
170-
tool = GetNotebookErrors(ToolContext())
171-
tool.context = mock_context
172-
173-
result = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1")))
174-
175-
assert result.has_errors is True
176-
assert result.total_errors == 1
177-
assert result.total_cells_with_errors == 1
178-
assert len(result.cells[0].errors) == 2
179-
180-
181122
def test_get_notebook_errors_multiple_cells(mock_context: Mock) -> None:
182123
"""Test get_notebook_errors with errors in multiple cells."""
183124
multiple_errors = [
@@ -190,6 +131,7 @@ def test_get_notebook_errors_multiple_cells(mock_context: Mock) -> None:
190131
traceback=[],
191132
)
192133
],
134+
stderr=[],
193135
),
194136
MarimoCellErrors(
195137
cell_id=CellId_t("c2"),
@@ -198,18 +140,14 @@ def test_get_notebook_errors_multiple_cells(mock_context: Mock) -> None:
198140
type="TypeError",
199141
message="error in c2",
200142
traceback=[],
201-
)
202-
],
203-
),
204-
MarimoCellErrors(
205-
cell_id=CellId_t("c3"),
206-
errors=[
143+
),
207144
MarimoErrorDetail(
208-
type="STDERR",
209-
message="stderr in c3",
145+
type="ValueError",
146+
message="error in c2",
210147
traceback=[],
211148
)
212149
],
150+
stderr=[],
213151
),
214152
]
215153
mock_context.get_notebook_errors.return_value = multiple_errors
@@ -220,9 +158,9 @@ def test_get_notebook_errors_multiple_cells(mock_context: Mock) -> None:
220158
result = tool.handle(GetNotebookErrorsArgs(session_id=SessionId("s1")))
221159

222160
assert result.has_errors is True
223-
assert result.total_errors == 2
224-
assert result.total_cells_with_errors == 3
225-
assert len(result.cells) == 3
161+
assert result.total_errors == 3
162+
assert result.total_cells_with_errors == 2
163+
assert len(result.cells) == 2
226164

227165

228166
def test_get_notebook_errors_respects_session_id(mock_context: Mock) -> None:

0 commit comments

Comments
 (0)