Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion marimo/_ast/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
is_non_marimo_python_script,
parse_notebook,
)
from marimo._schemas.serialization import NotebookSerialization, UnparsableCell
from marimo._schemas.serialization import (
CellDef,
NotebookSerialization,
UnparsableCell,
)

LOGGER = _loggers.marimo_logger()

Expand Down Expand Up @@ -89,6 +93,18 @@ def _static_load(filepath: Path) -> Optional[App]:
return load_notebook_ir(notebook, filepath=str(filepath))


def find_cell(filename, lineno) -> CellDef | None:
load_result = get_notebook_status(filename)
if load_result.notebook is None:
raise OSError("Could not resolve notebook.")
previous = None
for cell in load_result.notebook.cells:
if cell.lineno > lineno:
break
previous = cell
return previous


def load_notebook_ir(
notebook: NotebookSerialization, filepath: Optional[str] = None
) -> App:
Expand Down
25 changes: 23 additions & 2 deletions marimo/_save/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)

from marimo._ast.cell_id import is_external_cell_id
from marimo._ast.load import find_cell
from marimo._ast.transformers import (
ARG_PREFIX,
CacheExtractWithBlock,
Expand Down Expand Up @@ -596,7 +597,7 @@ def trace(self, with_frame: FrameType) -> None:
# causing this function to terminate before reaching this block.
self._frame = with_frame
for i, frame in enumerate(stack[::-1]):
_filename, lineno, function_name, _code = frame
filename, lineno, function_name, _code = frame
if function_name == "<module>":
ctx = get_context()
if ctx.execution_context is None:
Expand All @@ -608,10 +609,30 @@ def trace(self, with_frame: FrameType) -> None:
)
graph = ctx.graph
cell_id = ctx.cell_id or ctx.execution_context.cell_id

# We are calling from script mode, so our line number is
# absolute.
if "__marimo__" not in filename:
cell = find_cell(filename, lineno)
if cell is None:
raise CacheException(
"Could not resolve cell for cache."
f"{UNEXPECTED_FAILURE_BOILERPLATE}"
)
lineno -= cell.lineno
code = cell.code
elif cell_id in graph.cells:
code = graph.cells[cell_id].code
else:
raise CacheException(
"Could not resolve cell for cache."
f"{UNEXPECTED_FAILURE_BOILERPLATE}"
)

pre_module, save_module = CacheExtractWithBlock(
lineno - 1
).visit(
ast.parse(graph.cells[cell_id].code).body # type: ignore[arg-type]
ast.parse(code).body # type: ignore[arg-type]
)

self._cache = cache_attempt_from_hash(
Expand Down
11 changes: 10 additions & 1 deletion tests/_save/external_decorators/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@app.cell
def _():
def decorator_wrap():
@mo.cache
def cache(x):
return x + 1
Expand All @@ -17,5 +17,14 @@ def cache(x):
return (bar, cache)


@app.cell
def block_wrap(mo):
with mo.cache("random"):
x = []

a = "need a final line to trigger invalid block capture"
return (x,)


if __name__ == "__main__":
app.run()
4 changes: 4 additions & 0 deletions tests/_save/test_external_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ def _():
_, defs = ex_app.run()
assert defs["bar"] == 2
assert defs["cache"](1) == 2
assert len(defs["x"]) == 0
defs["x"].append(1)
_, defs = ex_app.run()
assert len(defs["x"]) == 1
return

@staticmethod
Expand Down
Loading