Skip to content

Commit e250061

Browse files
committed
feat: add async cache support with task deduplication
Add support for async/await functions with @cache, @lru_cache, and @persistent_cache decorators. Implements task deduplication to prevent race conditions when multiple concurrent calls are made with the same arguments. Implementation: - Use type(self) instead of _cache_call() in __get__() for proper subclass dispatch - Detect async functions and dispatch to _cache_call_async variant - Implement task deduplication using asyncio.Task caching with WeakKeyDictionary - Prevent concurrent duplicate executions via _pending_executions dict - Release lock before awaiting tasks to avoid deadlocks Testing: - Add 15 comprehensive async cache tests - Test concurrent deduplication (5 concurrent calls → 1 execution) - All 115 tests passing (100 sync + 15 async)
1 parent cf47fb0 commit e250061

File tree

2 files changed

+756
-3
lines changed

2 files changed

+756
-3
lines changed

marimo/_save/save.py

Lines changed: 135 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
from __future__ import annotations
33

44
import ast
5+
import asyncio
56
import functools
67
import inspect
78
import io
89
import sys
10+
import threading
911
import time
1012
import traceback
13+
import weakref
1114
from collections import abc
1215

1316
# NB: maxsize follows functools.cache, but renamed max_size outside of drop-in
@@ -360,7 +363,7 @@ def __get__(
360363
"(have you wrapped a function?)"
361364
)
362365
# Bind to the instance
363-
copy = _cache_call(
366+
copy = type(self)(
364367
None,
365368
self._loader_partial,
366369
pin_modules=self.pin_modules,
@@ -389,6 +392,18 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
389392
raise TypeError(
390393
"cache() takes at most 1 argument (expecting function)"
391394
)
395+
# Check if the function is async - if so, create async variant
396+
if inspect.iscoroutinefunction(args[0]):
397+
async_copy = _cache_call_async(
398+
None,
399+
self._loader_partial,
400+
pin_modules=self.pin_modules,
401+
hash_type=self.hash_type,
402+
)
403+
async_copy._frame_offset = self._frame_offset
404+
async_copy._frame_offset -= 4
405+
async_copy._set_context(args[0])
406+
return async_copy
392407
# Remove the additional frames from singledispatch, because invoking
393408
# the function directly.
394409
self._frame_offset -= 4
@@ -421,6 +436,116 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
421436
return response
422437

423438

439+
class _cache_call_async(_cache_call):
440+
"""Async variant of _cache_call for async/await functions.
441+
442+
Inherits all caching logic from _cache_call but provides an async
443+
__call__ method that properly awaits coroutines. Used automatically
444+
when @cache decorates an async function.
445+
446+
Implements task deduplication: concurrent calls with the same arguments
447+
will share the same execution, preventing duplicate work.
448+
"""
449+
450+
# Track pending executions per cache instance to prevent race conditions
451+
# WeakKeyDictionary ensures instances are cleaned up when garbage collected
452+
# Key: cache instance, Value: dict of {cache_key: Task}
453+
_pending_executions: weakref.WeakKeyDictionary[
454+
_cache_call_async, dict[str, asyncio.Task[Any]]
455+
] = weakref.WeakKeyDictionary()
456+
_pending_lock = threading.Lock()
457+
458+
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
459+
# Capture the deferred call case
460+
if self.__wrapped__ is None:
461+
if len(args) != 1:
462+
raise TypeError(
463+
"cache() takes at most 1 argument (expecting function)"
464+
)
465+
# Remove the additional frames from singledispatch, because invoking
466+
# the function directly.
467+
self._frame_offset -= 4
468+
self._set_context(args[0])
469+
return self
470+
471+
# Prepare execution context to get cache key
472+
scope, ctx, attempt = self._prepare_call_execution(args, kwargs)
473+
cache_key = attempt.hash
474+
475+
# Check for pending execution (task deduplication)
476+
existing_task = None
477+
with self._pending_lock:
478+
if self not in self._pending_executions:
479+
self._pending_executions[self] = {}
480+
pending = self._pending_executions[self]
481+
482+
if cache_key in pending:
483+
# Another coroutine is already executing this - save the task
484+
existing_task = pending[cache_key]
485+
486+
# Await the existing task AFTER releasing the lock to avoid deadlock
487+
if existing_task is not None:
488+
return await existing_task
489+
490+
# No pending execution - create a new task
491+
task = asyncio.create_task(
492+
self._execute_cached(scope, ctx, attempt, args, kwargs)
493+
)
494+
495+
with self._pending_lock:
496+
pending[cache_key] = task
497+
498+
try:
499+
result = await task
500+
finally:
501+
# Clean up completed task
502+
with self._pending_lock:
503+
if cache_key in pending:
504+
del pending[cache_key]
505+
# Clean up empty instance dict (WeakKeyDictionary handles instance cleanup)
506+
if not pending and self in self._pending_executions:
507+
del self._pending_executions[self]
508+
509+
return result
510+
511+
async def _execute_cached(
512+
self,
513+
scope: dict[str, Any],
514+
ctx: Any,
515+
attempt: Any,
516+
args: tuple[Any, ...],
517+
kwargs: dict[str, Any],
518+
) -> Any:
519+
"""Execute the cached function and update cache.
520+
521+
This is called by a single task even when multiple concurrent
522+
callers request the same computation.
523+
"""
524+
assert self.__wrapped__ is not None, UNEXPECTED_FAILURE_BOILERPLATE
525+
failed = False
526+
self._last_hash = attempt.hash
527+
try:
528+
if attempt.hit:
529+
attempt.restore(scope)
530+
return attempt.meta["return"]
531+
532+
start_time = time.time()
533+
# Await the coroutine to get the actual result
534+
response = await self.__wrapped__(*args, **kwargs)
535+
runtime = time.time() - start_time
536+
537+
self._finalize_cache_update(attempt, response, runtime, scope)
538+
except Exception as e:
539+
failed = True
540+
raise e
541+
finally:
542+
# NB. Exceptions raise their own side effects.
543+
if ctx and not failed:
544+
ctx.cell_lifecycle_registry.add(SideEffect(attempt.hash))
545+
self._misses += 1
546+
return response
547+
548+
424549
class _cache_context(SkipContext, CacheContext):
425550
def __init__(
426551
self,
@@ -617,7 +742,7 @@ def _invoke_call(
617742
*args: Any,
618743
frame_offset: int = 1,
619744
**kwargs: Any,
620-
) -> _cache_call:
745+
) -> Union[_cache_call, _cache_call_async]:
621746
if isinstance(loader, Loader):
622747
raise TypeError(
623748
"A loader instance cannot be passed to cache directly. "
@@ -637,6 +762,13 @@ def _invoke_call(
637762
"Invalid loader type. "
638763
f"Expected a loader partial, got {type(loader)}."
639764
)
765+
766+
# Check if the function is async
767+
if _fn is not None and inspect.iscoroutinefunction(_fn):
768+
return _cache_call_async(
769+
_fn, loader, *args, frame_offset=frame_offset + 1, **kwargs
770+
)
771+
640772
return _cache_call(
641773
_fn, loader, *args, frame_offset=frame_offset + 1, **kwargs
642774
)
@@ -663,7 +795,7 @@ def _invoke_call_fn(
663795
*args: Any,
664796
frame_offset: int = 1,
665797
**kwargs: Any,
666-
) -> _cache_call:
798+
) -> Union[_cache_call, _cache_call_async]:
667799
return _invoke_call(
668800
_fn, loader, *args, frame_offset=frame_offset + 1, **kwargs
669801
)

0 commit comments

Comments
 (0)