diff --git a/marimo/_save/cache.py b/marimo/_save/cache.py index bc1d3fac84e..e52f5969d51 100644 --- a/marimo/_save/cache.py +++ b/marimo/_save/cache.py @@ -1,17 +1,23 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations -import importlib import inspect import re -import textwrap from collections import namedtuple from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, Optional, get_args -from marimo._plugins.ui._core.ui_element import S, T, UIElement +from marimo._plugins.ui._core.ui_element import UIElement from marimo._runtime.context import ContextNotInitializedError, get_context from marimo._runtime.state import SetFunctor +from marimo._save.stubs import ( + CUSTOM_STUBS, + CustomStub, + FunctionStub, + ModuleStub, + UIElementStub, + maybe_register_stub, +) if TYPE_CHECKING: from marimo._ast.visitor import Name @@ -42,45 +48,6 @@ MetaKey = Literal["return", "version"] -class ModuleStub: - def __init__(self, module: Any) -> None: - self.name = module.__name__ - - def load(self) -> Any: - return importlib.import_module(self.name) - - -class FunctionStub: - def __init__(self, function: Any) -> None: - self.code = textwrap.dedent(inspect.getsource(function)) - - def load(self, glbls: dict[str, Any]) -> Any: - # TODO: Fix line cache and associate with the correct module. - code_obj = compile(self.code, "", "exec") - lcls: dict[str, Any] = {} - exec(code_obj, glbls, lcls) - # Update the global scope with the function. - for value in lcls.values(): - return value - - -class UIElementStub: - def __init__(self, element: UIElement[S, T]) -> None: - self.args = element._args - self.cls = element.__class__ - # Ideally only hashable attributes are stored on the subclass level. - defaults = set(self.cls.__new__(self.cls).__dict__.keys()) - defaults |= {"_ctx"} - self.data = { - k: v - for k, v in element.__dict__.items() - if hasattr(v, "__hash__") and k not in defaults - } - - def load(self) -> UIElement[S, T]: - return self.cls.from_args(self.data, self.args) # type: ignore - - # BaseException because "raise _ as e" is utilized. class CacheException(BaseException): pass @@ -188,6 +155,10 @@ def _restore_from_stub_if_needed( value.clear() value.update(result) result = value + elif isinstance(value, CustomStub): + # CustomStub is a placeholder for a custom type, which cannot be + # restored directly. + result = value.load(scope) else: result = value @@ -246,9 +217,19 @@ def update( self.meta[key] = self._convert_to_stub_if_needed(value, memo) def _convert_to_stub_if_needed( - self, value: Any, memo: dict[int, Any] | None = None + self, + value: Any, + memo: dict[int, Any] | None = None, + preserve_pointers: bool = True, ) -> Any: - """Convert objects to stubs if needed, recursively handling collections.""" + """Convert objects to stubs if needed, recursively handling collections. + + Args: + value: The value to convert + memo: Memoization dict to handle cycles + preserve_pointers: If True, modifies containers in-place to preserve + object identity. If False, creates new containers. + """ if memo is None: memo = {} @@ -269,38 +250,76 @@ def _convert_to_stub_if_needed( # tuples are immutable and cannot be recursive, but we still want to # iteratively convert the internal items. result = tuple( - self._convert_to_stub_if_needed(item, memo) for item in value + self._convert_to_stub_if_needed(item, memo, preserve_pointers) + for item in value ) elif isinstance(value, set): - # sets cannot be recursive (require hasable items), but we still - # maintain the original set reference. - result = set( - self._convert_to_stub_if_needed(item, memo) for item in value + # sets cannot be recursive (require hashable items) + converted = set( + self._convert_to_stub_if_needed(item, memo, preserve_pointers) + for item in value ) - value.clear() - value.update(result) - result = value + if preserve_pointers: + value.clear() + value.update(converted) + result = value + else: + result = converted elif isinstance(value, list): - # Store placeholder to handle cycles - memo[obj_id] = value - result = [ - self._convert_to_stub_if_needed(item, memo) for item in value - ] - value.clear() - value.extend(result) - result = value + if preserve_pointers: + # Preserve original list reference + memo[obj_id] = value + converted_list = [ + self._convert_to_stub_if_needed( + item, memo, preserve_pointers + ) + for item in value + ] + value.clear() + value.extend(converted_list) + result = value + else: + # Create new list + result = [] + memo[obj_id] = result + result.extend( + [ + self._convert_to_stub_if_needed( + item, memo, preserve_pointers + ) + for item in value + ] + ) elif isinstance(value, dict): - # Recursively convert dictionary values - memo[obj_id] = value - result = { - k: self._convert_to_stub_if_needed(v, memo) - for k, v in value.items() - } - value.clear() - value.update(result) - result = value + if preserve_pointers: + # Preserve original dict reference + memo[obj_id] = value + converted_dict = { + k: self._convert_to_stub_if_needed( + v, memo, preserve_pointers + ) + for k, v in value.items() + } + value.clear() + value.update(converted_dict) + result = value + else: + # Create new dict + result = {} + memo[obj_id] = result + result.update( + { + k: self._convert_to_stub_if_needed( + v, memo, preserve_pointers + ) + for k, v in value.items() + } + ) + elif type(value) in CUSTOM_STUBS or maybe_register_stub(value): + result = CUSTOM_STUBS[type(value)](value) else: result = value + memo[obj_id] = result return result diff --git a/marimo/_save/hash.py b/marimo/_save/hash.py index 9a9135a6f4e..65dcc8f46bd 100644 --- a/marimo/_save/hash.py +++ b/marimo/_save/hash.py @@ -35,6 +35,7 @@ from marimo._runtime.state import SetFunctor, State from marimo._runtime.watch._path import PathState from marimo._save.cache import Cache, CacheType +from marimo._save.stubs import maybe_get_custom_stub from marimo._types.ids import CellId_t if TYPE_CHECKING: @@ -779,6 +780,9 @@ def serialize_and_dequeue_content_refs( - primitive (bytes, str, numbers.Number, type(None)) - data primitive (e.g. numpy array, torch tensor) - external module definitions (imported anything) + - pure functions (no state, no external dependencies) + - pure containers of the above (list, dict, set, tuple) + - custom types defined in CUSTOM_STUBS Args: refs: A set of reference names unaccounted for. @@ -839,6 +843,8 @@ def serialize_and_dequeue_content_refs( # pinning being the mechanism for invalidation. elif getattr(value, "__module__", "__main__") == "__main__": continue + elif stub := maybe_get_custom_stub(value): + serial_value = stub.to_bytes() # External module that is not a class or function, may be some # container we don't know how to hash. # Note, function cases care caught by is_pure_function diff --git a/marimo/_save/stubs/__init__.py b/marimo/_save/stubs/__init__.py new file mode 100644 index 00000000000..f88b74c9b52 --- /dev/null +++ b/marimo/_save/stubs/__init__.py @@ -0,0 +1,93 @@ +# Copyright 2025 Marimo. All rights reserved. +"""Stub system for cache serialization.""" + +from __future__ import annotations + +from typing import Any, Callable + +from marimo._save.stubs.function_stub import FunctionStub +from marimo._save.stubs.module_stub import ModuleStub +from marimo._save.stubs.pydantic_stub import PydanticStub +from marimo._save.stubs.stubs import ( + CUSTOM_STUBS, + CustomStub, + register_stub, +) +from marimo._save.stubs.ui_element_stub import UIElementStub + +# Track which class names we've already attempted to register +_REGISTERED_NAMES: set[str] = set() + +# Dictionary mapping fully qualified class names to registration functions +STUB_REGISTRATIONS: dict[str, Callable[[Any], None]] = { + "pydantic.main.BaseModel": PydanticStub.register, +} + + +def maybe_register_stub(value: Any) -> bool: + """Lazily register a stub for a value's type if not already registered. + + This allows us to avoid importing third-party packages until they're + actually used in the cache. Walks the MRO to check if any parent class + matches a registered stub type. + + Returns: + True if the value's type is in CUSTOM_STUBS (either already registered + or newly registered), False otherwise. + """ + value_type = type(value) + + # Already registered in CUSTOM_STUBS + if value_type in CUSTOM_STUBS: + return True + + # Walk MRO to find matching base class + try: + mro_list = value_type.mro() + except BaseException: + # Some exotic metaclasses or broken types may raise when calling mro + mro_list = [value_type] + + for cls in mro_list: + if not (hasattr(cls, "__module__") and hasattr(cls, "__name__")): + continue + + cls_name = f"{cls.__module__}.{cls.__name__}" + + if cls_name in STUB_REGISTRATIONS: + if cls_name not in _REGISTERED_NAMES: + _REGISTERED_NAMES.add(cls_name) + STUB_REGISTRATIONS[cls_name](value) + # After registration attempt, check if now in CUSTOM_STUBS + return value_type in CUSTOM_STUBS + + return False + + +def maybe_get_custom_stub(value: Any) -> CustomStub | None: + """Get the registered stub for a value's type, if any. + + Args: + value: The value to get the stub for + + Returns: + A stub instance if registered, None otherwise + """ + # Fallback to custom cases + if maybe_register_stub(value): + value_type = type(value) + if value_type in CUSTOM_STUBS: + return CUSTOM_STUBS[value_type](value) + return None + + +__all__ = [ + "CUSTOM_STUBS", + "CustomStub", + "FunctionStub", + "ModuleStub", + "UIElementStub", + "maybe_register_stub", + "maybe_get_custom_stub", + "register_stub", +] diff --git a/marimo/_save/stubs/function_stub.py b/marimo/_save/stubs/function_stub.py new file mode 100644 index 00000000000..6fbc0c47e2b --- /dev/null +++ b/marimo/_save/stubs/function_stub.py @@ -0,0 +1,25 @@ +# Copyright 2025 Marimo. All rights reserved. +from __future__ import annotations + +import inspect +import textwrap +from typing import Any + +__all__ = ["FunctionStub"] + + +class FunctionStub: + """Stub for function objects, storing the source code.""" + + def __init__(self, function: Any) -> None: + self.code = textwrap.dedent(inspect.getsource(function)) + + def load(self, glbls: dict[str, Any]) -> Any: + """Reconstruct the function by executing its source code.""" + # TODO: Fix line cache and associate with the correct module. + code_obj = compile(self.code, "", "exec") + lcls: dict[str, Any] = {} + exec(code_obj, glbls, lcls) + # Update the global scope with the function. + for value in lcls.values(): + return value diff --git a/marimo/_save/stubs/module_stub.py b/marimo/_save/stubs/module_stub.py new file mode 100644 index 00000000000..8c4c71ccc5d --- /dev/null +++ b/marimo/_save/stubs/module_stub.py @@ -0,0 +1,18 @@ +# Copyright 2025 Marimo. All rights reserved. +from __future__ import annotations + +import importlib +from typing import Any + +__all__ = ["ModuleStub"] + + +class ModuleStub: + """Stub for module objects, storing only the module name.""" + + def __init__(self, module: Any) -> None: + self.name = module.__name__ + + def load(self) -> Any: + """Reload the module by name.""" + return importlib.import_module(self.name) diff --git a/marimo/_save/stubs/pydantic_stub.py b/marimo/_save/stubs/pydantic_stub.py new file mode 100644 index 00000000000..e4ccbaa8409 --- /dev/null +++ b/marimo/_save/stubs/pydantic_stub.py @@ -0,0 +1,111 @@ +# Copyright 2025 Marimo. All rights reserved. +from __future__ import annotations + +from typing import Any + +from marimo._save.stubs.stubs import CustomStub + +__all__ = ["PydanticStub"] + + +class PydanticStub(CustomStub): + """Stub for pydantic BaseModel instances. + + Pydantic models have non-deterministic pickling due to __pydantic_fields_set__ + being a set. This stub ensures deterministic serialization by sorting fields + and preserves complete pydantic state including private and extra fields. + """ + + __slots__ = ( + "model_class", + "pydantic_dict", + "pydantic_extra", + "pydantic_fields_set", + "pydantic_private", + ) + + def __init__(self, model: Any) -> None: + """Initialize stub with pydantic model data. + + Args: + model: A pydantic BaseModel instance + """ + from pydantic_core import PydanticUndefined + + self.model_class = model.__class__ + + # Store pydantic state as individual attributes + self.pydantic_dict = model.__dict__ + self.pydantic_extra = getattr(model, "__pydantic_extra__", None) + + # Sort fields_set for deterministic serialization + self.pydantic_fields_set = sorted( + getattr(model, "__pydantic_fields_set__", set()) + ) + + # Capture private fields, filtering out undefined values + private = getattr(model, "__pydantic_private__", None) + if private: + private = { + k: v for k, v in private.items() if v is not PydanticUndefined + } + self.pydantic_private = private + + def load(self, glbls: dict[str, Any]) -> Any: + """Reconstruct the pydantic model. + + Args: + glbls: Global namespace (unused for pydantic models) + + Returns: + Reconstructed pydantic model instance + """ + del glbls # Unused for pydantic models + # Use model_construct to bypass validation (matches pickle behavior) + instance = self.model_class.model_construct() + + # Reconstruct the state dict for __setstate__ + state = { + "__dict__": self.pydantic_dict, + "__pydantic_extra__": self.pydantic_extra, + "__pydantic_fields_set__": set(self.pydantic_fields_set), + "__pydantic_private__": self.pydantic_private, + } + + # Restore state using pydantic's __setstate__ + if hasattr(instance, "__setstate__"): + instance.__setstate__(state) + else: + # Fallback: manually restore each piece of state + instance.__dict__.update(state["__dict__"]) + if state.get("__pydantic_extra__"): + instance.__pydantic_extra__ = state["__pydantic_extra__"] + instance.__pydantic_fields_set__ = state["__pydantic_fields_set__"] + if state.get("__pydantic_private__"): + instance.__pydantic_private__ = state["__pydantic_private__"] + return instance + + def to_bytes(self) -> bytes: + """Serialize the stub to bytes. + + Returns: + Serialized bytes of the stub + """ + import pickle + + return pickle.dumps( + ( + self.model_class, + self.pydantic_dict, + self.pydantic_extra, + self.pydantic_fields_set, + self.pydantic_private, + ) + ) + + @staticmethod + def get_type() -> type: + """Get the pydantic BaseModel type.""" + from pydantic import BaseModel + + return BaseModel diff --git a/marimo/_save/stubs/stubs.py b/marimo/_save/stubs/stubs.py new file mode 100644 index 00000000000..83fd40943d3 --- /dev/null +++ b/marimo/_save/stubs/stubs.py @@ -0,0 +1,74 @@ +# Copyright 2025 Marimo. All rights reserved. +"""Lazy stub registration for cache serialization.""" + +from __future__ import annotations + +import abc +from typing import Any + +__all__ = ["CustomStub", "CUSTOM_STUBS", "register_stub"] + + +class CustomStub(abc.ABC): + """Base class for custom stubs that can be registered in the cache.""" + + __slots__ = () + + @abc.abstractmethod + def __init__(self, _obj: Any) -> None: + """Initializes the stub with the object to be stubbed.""" + + @abc.abstractmethod + def load(self, glbls: dict[str, Any]) -> Any: + """Loads the stub, restoring the original object.""" + raise NotImplementedError + + @staticmethod + @abc.abstractmethod + def get_type() -> type: + """Get the type this stub handles. + + May raise ImportError if the required package is not available. + """ + raise NotImplementedError + + @abc.abstractmethod + def to_bytes(self) -> bytes: + """Serialize the stub to bytes.""" + raise NotImplementedError + + @classmethod + def register(cls, value: Any) -> None: + """Register this stub for its target type. + + Handles the common registration pattern: get type, check isinstance, + and register the stub. Catches ImportError if the target type's + package is not available. + + Registers both the base type and the specific value's type to handle + subclasses correctly. + """ + try: + target_type = cls.get_type() + if isinstance(value, target_type): + register_stub(target_type, cls) + # Also register the specific subclass type + value_type = type(value) + if value_type != target_type: + register_stub(value_type, cls) + except ImportError: + pass + + +CUSTOM_STUBS: dict[type, type[CustomStub]] = {} + + +def register_stub(cls: type | None, stub: type[CustomStub]) -> None: + """Register a custom stub for a given class type. + + Args: + cls: The class type to register a stub for + stub: The stub class to use for serialization + """ + if cls is not None: + CUSTOM_STUBS[cls] = stub diff --git a/marimo/_save/stubs/ui_element_stub.py b/marimo/_save/stubs/ui_element_stub.py new file mode 100644 index 00000000000..8c935b072f4 --- /dev/null +++ b/marimo/_save/stubs/ui_element_stub.py @@ -0,0 +1,32 @@ +# Copyright 2025 Marimo. All rights reserved. +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +if TYPE_CHECKING: + from marimo._plugins.ui._core.ui_element import UIElement + +__all__ = ["UIElementStub"] + +S = TypeVar("S") +T = TypeVar("T") + + +class UIElementStub(Generic[S, T]): + """Stub for UIElement objects, storing args and hashable data.""" + + def __init__(self, element: UIElement[S, T]) -> None: + self.args = element._args + self.cls = element.__class__ + # Ideally only hashable attributes are stored on the subclass level. + defaults = set(self.cls.__new__(self.cls).__dict__.keys()) + defaults |= {"_ctx"} + self.data = { + k: v + for k, v in element.__dict__.items() + if hasattr(v, "__hash__") and k not in defaults + } + + def load(self) -> UIElement[S, T]: + """Reconstruct the UIElement from stored data.""" + return self.cls.from_args(self.data, self.args) # type: ignore diff --git a/tests/_save/stubs/__init__.py b/tests/_save/stubs/__init__.py new file mode 100644 index 00000000000..82fab87cfdd --- /dev/null +++ b/tests/_save/stubs/__init__.py @@ -0,0 +1 @@ +# Copyright 2025 Marimo. All rights reserved. diff --git a/tests/_save/stubs/test_pydantic_stub.py b/tests/_save/stubs/test_pydantic_stub.py new file mode 100644 index 00000000000..5fa86d3ddf1 --- /dev/null +++ b/tests/_save/stubs/test_pydantic_stub.py @@ -0,0 +1,200 @@ +# Copyright 2025 Marimo. All rights reserved. + +from __future__ import annotations + +import pytest + +pytest.importorskip("pydantic_core") + + +from marimo._save.stubs.pydantic_stub import PydanticStub + + +class TestPydanticStub: + """Tests for PydanticStub serialization and deserialization.""" + + @staticmethod + def test_basic_model() -> None: + """Test stub with basic pydantic model.""" + from pydantic import BaseModel + + class BasicModel(BaseModel): + name: str + value: int + + model = BasicModel(name="test", value=42) + + # Create stub + stub = PydanticStub(model) + + # Verify stub attributes + assert stub.model_class == BasicModel + assert stub.pydantic_dict == {"name": "test", "value": 42} + assert stub.pydantic_fields_set == ["name", "value"] + assert stub.pydantic_extra is None + assert stub.pydantic_private is None + + # Restore and verify + restored = stub.load({}) + assert isinstance(restored, BasicModel) + assert restored.name == model.name + assert restored.value == model.value + assert restored.model_fields_set == model.model_fields_set + + @staticmethod + def test_model_with_private_fields() -> None: + """Test stub with model containing private fields.""" + from pydantic import BaseModel, PrivateAttr + + class ModelWithPrivate(BaseModel): + name: str + _private: int = PrivateAttr(default=0) + _secret: str = PrivateAttr(default="secret") + + model = ModelWithPrivate(name="test") + model._private = 99 + model._secret = "my_secret" + + # Create stub + stub = PydanticStub(model) + + # Verify private fields captured + assert stub.pydantic_private is not None + assert "_private" in stub.pydantic_private + assert "_secret" in stub.pydantic_private + assert stub.pydantic_private["_private"] == 99 + assert stub.pydantic_private["_secret"] == "my_secret" + + # Restore and verify private fields + restored = stub.load({}) + assert restored._private == model._private + assert restored._secret == model._secret + + @staticmethod + def test_model_with_extra_fields() -> None: + """Test stub with model allowing extra fields.""" + from pydantic import BaseModel, ConfigDict + + class ModelWithExtra(BaseModel): + model_config = ConfigDict(extra="allow") + name: str + + model = ModelWithExtra(name="test", extra_field="bonus", another=123) + + # Create stub + stub = PydanticStub(model) + + # Verify extra fields captured + assert stub.pydantic_extra is not None + assert "extra_field" in stub.pydantic_extra + assert "another" in stub.pydantic_extra + assert stub.pydantic_extra["extra_field"] == "bonus" + assert stub.pydantic_extra["another"] == 123 + + # Restore and verify extra fields + restored = stub.load({}) + assert restored.__pydantic_extra__ == model.__pydantic_extra__ + # Access extra fields via __pydantic_extra__ + assert restored.__pydantic_extra__["extra_field"] == "bonus" + assert restored.__pydantic_extra__["another"] == 123 + + @staticmethod + def test_complex_model() -> None: + """Test stub with model having all features.""" + from pydantic import BaseModel, ConfigDict, PrivateAttr + + class ComplexModel(BaseModel): + model_config = ConfigDict(extra="allow") + name: str + value: int + _private: str = PrivateAttr(default="secret") + + model = ComplexModel(name="test", value=42, extra="bonus") + model._private = "my_secret" + + # Create stub + stub = PydanticStub(model) + + # Verify all state captured + assert stub.pydantic_dict == {"name": "test", "value": 42} + assert "extra" in stub.pydantic_extra + assert stub.pydantic_private["_private"] == "my_secret" + assert "extra" in stub.pydantic_fields_set + assert "name" in stub.pydantic_fields_set + assert "value" in stub.pydantic_fields_set + + # Restore and verify everything + restored = stub.load({}) + assert restored.name == model.name + assert restored.value == model.value + assert restored._private == model._private + assert restored.__pydantic_extra__ == model.__pydantic_extra__ + assert restored.model_fields_set == model.model_fields_set + + @staticmethod + def test_deterministic_fields_set() -> None: + """Test that fields_set is sorted for deterministic serialization.""" + from pydantic import BaseModel + + class Model(BaseModel): + a: int + z: int + m: int + + # Create multiple instances with different field order + model1 = Model(z=1, a=2, m=3) + model2 = Model(a=2, m=3, z=1) + + stub1 = PydanticStub(model1) + stub2 = PydanticStub(model2) + + # fields_set should be sorted and identical + assert stub1.pydantic_fields_set == stub2.pydantic_fields_set + assert stub1.pydantic_fields_set == ["a", "m", "z"] + + @staticmethod + def test_nested_models() -> None: + """Test stub with nested pydantic models.""" + from pydantic import BaseModel + + class InnerModel(BaseModel): + inner_value: int + + class OuterModel(BaseModel): + name: str + inner: InnerModel + + inner = InnerModel(inner_value=99) + outer = OuterModel(name="test", inner=inner) + + # Create stub + stub = PydanticStub(outer) + + # Restore and verify nested structure + restored = stub.load({}) + assert restored.name == outer.name + assert isinstance(restored.inner, InnerModel) + assert restored.inner.inner_value == outer.inner.inner_value + + @staticmethod + def test_partial_fields_set() -> None: + """Test model where not all fields are set.""" + from pydantic import BaseModel + + class Model(BaseModel): + required: str + optional: int = 42 + + # Only set required field + model = Model(required="test") + + stub = PydanticStub(model) + + # Only required field should be in fields_set + assert stub.pydantic_fields_set == ["required"] + + # Restore and verify + restored = stub.load({}) + assert restored.required == model.required + assert restored.optional == model.optional + assert restored.model_fields_set == {"required"} diff --git a/tests/_save/stubs/test_stubs.py b/tests/_save/stubs/test_stubs.py new file mode 100644 index 00000000000..e948c48669a --- /dev/null +++ b/tests/_save/stubs/test_stubs.py @@ -0,0 +1,295 @@ +# Copyright 2025 Marimo. All rights reserved. + +from __future__ import annotations + +from typing import Any + +import pytest + +from marimo._save.stubs import ( + CUSTOM_STUBS, + STUB_REGISTRATIONS, + CustomStub, + maybe_register_stub, + register_stub, +) + + +class TestStubRegistration: + """Tests for stub registration mechanism.""" + + @staticmethod + def test_stub_registrations_dict() -> None: + """Test that STUB_REGISTRATIONS contains expected entries.""" + # Should have pydantic.main.BaseModel + assert "pydantic.main.BaseModel" in STUB_REGISTRATIONS + + @staticmethod + @pytest.mark.skipif( + not pytest.importorskip( + "pydantic_core", reason="pydantic not installed" + ), + reason="pydantic required", + ) + def test_maybe_register_stub_pydantic() -> None: + """Test registering a pydantic model.""" + from pydantic import BaseModel + + from marimo._save.stubs import _REGISTERED_NAMES + from marimo._save.stubs.pydantic_stub import PydanticStub + + class TestModel(BaseModel): + value: int + + # Clear any existing registration + if BaseModel in CUSTOM_STUBS: + del CUSTOM_STUBS[BaseModel] + if TestModel in CUSTOM_STUBS: + del CUSTOM_STUBS[TestModel] + # Also clear registered names + _REGISTERED_NAMES.discard("pydantic.main.BaseModel") + + model = TestModel(value=42) + + # Register the stub + result = maybe_register_stub(model) + + # Should return True (registered) + assert result is True + + # BaseModel should now be in CUSTOM_STUBS + assert BaseModel in CUSTOM_STUBS + assert CUSTOM_STUBS[BaseModel] is PydanticStub + + # Subclass should also be registered + assert TestModel in CUSTOM_STUBS + assert CUSTOM_STUBS[TestModel] is PydanticStub + + @staticmethod + @pytest.mark.skipif( + not pytest.importorskip( + "pydantic_core", reason="pydantic not installed" + ), + reason="pydantic required", + ) + def test_maybe_register_stub_already_registered() -> None: + """Test that already registered stubs return True immediately.""" + from pydantic import BaseModel + + from marimo._save.stubs import _REGISTERED_NAMES + + class TestModel(BaseModel): + value: int + + # Ensure clean state + if BaseModel in CUSTOM_STUBS: + del CUSTOM_STUBS[BaseModel] + if TestModel in CUSTOM_STUBS: + del CUSTOM_STUBS[TestModel] + _REGISTERED_NAMES.discard("pydantic.main.BaseModel") + + model = TestModel(value=42) + + # First registration + result1 = maybe_register_stub(model) + assert result1 is True + + # Verify it's registered + assert BaseModel in CUSTOM_STUBS + assert TestModel in CUSTOM_STUBS + + # Second call should return True immediately (already in CUSTOM_STUBS) + result2 = maybe_register_stub(model) + assert result2 is True + + @staticmethod + def test_maybe_register_stub_no_match() -> None: + """Test that non-matching types return False.""" + + class PlainClass: + pass + + obj = PlainClass() + + # Should return False (no registration) + result = maybe_register_stub(obj) + assert result is False + + # Should not be in CUSTOM_STUBS + assert PlainClass not in CUSTOM_STUBS + + @staticmethod + @pytest.mark.skipif( + not pytest.importorskip("pydantic", reason="pydantic not installed"), + reason="pydantic required", + ) + def test_mro_traversal() -> None: + """Test that MRO traversal finds base class registration.""" + from pydantic import BaseModel + + from marimo._save.stubs import _REGISTERED_NAMES + + # Clear registrations + if BaseModel in CUSTOM_STUBS: + del CUSTOM_STUBS[BaseModel] + _REGISTERED_NAMES.discard("pydantic.main.BaseModel") + + class Parent(BaseModel): + x: int + + class Child(Parent): + y: int + + if Child in CUSTOM_STUBS: + del CUSTOM_STUBS[Child] + + child = Child(x=1, y=2) + + # Should register via MRO (finds BaseModel in parent chain) + result = maybe_register_stub(child) + assert result is True + + # Both BaseModel and Child should be registered + assert BaseModel in CUSTOM_STUBS + assert Child in CUSTOM_STUBS + + +class TestCustomStubBase: + """Tests for CustomStub base class.""" + + @staticmethod + def test_abstract_methods() -> None: + """Test that CustomStub has required abstract methods.""" + # Should not be able to instantiate directly + with pytest.raises(TypeError): + CustomStub() # type: ignore + + @staticmethod + def test_register_classmethod() -> None: + """Test that register is a classmethod.""" + assert hasattr(CustomStub, "register") + assert callable(CustomStub.register) + + @staticmethod + def test_get_type_staticmethod() -> None: + """Test that get_type is a static method.""" + assert hasattr(CustomStub, "get_type") + + @staticmethod + def test_slots() -> None: + """Test that CustomStub has __slots__ defined.""" + assert hasattr(CustomStub, "__slots__") + assert CustomStub.__slots__ == () + + +class TestRegisterStub: + """Tests for register_stub function.""" + + @staticmethod + def test_register_stub_basic() -> None: + """Test basic stub registration.""" + + class DummyType: + pass + + class DummyStub(CustomStub): + __slots__ = ("obj",) + + def __init__(self, obj: Any) -> None: + self.obj = obj + + def load(self, glbls: dict[str, Any]) -> Any: + del glbls # Unused + return self.obj + + @staticmethod + def get_type() -> type: + return DummyType + + # Register + register_stub(DummyType, DummyStub) + + # Should be in CUSTOM_STUBS + assert DummyType in CUSTOM_STUBS + assert CUSTOM_STUBS[DummyType] is DummyStub + + # Clean up + del CUSTOM_STUBS[DummyType] + + @staticmethod + def test_register_stub_none() -> None: + """Test registering with None type does nothing.""" + + class DummyStub(CustomStub): + __slots__ = () + + def __init__(self, obj: Any) -> None: + pass + + def load(self, glbls: dict[str, Any]) -> Any: + del glbls # Unused + return None + + @staticmethod + def get_type() -> type: + return object + + # Register with None + register_stub(None, DummyStub) + + # Should not add None to CUSTOM_STUBS + assert None not in CUSTOM_STUBS + + +class TestStubIntegration: + """Integration tests for stub system.""" + + @staticmethod + @pytest.mark.skipif( + not pytest.importorskip( + "pydantic_core", reason="pydantic not installed" + ), + reason="pydantic required", + ) + def test_cache_integration() -> None: + """Test stub integration with cache system.""" + from pydantic import BaseModel + + from marimo._save.cache import Cache + from marimo._save.stubs import _REGISTERED_NAMES + from marimo._save.stubs.pydantic_stub import PydanticStub + + class TestModel(BaseModel): + name: str + value: int + + # Clear any existing registration to ensure clean test + if BaseModel in CUSTOM_STUBS: + del CUSTOM_STUBS[BaseModel] + if TestModel in CUSTOM_STUBS: + del CUSTOM_STUBS[TestModel] + _REGISTERED_NAMES.discard("pydantic.main.BaseModel") + + model = TestModel(name="test", value=42) + + # Create cache + cache = Cache.empty( + key=type("HashKey", (), {"hash": "test", "cache_type": "Pure"})(), + defs={"x"}, + stateful_refs=set(), + ) + + # Convert to stub (should trigger registration and conversion) + converted = cache._convert_to_stub_if_needed(model, {}) + + # Should be a PydanticStub + assert isinstance(converted, PydanticStub) + + # Restore from stub + restored = cache._restore_from_stub_if_needed(converted, {}, {}) + + # Should match original + assert isinstance(restored, TestModel) + assert restored.name == model.name + assert restored.value == model.value + assert restored.model_fields_set == model.model_fields_set diff --git a/tests/_save/test_hash.py b/tests/_save/test_hash.py index 1b18d7747ab..5fc822c19b6 100644 --- a/tests/_save/test_hash.py +++ b/tests/_save/test_hash.py @@ -985,7 +985,54 @@ def two(MockLoader, persistent_cache, expected_hash, pl) -> tuple[int]: return (two,) -# Skip for now, as the local branch is cache busting +class TestCustomHash: + @staticmethod + @pytest.mark.skipif( + not DependencyManager.has("pydantic"), + reason="optional dependencies not installed", + ) + @pytest.mark.skipif( + "sys.version_info < (3, 12) or sys.version_info >= (3, 13)" + ) + async def test_pydantic_model_hash(app: App) -> None: + with app.setup: + import pydantic + + import marimo as mo + + @app.class_definition + class Model(pydantic.BaseModel): + a: int + b: str + + @app.function + @mo.cache + def use_model(model: Model) -> tuple[int, str]: + return model.a, model.b + + @app.cell + def _check_deterministic() -> None: + assert use_model.hits == 0 + model = Model(a=1, b="test") + a, b = use_model(model) + initial_hash = use_model._last_hash + a, b = use_model(model) # Cache hit + assert use_model.hits == 1 + model_copy = Model(a=1, b="test") + A, B = use_model(model_copy) # Cache hit, different instance + assert use_model.hits == 2 + assert (a, b) == (A, B) == (1, "test") + assert use_model._last_hash == initial_hash + + @app.cell + def _check_different(a: int, b: str, initial_hash: str) -> None: + diff_model = Model(a=2, b="test") + c, d = use_model(diff_model) + assert use_model.hits == 2 + assert (c, d) != (a, b) + assert initial_hash != use_model._last_hash + + class TestDynamicHash: @staticmethod async def test_transitive_state_hash(