diff --git a/marimo/_ast/cell.py b/marimo/_ast/cell.py index c7c3cdd24c0..0ba2fb5f153 100644 --- a/marimo/_ast/cell.py +++ b/marimo/_ast/cell.py @@ -122,6 +122,7 @@ class RuntimeState: @dataclasses.dataclass class RunResultStatus: state: Optional[RunResultStatusType] = None + exception: Optional[Exception] = None @dataclasses.dataclass @@ -354,9 +355,12 @@ def set_runtime_state( ) def set_run_result_status( - self, run_result_status: RunResultStatusType + self, + run_result_status: RunResultStatusType, + exception: Exception | None = None, ) -> None: self._run_result_status.state = run_result_status + self._run_result_status.exception = exception def set_stale( self, stale: bool, stream: Stream | None = None, broadcast: bool = True @@ -376,6 +380,10 @@ def set_output(self, output: Any) -> None: def output(self) -> Any: return self._output.output + @property + def exception(self) -> Exception | None: + return self._run_result_status.exception + @dataclasses.dataclass class Cell: diff --git a/marimo/_runtime/runner/cell_runner.py b/marimo/_runtime/runner/cell_runner.py index 34eff08b7d4..341a00175e8 100644 --- a/marimo/_runtime/runner/cell_runner.py +++ b/marimo/_runtime/runner/cell_runner.py @@ -65,7 +65,7 @@ def cell_filename(cell_id: CellId_t) -> str: return f"" -ErrorObjects = Union[BaseException, Error] +ExceptionOrError = Union[BaseException, Error] @dataclass @@ -73,7 +73,11 @@ class RunResult: # Raw output of cell: last expression output: Any # Exception raised by cell, if any - exception: Optional[ErrorObjects] + # + # TODO(akshayka): Exceptions and "Errors" (most of which are at parse time + # and can't be encountered by the runner) shouldn't be packed into a single + # field. + exception: Optional[ExceptionOrError] # Accumulated output: via imperative mo.output.append() accumulated_output: Any = None @@ -82,6 +86,26 @@ def success(self) -> bool: return self.exception is None +def should_show_traceback( + exception: Optional[ExceptionOrError], +) -> bool: + if exception is None: + return True + + # Stop "errors" aren't actually errors but rather a control + # flow mechanism used by mo.stop() to stop execution; as such + # a traceback should not be shown for them. + if isinstance(exception, MarimoStopError): + return False + + # SQL parsing errors happen in SQL cells so showing a + # python traceback is not useful. + if isinstance(exception, MarimoSQLError): + return False + + return True + + class Runner: """Runner for a collection of cells.""" @@ -147,7 +171,7 @@ def __init__( # whether the runner has been interrupted self.interrupted = False # mapping from cell_id to exception it raised - self.exceptions: dict[CellId_t, ErrorObjects] = {} + self.exceptions: dict[CellId_t, ExceptionOrError] = {} # each cell's position in the run queue self._run_position = { @@ -334,7 +358,7 @@ def _run_result_from_exception( unwrapped_exception: Optional[BaseException], cell_id: CellId_t, ) -> tuple[RunResult, Optional[BaseException]]: - exception: Optional[ErrorObjects] = unwrapped_exception + exception: Optional[ExceptionOrError] = unwrapped_exception if isinstance(exception, MarimoMissingRefError): ref, blamed_cell = self._get_blamed_cell(exception) # All MarimoMissingRefErrors should be caused caused by @@ -514,25 +538,6 @@ async def run(self, cell_id: CellId_t) -> RunResult: # this call as well, so this should be lifted out of `run`. self.cancel(cell_id) - def should_show_traceback( - exception: Optional[ErrorObjects], - ) -> bool: - if exception is None: - return True - - # Stop "errors" aren't actually errors but rather a control - # flow mechanism used by mo.stop() to stop execution; as such - # a traceback should not be shown for them. - if isinstance(exception, MarimoStopError): - return False - - # SQL parsing errors happen in SQL cells so showing a - # python traceback is not useful. - if isinstance(exception, MarimoSQLError): - return False - - return True - if should_show_traceback(run_result.exception): tmpio = io.StringIO() # The executors explicitly raise cell exceptions from base diff --git a/marimo/_runtime/runner/hooks_post_execution.py b/marimo/_runtime/runner/hooks_post_execution.py index 9c31c0f0805..24c52139bbc 100644 --- a/marimo/_runtime/runner/hooks_post_execution.py +++ b/marimo/_runtime/runner/hooks_post_execution.py @@ -93,7 +93,17 @@ def _set_run_result_status( elif runner.cancelled(cell.cell_id): cell.set_run_result_status("cancelled") elif run_result.exception is not None: - cell.set_run_result_status("exception") + cell.set_run_result_status( + "exception", + ( + # TODO(akshayka): "run_result.exception" can unfortunately + # hold things that are not exceptions; remove this check + # if/when that is ever cleaned up. + run_result.exception + if isinstance(run_result.exception, Exception) + else None + ), + ) else: cell.set_run_result_status("success") diff --git a/marimo/_runtime/runtime.py b/marimo/_runtime/runtime.py index 5e47a180df3..96c75c2b46a 100644 --- a/marimo/_runtime/runtime.py +++ b/marimo/_runtime/runtime.py @@ -2880,12 +2880,25 @@ def log_callback(log_line: str) -> None: if self.should_update_script_metadata(): self.update_script_metadata(installed_modules) + # All cells that depend on successfully installed modules are re-run. + # + # This consists of cells that either statically reference the installed + # module, or that previously failed with a ModuleNotFoundError matching + # an installed module. cells_to_run = set( cid for module in installed_modules if (cid := self._kernel.module_registry.defining_cell(module)) is not None ) + + for cid, cell in self._kernel.graph.cells.items(): + if ( + isinstance(cell.exception, ModuleNotFoundError) + and cell.exception.name in installed_modules + ): + cells_to_run.add(cid) + if cells_to_run: await self._kernel._if_autorun_then_run_cells(cells_to_run) diff --git a/tests/_runtime/test_runtime.py b/tests/_runtime/test_runtime.py index 96d18cfba50..f69d004144f 100644 --- a/tests/_runtime/test_runtime.py +++ b/tests/_runtime/test_runtime.py @@ -1879,6 +1879,22 @@ async def test_sync_graph_empty_sync(self, any_kernel: Kernel) -> None: assert k.globals["y"] == 2 assert not k.errors + async def test_missing_module_detected(self, any_kernel: Kernel) -> None: + k = any_kernel + await k.run( + [er := ExecutionRequest(cell_id="0", code="import foobar")] + ) + cell = k.graph.cells[er.cell_id] + assert cell.exception is not None + assert isinstance(cell.exception, ModuleNotFoundError) + assert cell.exception.name == "foobar" + + await k.run( + [er := ExecutionRequest(cell_id="0", code="import marimo")] + ) + cell = k.graph.cells[er.cell_id] + assert cell.exception is None + class TestStrictExecution: @staticmethod