Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion marimo/_ast/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class RuntimeState:
@dataclasses.dataclass
class RunResultStatus:
state: Optional[RunResultStatusType] = None
exception: Optional[Exception] = None


@dataclasses.dataclass
Expand Down Expand Up @@ -354,9 +355,12 @@ def set_runtime_state(
)

def set_run_result_status(
self, run_result_status: RunResultStatusType
self,
run_result_status: RunResultStatusType,
exception: Exception | None = None,
) -> None:
self._run_result_status.state = run_result_status
self._run_result_status.exception = exception

def set_stale(
self, stale: bool, stream: Stream | None = None, broadcast: bool = True
Expand All @@ -376,6 +380,10 @@ def set_output(self, output: Any) -> None:
def output(self) -> Any:
return self._output.output

@property
def exception(self) -> Exception | None:
return self._run_result_status.exception


@dataclasses.dataclass
class Cell:
Expand Down
51 changes: 28 additions & 23 deletions marimo/_runtime/runner/cell_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,19 @@ def cell_filename(cell_id: CellId_t) -> str:
return f"<cell-{cell_id}>"


ErrorObjects = Union[BaseException, Error]
ExceptionOrError = Union[BaseException, Error]


@dataclass
class RunResult:
# Raw output of cell: last expression
output: Any
# Exception raised by cell, if any
exception: Optional[ErrorObjects]
#
# TODO(akshayka): Exceptions and "Errors" (most of which are at parse time
# and can't be encountered by the runner) shouldn't be packed into a single
# field.
exception: Optional[ExceptionOrError]
# Accumulated output: via imperative mo.output.append()
accumulated_output: Any = None

Expand All @@ -82,6 +86,26 @@ def success(self) -> bool:
return self.exception is None


def should_show_traceback(
exception: Optional[ExceptionOrError],
) -> bool:
if exception is None:
return True

# Stop "errors" aren't actually errors but rather a control
# flow mechanism used by mo.stop() to stop execution; as such
# a traceback should not be shown for them.
if isinstance(exception, MarimoStopError):
return False

# SQL parsing errors happen in SQL cells so showing a
# python traceback is not useful.
if isinstance(exception, MarimoSQLError):
return False

return True


class Runner:
"""Runner for a collection of cells."""

Expand Down Expand Up @@ -147,7 +171,7 @@ def __init__(
# whether the runner has been interrupted
self.interrupted = False
# mapping from cell_id to exception it raised
self.exceptions: dict[CellId_t, ErrorObjects] = {}
self.exceptions: dict[CellId_t, ExceptionOrError] = {}

# each cell's position in the run queue
self._run_position = {
Expand Down Expand Up @@ -334,7 +358,7 @@ def _run_result_from_exception(
unwrapped_exception: Optional[BaseException],
cell_id: CellId_t,
) -> tuple[RunResult, Optional[BaseException]]:
exception: Optional[ErrorObjects] = unwrapped_exception
exception: Optional[ExceptionOrError] = unwrapped_exception
if isinstance(exception, MarimoMissingRefError):
ref, blamed_cell = self._get_blamed_cell(exception)
# All MarimoMissingRefErrors should be caused caused by
Expand Down Expand Up @@ -514,25 +538,6 @@ async def run(self, cell_id: CellId_t) -> RunResult:
# this call as well, so this should be lifted out of `run`.
self.cancel(cell_id)

def should_show_traceback(
exception: Optional[ErrorObjects],
) -> bool:
if exception is None:
return True

# Stop "errors" aren't actually errors but rather a control
# flow mechanism used by mo.stop() to stop execution; as such
# a traceback should not be shown for them.
if isinstance(exception, MarimoStopError):
return False

# SQL parsing errors happen in SQL cells so showing a
# python traceback is not useful.
if isinstance(exception, MarimoSQLError):
return False

return True

if should_show_traceback(run_result.exception):
tmpio = io.StringIO()
# The executors explicitly raise cell exceptions from base
Expand Down
12 changes: 11 additions & 1 deletion marimo/_runtime/runner/hooks_post_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,17 @@ def _set_run_result_status(
elif runner.cancelled(cell.cell_id):
cell.set_run_result_status("cancelled")
elif run_result.exception is not None:
cell.set_run_result_status("exception")
cell.set_run_result_status(
"exception",
(
# TODO(akshayka): "run_result.exception" can unfortunately
# hold things that are not exceptions; remove this check
# if/when that is ever cleaned up.
run_result.exception
if isinstance(run_result.exception, Exception)
else None
),
)
else:
cell.set_run_result_status("success")

Expand Down
13 changes: 13 additions & 0 deletions marimo/_runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2880,12 +2880,25 @@ def log_callback(log_line: str) -> None:
if self.should_update_script_metadata():
self.update_script_metadata(installed_modules)

# All cells that depend on successfully installed modules are re-run.
#
# This consists of cells that either statically reference the installed
# module, or that previously failed with a ModuleNotFoundError matching
# an installed module.
cells_to_run = set(
cid
for module in installed_modules
if (cid := self._kernel.module_registry.defining_cell(module))
is not None
)

for cid, cell in self._kernel.graph.cells.items():
if (
isinstance(cell.exception, ModuleNotFoundError)
and cell.exception.name in installed_modules
):
cells_to_run.add(cid)

if cells_to_run:
await self._kernel._if_autorun_then_run_cells(cells_to_run)

Expand Down
16 changes: 16 additions & 0 deletions tests/_runtime/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1879,6 +1879,22 @@ async def test_sync_graph_empty_sync(self, any_kernel: Kernel) -> None:
assert k.globals["y"] == 2
assert not k.errors

async def test_missing_module_detected(self, any_kernel: Kernel) -> None:
k = any_kernel
await k.run(
[er := ExecutionRequest(cell_id="0", code="import foobar")]
)
cell = k.graph.cells[er.cell_id]
assert cell.exception is not None
assert isinstance(cell.exception, ModuleNotFoundError)
assert cell.exception.name == "foobar"

await k.run(
[er := ExecutionRequest(cell_id="0", code="import marimo")]
)
cell = k.graph.cells[er.cell_id]
assert cell.exception is None


class TestStrictExecution:
@staticmethod
Expand Down
Loading