Skip to content
Merged
157 changes: 88 additions & 69 deletions marimo/_save/cache.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

import importlib
import inspect
import re
import textwrap
from collections import namedtuple
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional, get_args

from marimo._plugins.ui._core.ui_element import S, T, UIElement
from marimo._plugins.ui._core.ui_element import UIElement
from marimo._runtime.context import ContextNotInitializedError, get_context
from marimo._runtime.state import SetFunctor
from marimo._save.stubs import (
CUSTOM_STUBS,
CustomStub,
FunctionStub,
ModuleStub,
UIElementStub,
maybe_register_stub,
)

if TYPE_CHECKING:
from marimo._ast.visitor import Name
Expand Down Expand Up @@ -42,45 +48,6 @@
MetaKey = Literal["return", "version"]


class ModuleStub:
def __init__(self, module: Any) -> None:
self.name = module.__name__

def load(self) -> Any:
return importlib.import_module(self.name)


class FunctionStub:
def __init__(self, function: Any) -> None:
self.code = textwrap.dedent(inspect.getsource(function))

def load(self, glbls: dict[str, Any]) -> Any:
# TODO: Fix line cache and associate with the correct module.
code_obj = compile(self.code, "<string>", "exec")
lcls: dict[str, Any] = {}
exec(code_obj, glbls, lcls)
# Update the global scope with the function.
for value in lcls.values():
return value


class UIElementStub:
def __init__(self, element: UIElement[S, T]) -> None:
self.args = element._args
self.cls = element.__class__
# Ideally only hashable attributes are stored on the subclass level.
defaults = set(self.cls.__new__(self.cls).__dict__.keys())
defaults |= {"_ctx"}
self.data = {
k: v
for k, v in element.__dict__.items()
if hasattr(v, "__hash__") and k not in defaults
}

def load(self) -> UIElement[S, T]:
return self.cls.from_args(self.data, self.args) # type: ignore


