From cf47fb047cba098b329bdafead7cda6c9eaf6f88 Mon Sep 17 00:00:00 2001 From: Jan Scheffler Date: Fri, 10 Oct 2025 19:25:47 +0200 Subject: [PATCH 1/3] refactor: extract cache execution helpers for reuse Extract common execution logic into _prepare_call_execution() and _finalize_cache_update() helper methods. This reduces duplication and prepares the codebase for async cache support. - Add _prepare_call_execution() to build execution context - Add _finalize_cache_update() to save cache results - Refactor __call__() to use new helpers --- marimo/_save/save.py | 138 ++++++++++++++++++++++++------------------- 1 file changed, 78 insertions(+), 60 deletions(-) diff --git a/marimo/_save/save.py b/marimo/_save/save.py index b09af7e84c7..daf6379bb98 100644 --- a/marimo/_save/save.py +++ b/marimo/_save/save.py @@ -249,6 +249,81 @@ def _build_base_block( external=self._external, ) + def _prepare_call_execution( + self, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[dict[str, Any], Any, Any]: + """Prepare execution context and create cache attempt. + + Returns tuple of (scope, ctx, attempt) needed for cache execution. + """ + # Build base block if needed (for external/late binding) + if self.base_block is None: + assert self._external, UNEXPECTED_FAILURE_BOILERPLATE + assert self.__wrapped__ is not None, UNEXPECTED_FAILURE_BOILERPLATE + graph = graph_from_scope(self.scope) + cell_id = get_cell_id_from_scope(self.__wrapped__, self.scope) + self.base_block = self._build_base_block( + self.__wrapped__, graph, cell_id + ) + + # Rewrite scoped args to prevent shadowed variables + arg_dict = {f"{ARG_PREFIX}{k}": v for (k, v) in zip(self._args, args)} + kwargs_copy = {f"{ARG_PREFIX}{k}": v for (k, v) in kwargs.items()} + # If the function has varargs, we need to capture them as well. + if self._var_arg is not None: + arg_dict[f"{ARG_PREFIX}{self._var_arg}"] = args[len(self._args) :] + if self._var_kwarg is not None: + # NB: kwargs are always a dict, so we can just copy them. + arg_dict[f"{ARG_PREFIX}{self._var_kwarg}"] = kwargs.copy() + + # Capture the call case + ctx = safe_get_context() + glbls: dict[str, Any] = {} + if ctx is not None: + glbls = ctx.globals + # Typically, scope is overridden by globals (scope is just a snapshot of + # the current frame, which may have changed)- however in an external + # context, scope is the only source of glbls (the definition should be + # unaware of working memory). + scope = { + **self.scope, + } + if not self._external: + scope = { + **scope, + **glbls, + } + scope = { + **scope, + **arg_dict, + **kwargs_copy, + **(self._bound or {}), + } + assert self._loader is not None, UNEXPECTED_FAILURE_BOILERPLATE + attempt = content_cache_attempt_from_base( + self.base_block, + scope, + self.loader, + scoped_refs=self.scoped_refs, + required_refs=set([f"{ARG_PREFIX}{k}" for k in self._args]), + as_fn=True, + ) + + return scope, ctx, attempt + + def _finalize_cache_update( + self, + attempt: Any, + response: Any, + runtime: float, + scope: dict[str, Any], + ) -> None: + """Update and save cache with execution results.""" + # stateful variables may be global + scope = {k: v for k, v in scope.items() if k in attempt.stateful_refs} + attempt.update(scope, meta={"return": response, "runtime": runtime}) + self.loader.save_cache(attempt) + @property def misses(self) -> int: if self._loader is None: @@ -320,58 +395,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: self._set_context(args[0]) return self - if self.base_block is None: - assert self._external, UNEXPECTED_FAILURE_BOILERPLATE - # We only build the graph on invocation because toplevel functions - # can be defined out of order. - graph = graph_from_scope(self.scope) - cell_id = get_cell_id_from_scope(self.__wrapped__, self.scope) - self.base_block = self._build_base_block( - self.__wrapped__, graph, cell_id - ) - - # Rewrite scoped args to prevent shadowed variables - arg_dict = {f"{ARG_PREFIX}{k}": v for (k, v) in zip(self._args, args)} - kwargs_copy = {f"{ARG_PREFIX}{k}": v for (k, v) in kwargs.items()} - # If the function has varargs, we need to capture them as well. - if self._var_arg is not None: - arg_dict[f"{ARG_PREFIX}{self._var_arg}"] = args[len(self._args) :] - if self._var_kwarg is not None: - # NB: kwargs are always a dict, so we can just copy them. - arg_dict[f"{ARG_PREFIX}{self._var_kwarg}"] = kwargs.copy() - - # Capture the call case - ctx = safe_get_context() - glbls: dict[str, Any] = {} - if ctx is not None: - glbls = ctx.globals - # Typically, scope is overridden by globals (scope is just a snapshot of - # the current frame, which may have changed)- however in an external - # context, scope is the only source of glbls (the definition should be - # unaware of working memory). - scope = { - **self.scope, - } - if not self._external: - scope = { - **scope, - **glbls, - } - scope = { - **scope, - **arg_dict, - **kwargs_copy, - **(self._bound or {}), - } - assert self._loader is not None, UNEXPECTED_FAILURE_BOILERPLATE - attempt = content_cache_attempt_from_base( - self.base_block, - scope, - self.loader, - scoped_refs=self.scoped_refs, - required_refs=set([f"{ARG_PREFIX}{k}" for k in self._args]), - as_fn=True, - ) + # Prepare execution context + scope, ctx, attempt = self._prepare_call_execution(args, kwargs) failed = False self._last_hash = attempt.hash @@ -384,14 +409,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: response = self.__wrapped__(*args, **kwargs) runtime = time.time() - start_time - # stateful variables may be global - scope = { - k: v for k, v in scope.items() if k in attempt.stateful_refs - } - attempt.update( - scope, meta={"return": response, "runtime": runtime} - ) - self.loader.save_cache(attempt) + self._finalize_cache_update(attempt, response, runtime, scope) except Exception as e: failed = True raise e From e250061e695413f0908d8cb1cfbc59e330cf0d31 Mon Sep 17 00:00:00 2001 From: Jan Scheffler Date: Fri, 10 Oct 2025 19:34:02 +0200 Subject: [PATCH 2/3] feat: add async cache support with task deduplication MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add support for async/await functions with @cache, @lru_cache, and @persistent_cache decorators. Implements task deduplication to prevent race conditions when multiple concurrent calls are made with the same arguments. Implementation: - Use type(self) instead of _cache_call() in __get__() for proper subclass dispatch - Detect async functions and dispatch to _cache_call_async variant - Implement task deduplication using asyncio.Task caching with WeakKeyDictionary - Prevent concurrent duplicate executions via _pending_executions dict - Release lock before awaiting tasks to avoid deadlocks Testing: - Add 15 comprehensive async cache tests - Test concurrent deduplication (5 concurrent calls → 1 execution) - All 115 tests passing (100 sync + 15 async) --- marimo/_save/save.py | 138 ++++++++- tests/_save/test_cache.py | 621 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 756 insertions(+), 3 deletions(-) diff --git a/marimo/_save/save.py b/marimo/_save/save.py index daf6379bb98..bd6e9795e8f 100644 --- a/marimo/_save/save.py +++ b/marimo/_save/save.py @@ -2,12 +2,15 @@ from __future__ import annotations import ast +import asyncio import functools import inspect import io import sys +import threading import time import traceback +import weakref from collections import abc # NB: maxsize follows functools.cache, but renamed max_size outside of drop-in @@ -360,7 +363,7 @@ def __get__( "(have you wrapped a function?)" ) # Bind to the instance - copy = _cache_call( + copy = type(self)( None, self._loader_partial, pin_modules=self.pin_modules, @@ -389,6 +392,18 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: raise TypeError( "cache() takes at most 1 argument (expecting function)" ) + # Check if the function is async - if so, create async variant + if inspect.iscoroutinefunction(args[0]): + async_copy = _cache_call_async( + None, + self._loader_partial, + pin_modules=self.pin_modules, + hash_type=self.hash_type, + ) + async_copy._frame_offset = self._frame_offset + async_copy._frame_offset -= 4 + async_copy._set_context(args[0]) + return async_copy # Remove the additional frames from singledispatch, because invoking # the function directly. self._frame_offset -= 4 @@ -421,6 +436,116 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return response +class _cache_call_async(_cache_call): + """Async variant of _cache_call for async/await functions. + + Inherits all caching logic from _cache_call but provides an async + __call__ method that properly awaits coroutines. Used automatically + when @cache decorates an async function. + + Implements task deduplication: concurrent calls with the same arguments + will share the same execution, preventing duplicate work. + """ + + # Track pending executions per cache instance to prevent race conditions + # WeakKeyDictionary ensures instances are cleaned up when garbage collected + # Key: cache instance, Value: dict of {cache_key: Task} + _pending_executions: weakref.WeakKeyDictionary[ + _cache_call_async, dict[str, asyncio.Task[Any]] + ] = weakref.WeakKeyDictionary() + _pending_lock = threading.Lock() + + async def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Capture the deferred call case + if self.__wrapped__ is None: + if len(args) != 1: + raise TypeError( + "cache() takes at most 1 argument (expecting function)" + ) + # Remove the additional frames from singledispatch, because invoking + # the function directly. + self._frame_offset -= 4 + self._set_context(args[0]) + return self + + # Prepare execution context to get cache key + scope, ctx, attempt = self._prepare_call_execution(args, kwargs) + cache_key = attempt.hash + + # Check for pending execution (task deduplication) + existing_task = None + with self._pending_lock: + if self not in self._pending_executions: + self._pending_executions[self] = {} + pending = self._pending_executions[self] + + if cache_key in pending: + # Another coroutine is already executing this - save the task + existing_task = pending[cache_key] + + # Await the existing task AFTER releasing the lock to avoid deadlock + if existing_task is not None: + return await existing_task + + # No pending execution - create a new task + task = asyncio.create_task( + self._execute_cached(scope, ctx, attempt, args, kwargs) + ) + + with self._pending_lock: + pending[cache_key] = task + + try: + result = await task + finally: + # Clean up completed task + with self._pending_lock: + if cache_key in pending: + del pending[cache_key] + # Clean up empty instance dict (WeakKeyDictionary handles instance cleanup) + if not pending and self in self._pending_executions: + del self._pending_executions[self] + + return result + + async def _execute_cached( + self, + scope: dict[str, Any], + ctx: Any, + attempt: Any, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Any: + """Execute the cached function and update cache. + + This is called by a single task even when multiple concurrent + callers request the same computation. + """ + assert self.__wrapped__ is not None, UNEXPECTED_FAILURE_BOILERPLATE + failed = False + self._last_hash = attempt.hash + try: + if attempt.hit: + attempt.restore(scope) + return attempt.meta["return"] + + start_time = time.time() + # Await the coroutine to get the actual result + response = await self.__wrapped__(*args, **kwargs) + runtime = time.time() - start_time + + self._finalize_cache_update(attempt, response, runtime, scope) + except Exception as e: + failed = True + raise e + finally: + # NB. Exceptions raise their own side effects. + if ctx and not failed: + ctx.cell_lifecycle_registry.add(SideEffect(attempt.hash)) + self._misses += 1 + return response + + class _cache_context(SkipContext, CacheContext): def __init__( self, @@ -617,7 +742,7 @@ def _invoke_call( *args: Any, frame_offset: int = 1, **kwargs: Any, -) -> _cache_call: +) -> Union[_cache_call, _cache_call_async]: if isinstance(loader, Loader): raise TypeError( "A loader instance cannot be passed to cache directly. " @@ -637,6 +762,13 @@ def _invoke_call( "Invalid loader type. " f"Expected a loader partial, got {type(loader)}." ) + + # Check if the function is async + if _fn is not None and inspect.iscoroutinefunction(_fn): + return _cache_call_async( + _fn, loader, *args, frame_offset=frame_offset + 1, **kwargs + ) + return _cache_call( _fn, loader, *args, frame_offset=frame_offset + 1, **kwargs ) @@ -663,7 +795,7 @@ def _invoke_call_fn( *args: Any, frame_offset: int = 1, **kwargs: Any, -) -> _cache_call: +) -> Union[_cache_call, _cache_call_async]: return _invoke_call( _fn, loader, *args, frame_offset=frame_offset + 1, **kwargs ) diff --git a/tests/_save/test_cache.py b/tests/_save/test_cache.py index dc48299fde4..23062fcc1bd 100644 --- a/tests/_save/test_cache.py +++ b/tests/_save/test_cache.py @@ -2736,3 +2736,624 @@ def slow_func(x): info3 = k.globals["info3"] assert info3.time_saved > first_saving assert info3.hits == 2 + + +class TestAsyncCacheDecorator: + """Tests for async function caching support.""" + + async def test_basic_async_cache( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + """Test basic async function caching with @cache decorator.""" + await k.run( + [ + exec_req.get( + """ + from marimo._save.save import cache + + @cache + async def async_fib(n): + if n <= 1: + return n + a = await async_fib(n - 1) + b = await async_fib(n - 2) + return a + b + + a = await async_fib(5) + b = await async_fib(10) + """ + ), + ] + ) + + assert not k.stderr.messages, k.stderr + assert k.globals["a"] == 5 + assert k.globals["b"] == 55 + # Should have cache hits like the sync version + assert k.globals["async_fib"].hits == 9 + + async def test_async_lru_cache( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + """Test async function caching with @lru_cache decorator.""" + await k.run( + [ + exec_req.get( + """ + from marimo._save.save import lru_cache + + @lru_cache(maxsize=2) + async def async_fib(n): + if n <= 1: + return n + a = await async_fib(n - 1) + b = await async_fib(n - 2) + return a + b + + a = await async_fib(5) + b = await async_fib(10) + """ + ), + ] + ) + + assert not k.stderr.messages, k.stderr + assert k.globals["a"] == 5 + assert k.globals["b"] == 55 + # Should have more hits with smaller cache + assert k.globals["async_fib"].hits == 14 + + async def test_async_persistent_cache( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + """Test async function with @persistent_cache decorator.""" + await k.run( + [ + exec_req.get( + """ + import asyncio + from marimo._save.save import persistent_cache + from marimo._save.loaders import MemoryLoader + + @persistent_cache(_loader=MemoryLoader) + async def async_compute(x): + await asyncio.sleep(0.001) # Simulate async work + return x * 2 + + result1 = await async_compute(5) + result2 = await async_compute(5) # Should hit cache + result3 = await async_compute(10) # Should miss + """ + ), + ] + ) + + assert not k.stderr.messages, k.stderr + assert k.globals["result1"] == 10 + assert k.globals["result2"] == 10 + assert k.globals["result3"] == 20 + assert k.globals["async_compute"].hits == 1 + + async def test_async_cache_with_external_deps( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + """Test async cached function with external dependencies.""" + await k.run( + [ + exec_req.get( + """ + import asyncio + from marimo._save.save import cache + + external_value = 10 + + @cache + async def async_add(x): + await asyncio.sleep(0.001) + return x + external_value + + result1 = await async_add(5) + result2 = await async_add(5) # Should hit cache + """ + ), + ] + ) + + assert not k.stderr.messages, k.stderr + assert k.globals["result1"] == 15 + assert k.globals["result2"] == 15 + assert k.globals["async_add"].hits == 1 + + async def test_async_cache_method( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + """Test async method caching.""" + await k.run( + [ + exec_req.get( + """ + import asyncio + from marimo._save.save import cache + + class AsyncCalculator: + def __init__(self, base): + self.base = base + + @cache + async def calculate(self, x): + await asyncio.sleep(0.001) + return self.base + x + + calc = AsyncCalculator(10) + result1 = await calc.calculate(5) + result2 = await calc.calculate(5) # Should hit cache + result3 = await calc.calculate(7) # Should miss + """ + ), + ] + ) + + assert not k.stderr.messages, k.stderr + assert k.globals["result1"] == 15 + assert k.globals["result2"] == 15 + assert k.globals["result3"] == 17 + + async def test_async_cache_static_method( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + """Test async static method caching.""" + await k.run( + [ + exec_req.get( + """ + import asyncio + from marimo._save.save import cache + + class AsyncMath: + @staticmethod + @cache + async def multiply(x, y): + await asyncio.sleep(0.001) + return x * y + + result1 = await AsyncMath.multiply(3, 4) + result2 = await AsyncMath.multiply(3, 4) # Should hit cache + result3 = await AsyncMath.multiply(5, 6) # Should miss + """ + ), + ] + ) + + assert not k.stderr.messages, k.stderr + assert k.globals["result1"] == 12 + assert k.globals["result2"] == 12 + assert k.globals["result3"] == 30 + + async def test_async_cache_with_await_in_notebook( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + """Test async function that can be awaited directly in notebook context.""" + await k.run( + [ + exec_req.get( + """ + import asyncio + from marimo._save.save import cache + + @cache + async def fetch_data(n): + await asyncio.sleep(0.001) + return n * 100 + + # Use direct await since marimo supports top-level await + result = await fetch_data(5) + """ + ), + ] + ) + + assert not k.stderr.messages, k.stderr + assert k.globals["result"] == 500 + + async def test_async_cache_info_and_clear( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + """Verify cache_info() and cache_clear() work correctly with async functions.""" + await k.run( + [ + exec_req.get( + """ + from marimo._save.save import cache, lru_cache + + @cache + async def async_func(x): + return x * 2 + + @lru_cache(maxsize=2) + async def async_lru_func(x): + return x * 3 + + # Test basic cache_info + info0 = async_func.cache_info() + await async_func(1) + await async_func(1) # hit + await async_func(2) # miss + info1 = async_func.cache_info() + + # Test lru_cache maxsize + lru_info = async_lru_func.cache_info() + + # Test cache_clear + async_func.cache_clear() + info2 = async_func.cache_info() + """ + ), + ] + ) + + assert not k.stderr.messages, k.stderr + + # Initial state + info0 = k.globals["info0"] + assert info0.hits == 0 + assert info0.misses == 0 + assert info0.maxsize is None + assert info0.currsize == 0 + assert info0.time_saved == 0.0 + + # After calls + info1 = k.globals["info1"] + assert info1.hits == 1 + assert info1.misses == 2 + assert info1.currsize == 2 + + # LRU maxsize + lru_info = k.globals["lru_info"] + assert lru_info.maxsize == 2 + + # After clear + info2 = k.globals["info2"] + assert info2.currsize == 0 + + async def test_async_cache_time_tracking( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + """Verify time_saved is tracked correctly for async functions.""" + await k.run( + [ + exec_req.get( + """ + import asyncio + from marimo._save.save import cache + + @cache + async def async_slow_func(x): + await asyncio.sleep(0.01) # Simulate slow async operation + return x * 2 + + # Initial state + info0 = async_slow_func.cache_info() + + # First call - miss (should record runtime) + r1 = await async_slow_func(5) + info1 = async_slow_func.cache_info() + + # Second call - hit (should add to time_saved) + r2 = await async_slow_func(5) + info2 = async_slow_func.cache_info() + + # Third call - another hit (should accumulate time_saved) + r3 = await async_slow_func(5) + info3 = async_slow_func.cache_info() + """ + ), + ] + ) + + assert not k.stderr.messages, k.stderr + + # Initial state: no time saved yet + info0 = k.globals["info0"] + assert info0.time_saved == 0.0 + + # After first call (miss): still no time saved + info1 = k.globals["info1"] + assert info1.time_saved == 0.0 + + # After first hit: should have some time saved + info2 = k.globals["info2"] + assert info2.time_saved > 0.0 + first_saving = info2.time_saved + + # After second hit: time_saved should accumulate + info3 = k.globals["info3"] + assert info3.time_saved > first_saving + assert info3.hits == 2 + + async def test_async_cache_class_method( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + """Test async class method caching.""" + await k.run( + [ + exec_req.get( + """ + import asyncio + from marimo._save.save import cache + + class AsyncMath: + @classmethod + @cache + async def compute(cls, x, y): + await asyncio.sleep(0.001) + return x + y + + case_a = AsyncMath().compute + case_b = AsyncMath().compute + case_c = AsyncMath.compute + result1 = await case_a(1, 2) + hash1 = case_a._last_hash + result2 = await case_b(2, 1) + hash2 = case_b._last_hash + result3 = await case_c(1, 2) + hash3 = case_c._last_hash + base_hash = AsyncMath.compute._last_hash + """ + ), + ] + ) + + assert not k.stdout.messages, k.stdout.messages + assert not k.stderr.messages, k.stderr.messages + + # Verify results + assert k.globals["result1"] == 3 + assert k.globals["result2"] == 3 + assert k.globals["result3"] == 3 + assert k.globals["hash1"] != k.globals["hash2"] + assert k.globals["hash1"] == k.globals["hash3"] + assert k.globals["case_c"]._last_hash is not None + + # NB. base_hash has different behavior than the others on python 3.13+ + # 3.13 has base_hash == hash1, while <3.13 has base_hash != None + import sys + + if sys.version_info >= (3, 13): + assert k.globals["base_hash"] == k.globals["hash1"] + else: + assert k.globals["base_hash"] is None + + async def test_async_lru_cache_default( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + """Test async lru_cache with default maxsize (256).""" + await k.run( + [ + exec_req.get( + """ + from marimo._save.save import lru_cache + + @lru_cache + async def async_fib(n): + if n <= 1: + return n + a = await async_fib(n - 1) + b = await async_fib(n - 2) + return a + b + + a = await async_fib(260) + b = await async_fib(10) + """ + ), + ] + ) + + assert not k.stderr.messages + # More hits with a smaller cache, because it needs to check the cache + # more. Has 256 entries by default, normal cache hits just 259 times. + assert k.globals["async_fib"].hits == 266 + + # A little ridiculous, but still low compute. + assert ( + k.globals["a"] + == 971183874599339129547649988289594072811608739584170445 + ) + assert k.globals["b"] == 55 + + async def test_async_cross_cell_cache( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + """Test async function caching across multiple notebook cells.""" + await k.run( + [ + exec_req.get("""from marimo._save.save import cache"""), + exec_req.get( + """ + @cache + async def async_fib(n): + if n <= 1: + return n + a = await async_fib(n - 1) + b = await async_fib(n - 2) + return a + b + """ + ), + exec_req.get("""a = await async_fib(5)"""), + exec_req.get("""b = await async_fib(10); a"""), + ] + ) + + assert not k.stderr.messages + assert k.globals["async_fib"].hits == 9 + + assert k.globals["a"] == 5 + assert k.globals["b"] == 55 + + async def test_async_cache_with_external_state( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + """Test async cached function with mo.state() dependency.""" + await k.run( + [ + exec_req.get( + """ + from marimo._save.save import cache + from marimo._runtime.state import state + """ + ), + exec_req.get("""external, setter = state(0)"""), + exec_req.get( + """ + @cache + async def async_fib(n): + if n <= 1: + return n + external() + a = await async_fib(n - 1) + b = await async_fib(n - 2) + return a + b + """ + ), + exec_req.get("""impure = []"""), + exec_req.get("""a = await async_fib(5)"""), + exec_req.get("""b = await async_fib(10); a"""), + exec_req.get( + """ + c = a + b + if len(impure) == 0: + setter(1) + elif len(impure) == 1: + setter(0) + impure.append(c) + """ + ), + ] + ) + + assert not k.stderr.messages + + assert k.globals["a"] == 5 + assert k.globals["b"] == 55 + assert k.globals["impure"] == [60, 157, 60] + # Cache hit value may be flaky depending on when state is evicted from + # the registry. The actual cache hit is less important than caching + # occurring in the first place. + # NB. 20 = 2 * 9 + 2 + if k.globals["async_fib"].hits in (9, 18): + import warnings + + warnings.warn( + "Known flaky edge case for async cache with state dep.", + stacklevel=1, + ) + else: + assert k.globals["async_fib"].hits == 20 + + async def test_async_cache_decorator_with_kwargs( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + """Test that kwargs hashing works identically for async functions.""" + await k.run( + [ + exec_req.get( + """ + from marimo._save.save import cache + + @cache + async def async_cached_func(*args, **kwargs): + return sum(args) + sum(kwargs.values()) + + # First call with specific kwargs + result1 = await async_cached_func(1, 2, some_kw_arg=3) + hash1 = async_cached_func._last_hash + """ + ), + exec_req.get( + """ + # Second call with different kwargs - should be cache miss + result2 = await async_cached_func(1, 2, some_kw_arg=4) + hash2 = async_cached_func._last_hash + """ + ), + exec_req.get( + """ + # Third call with same kwargs as first - should be cache hit + result3 = await async_cached_func(1, 2, some_kw_arg=3) + hash3 = async_cached_func._last_hash + """ + ), + ] + ) + + # Verify results + assert k.globals["result1"] == 6 # 1 + 2 + 3 + assert k.globals["result2"] == 7 # 1 + 2 + 4 + assert k.globals["result3"] == 6 # 1 + 2 + 3 + + # Verify cache keys + hash1 = k.globals["hash1"] + hash2 = k.globals["hash2"] + hash3 = k.globals["hash3"] + + assert hash1 != hash2, "Cache key should change when kwargs change" + assert hash1 == hash3, ( + "Cache key should be same for identical args/kwargs" + ) + + # Verify cache hits + assert k.globals["async_cached_func"].hits == 1 + + async def test_async_cache_concurrent_deduplication( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + """Test that concurrent calls to the same async cached function are deduplicated. + + When multiple async calls are made concurrently with the same arguments, + only one execution should occur - the rest should await the same task. + This prevents race conditions and duplicate work. + """ + await k.run( + [ + exec_req.get( + """ + import asyncio + from marimo._save.save import cache + + call_count = 0 + + @cache + async def expensive_async_compute(x): + global call_count + call_count += 1 + await asyncio.sleep(0.1) # Simulate expensive async work + return x * 2 + + # Launch 5 concurrent calls with the same argument + # Only one should actually execute, the rest should await that task + results = await asyncio.gather( + expensive_async_compute(42), + expensive_async_compute(42), + expensive_async_compute(42), + expensive_async_compute(42), + expensive_async_compute(42), + ) + """ + ), + ] + ) + + assert not k.stderr.messages, k.stderr + + # All results should be the same + results = k.globals["results"] + assert all(r == 84 for r in results), "All results should be 84" + + # The function should only have been called once (deduplication worked) + assert k.globals["call_count"] == 1, ( + f"Expected 1 execution due to deduplication, got {k.globals['call_count']}" + ) + + # Cache hit should be 0 (first execution is a miss) + # Note: The first call misses, subsequent concurrent calls await the same task + assert k.globals["expensive_async_compute"].hits == 0, ( + "First execution should be a miss, deduplication doesn't count as hits" + ) From 52b6e2c695c0d604ba0c8035168cde09b5b632ee Mon Sep 17 00:00:00 2001 From: Jan Scheffler Date: Fri, 10 Oct 2025 20:13:54 +0200 Subject: [PATCH 3/3] docs: document async cache support Add documentation for async/await support in cache decorators: - Add async examples for @cache and @persistent_cache decorators - Document task deduplication behavior for concurrent async calls - Update comparison table to show async support advantage over functools.cache --- docs/api/caching.md | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/docs/api/caching.md b/docs/api/caching.md index a8d4387cea3..4e7720e313d 100644 --- a/docs/api/caching.md +++ b/docs/api/caching.md @@ -36,6 +36,32 @@ def compute_embedding(data: str, embedding_dimension: int, model: str) -> np.nda /// +/// tab | `mo.cache` (async) + +```python +import marimo as mo + +@mo.cache +async def fetch_data(url: str, params: dict) -> dict: + response = await http_client.get(url, params=params) + return response.json() +``` + +/// + +/// tab | `mo.persistent_cache` (async) + +```python +import marimo as mo + +@mo.persistent_cache +async def compute_embedding(data: str, embedding_dimension: int, model: str) -> np.ndarray: + response = await llm_client.get_embeddings(data, model) + return response.embeddings +``` + +/// + Roughly speaking, the first time a cached function is called with a particular sequence of arguments, the function will run and its return value will be cached. The next time it is called with the same sequence of arguments (on @@ -49,6 +75,13 @@ letting you pick up where you left off. (For an in-memory cache of bounded size, use [`mo.lru_cache`][marimo.lru_cache].) +!!! note "Async functions are fully supported" + All cache decorators (`mo.cache`, `mo.lru_cache`, `mo.persistent_cache`) work + seamlessly with both synchronous and asynchronous functions. When multiple + concurrent calls are made to a cached async function with the same arguments, + only one execution occurs—the rest await the result. This prevents race conditions + and duplicate work. + !!! tip "Where persistent caches are stored" By default, persistent caches are stored in `__marimo__/cache/`, in the directory of the current notebook. For projects versioned with `git`, consider adding @@ -202,6 +235,7 @@ Here is a table comparing marimo's cache with `functools.cache`: | Tracks closed-over variables | ✅ | ❌ | | Allows unhashable arguments? | ✅ | ❌ | | Allows Array-like arguments? | ✅ | ❌ | +| Supports async functions? | ✅ | ❌ | | Suitable for lightweight functions (microseconds)? | ❌ | ✅ |