Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions marimo/_save/loaders/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ class Loader(ABC):
"""

def __init__(self, name: str) -> None:
self.name = name
# Remove * since used to prevent shadowing in scoped cases.
self.name = name.strip("*")
self._hits = 0
self._time_saved = 0.0

Expand Down Expand Up @@ -212,7 +213,7 @@ def __init__(
self.store = DEFAULT_STORE()

# Limited character set for path for windows compatibility
self.name = re.sub(r"[^a-zA-Z0-9 _-]", "_", name)
self.name = re.sub(r"[^a-zA-Z0-9 _-]", "_", self.name)
self.suffix = suffix

def build_path(self, key: HashKey) -> Path:
Expand Down
39 changes: 39 additions & 0 deletions tests/_save/test_external_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,45 @@


class TestDecoratorImports:
@staticmethod
def test_import_alias_hash_path(app) -> None:
"""Test that imported cached functions with module aliases have correct hash paths."""
for module in list(sys.modules.keys()):
if module.startswith("tests._save.external_decorators"):
del sys.modules[module]

with app.setup:
import marimo as mo
import tests._save.external_decorators.module_1 as my_module
from tests._save.external_decorators.transitive_imports import (
doesnt_have_namespace as external_func,
)

@app.function
@mo.cache
def doesnt_have_namespace() -> None:
# Just replicating the function in external_func
return my_module.__version__

@app.cell
def check_hash_paths() -> None:
local_func = doesnt_have_namespace
# Both functions access the same module with the same alias
external_result = external_func()
local_result = local_func()

# Results should be the same (both return "1.0.0")
assert external_result == local_result == "1.0.0"

# Hashes should be equal (same code, same module accessed)
external_name = external_func._loader().name
local_name = local_func._loader().name

assert external_name == local_name, (
f"Hashes should be equal for same code and module, "
f"got {external_name} != {local_name}"
)

@staticmethod
def test_has_shared_import(app) -> None:
with app.setup:
Expand Down
Loading