# BaseException because "raise _ as e" is utilized.
class CacheException(BaseException):
pass
Expand Down Expand Up @@ -188,6 +155,10 @@ def _restore_from_stub_if_needed(
value.clear()
value.update(result)
result = value
elif isinstance(value, CustomStub):
# CustomStub is a placeholder for a custom type, which cannot be
# restored directly.
result = value.load(scope)
else:
result = value

Expand Down Expand Up @@ -246,9 +217,19 @@ def update(
self.meta[key] = self._convert_to_stub_if_needed(value, memo)

def _convert_to_stub_if_needed(
self, value: Any, memo: dict[int, Any] | None = None
self,
value: Any,
memo: dict[int, Any] | None = None,
preserve_pointers: bool = True,
) -> Any:
"""Convert objects to stubs if needed, recursively handling collections."""
"""Convert objects to stubs if needed, recursively handling collections.

Args:
value: The value to convert
memo: Memoization dict to handle cycles
preserve_pointers: If True, modifies containers in-place to preserve
object identity. If False, creates new containers.
"""
if memo is None:
memo = {}

Expand All @@ -269,38 +250,76 @@ def _convert_to_stub_if_needed(
# tuples are immutable and cannot be recursive, but we still want to
# iteratively convert the internal items.
result = tuple(
self._convert_to_stub_if_needed(item, memo) for item in value
self._convert_to_stub_if_needed(item, memo, preserve_pointers)
for item in value
)
elif isinstance(value, set):
# sets cannot be recursive (require hasable items), but we still
# maintain the original set reference.
result = set(
self._convert_to_stub_if_needed(item, memo) for item in value
# sets cannot be recursive (require hashable items)
converted = set(
self._convert_to_stub_if_needed(item, memo, preserve_pointers)
for item in value
)
value.clear()
value.update(result)
result = value
if preserve_pointers:
value.clear()
value.update(converted)
result = value
else:
result = converted
elif isinstance(value, list):
# Store placeholder to handle cycles
memo[obj_id] = value
result = [
self._convert_to_stub_if_needed(item, memo) for item in value
]
value.clear()
value.extend(result)
result = value
if preserve_pointers:
# Preserve original list reference
memo[obj_id] = value
converted_list = [
self._convert_to_stub_if_needed(
item, memo, preserve_pointers
)
for item in value
]
value.clear()
value.extend(converted_list)
result = value
else:
# Create new list
result = []
memo[obj_id] = result
result.extend(
[
self._convert_to_stub_if_needed(
item, memo, preserve_pointers
)
for item in value
]
)
elif isinstance(value, dict):
# Recursively convert dictionary values
memo[obj_id] = value
result = {
k: self._convert_to_stub_if_needed(v, memo)
for k, v in value.items()
}
value.clear()
value.update(result)
result = value
if preserve_pointers:
# Preserve original dict reference
memo[obj_id] = value
converted_dict = {
k: self._convert_to_stub_if_needed(
v, memo, preserve_pointers
)
for k, v in value.items()
}
value.clear()
value.update(converted_dict)
result = value
else:
# Create new dict
result = {}
memo[obj_id] = result
result.update(
{
k: self._convert_to_stub_if_needed(
v, memo, preserve_pointers
)
for k, v in value.items()
}
)
elif type(value) in CUSTOM_STUBS or maybe_register_stub(value):
result = CUSTOM_STUBS[type(value)](value)
else:
result = value

memo[obj_id] = result

return result
Expand Down
6 changes: 6 additions & 0 deletions marimo/_save/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from marimo._runtime.state import SetFunctor, State
from marimo._runtime.watch._path import PathState
from marimo._save.cache import Cache, CacheType
from marimo._save.stubs import maybe_get_custom_stub
from marimo._types.ids import CellId_t

if TYPE_CHECKING:
Expand Down Expand Up @@ -779,6 +780,9 @@ def serialize_and_dequeue_content_refs(
- primitive (bytes, str, numbers.Number, type(None))
- data primitive (e.g. numpy array, torch tensor)
- external module definitions (imported anything)
- pure functions (no state, no external dependencies)
- pure containers of the above (list, dict, set, tuple)
- custom types defined in CUSTOM_STUBS

Args:
refs: A set of reference names unaccounted for.
Expand Down Expand Up @@ -839,6 +843,8 @@ def serialize_and_dequeue_content_refs(
# pinning being the mechanism for invalidation.
elif getattr(value, "__module__", "__main__") == "__main__":
continue
elif stub := maybe_get_custom_stub(value):
serial_value = stub.to_bytes()
# External module that is not a class or function, may be some
# container we don't know how to hash.
# Note, function cases care caught by is_pure_function
Expand Down
93 changes: 93 additions & 0 deletions marimo/_save/stubs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2025 Marimo. All rights reserved.
"""Stub system for cache serialization."""

from __future__ import annotations

from typing import Any, Callable

from marimo._save.stubs.function_stub import FunctionStub
from marimo._save.stubs.module_stub import ModuleStub
from marimo._save.stubs.pydantic_stub import PydanticStub
from marimo._save.stubs.stubs import (
CUSTOM_STUBS,
CustomStub,
register_stub,
)
from marimo._save.stubs.ui_element_stub import UIElementStub

# Track which class names we've already attempted to register
_REGISTERED_NAMES: set[str] = set()

# Dictionary mapping fully qualified class names to registration functions
STUB_REGISTRATIONS: dict[str, Callable[[Any], None]] = {
"pydantic.main.BaseModel": PydanticStub.register,
}


def maybe_register_stub(value: Any) -> bool:
"""Lazily register a stub for a value's type if not already registered.

This allows us to avoid importing third-party packages until they're
actually used in the cache. Walks the MRO to check if any parent class
matches a registered stub type.

Returns:
True if the value's type is in CUSTOM_STUBS (either already registered
or newly registered), False otherwise.
"""
value_type = type(value)

# Already registered in CUSTOM_STUBS
if value_type in CUSTOM_STUBS:
return True

# Walk MRO to find matching base class
try:
mro_list = value_type.mro()
except BaseException:
# Some exotic metaclasses or broken types may raise when calling mro
mro_list = [value_type]

for cls in mro_list:
if not (hasattr(cls, "__module__") and hasattr(cls, "__name__")):
continue

cls_name = f"{cls.__module__}.{cls.__name__}"

if cls_name in STUB_REGISTRATIONS:
if cls_name not in _REGISTERED_NAMES:
_REGISTERED_NAMES.add(cls_name)
STUB_REGISTRATIONS[cls_name](value)
# After registration attempt, check if now in CUSTOM_STUBS
return value_type in CUSTOM_STUBS

return False


def maybe_get_custom_stub(value: Any) -> CustomStub | None:
"""Get the registered stub for a value's type, if any.

Args:
value: The value to get the stub for

Returns:
A stub instance if registered, None otherwise
"""
# Fallback to custom cases
if maybe_register_stub(value):
value_type = type(value)
if value_type in CUSTOM_STUBS:
return CUSTOM_STUBS[value_type](value)
return None


__all__ = [
"CUSTOM_STUBS",
"CustomStub",
"FunctionStub",
"ModuleStub",
"UIElementStub",
"maybe_register_stub",
"maybe_get_custom_stub",
"register_stub",
]
25 changes: 25 additions & 0 deletions marimo/_save/stubs/function_stub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2025 Marimo. All rights reserved.
from __future__ import annotations

import inspect
import textwrap
from typing import Any

__all__ = ["FunctionStub"]


class FunctionStub:
"""Stub for function objects, storing the source code."""

def __init__(self, function: Any) -> None:
self.code = textwrap.dedent(inspect.getsource(function))

def load(self, glbls: dict[str, Any]) -> Any:
"""Reconstruct the function by executing its source code."""
# TODO: Fix line cache and associate with the correct module.
code_obj = compile(self.code, "<string>", "exec")
lcls: dict[str, Any] = {}
exec(code_obj, glbls, lcls)
# Update the global scope with the function.
for value in lcls.values():
return value
18 changes: 18 additions & 0 deletions marimo/_save/stubs/module_stub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2025 Marimo. All rights reserved.
from __future__ import annotations

import importlib
from typing import Any

__all__ = ["ModuleStub"]


class ModuleStub:
"""Stub for module objects, storing only the module name."""

def __init__(self, module: Any) -> None:
self.name = module.__name__

def load(self) -> Any:
"""Reload the module by name."""
return importlib.import_module(self.name)
Loading
Loading