22from __future__ import annotations
33
44import ast
5+ import asyncio
56import functools
67import inspect
78import io
89import sys
10+ import threading
911import time
1012import traceback
13+ import weakref
1114from collections import abc
1215
1316# NB: maxsize follows functools.cache, but renamed max_size outside of drop-in
@@ -360,7 +363,7 @@ def __get__(
360363 "(have you wrapped a function?)"
361364 )
362365 # Bind to the instance
363- copy = _cache_call (
366+ copy = type ( self ) (
364367 None ,
365368 self ._loader_partial ,
366369 pin_modules = self .pin_modules ,
@@ -389,6 +392,18 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
389392 raise TypeError (
390393 "cache() takes at most 1 argument (expecting function)"
391394 )
395+ # Check if the function is async - if so, create async variant
396+ if inspect .iscoroutinefunction (args [0 ]):
397+ async_copy = _cache_call_async (
398+ None ,
399+ self ._loader_partial ,
400+ pin_modules = self .pin_modules ,
401+ hash_type = self .hash_type ,
402+ )
403+ async_copy ._frame_offset = self ._frame_offset
404+ async_copy ._frame_offset -= 4
405+ async_copy ._set_context (args [0 ])
406+ return async_copy
392407 # Remove the additional frames from singledispatch, because invoking
393408 # the function directly.
394409 self ._frame_offset -= 4
@@ -421,6 +436,116 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
421436 return response
422437
423438
439+ class _cache_call_async (_cache_call ):
440+ """Async variant of _cache_call for async/await functions.
441+
442+ Inherits all caching logic from _cache_call but provides an async
443+ __call__ method that properly awaits coroutines. Used automatically
444+ when @cache decorates an async function.
445+
446+ Implements task deduplication: concurrent calls with the same arguments
447+ will share the same execution, preventing duplicate work.
448+ """
449+
450+ # Track pending executions per cache instance to prevent race conditions
451+ # WeakKeyDictionary ensures instances are cleaned up when garbage collected
452+ # Key: cache instance, Value: dict of {cache_key: Task}
453+ _pending_executions : weakref .WeakKeyDictionary [
454+ _cache_call_async , dict [str , asyncio .Task [Any ]]
455+ ] = weakref .WeakKeyDictionary ()
456+ _pending_lock = threading .Lock ()
457+
458+ async def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
459+ # Capture the deferred call case
460+ if self .__wrapped__ is None :
461+ if len (args ) != 1 :
462+ raise TypeError (
463+ "cache() takes at most 1 argument (expecting function)"
464+ )
465+ # Remove the additional frames from singledispatch, because invoking
466+ # the function directly.
467+ self ._frame_offset -= 4
468+ self ._set_context (args [0 ])
469+ return self
470+
471+ # Prepare execution context to get cache key
472+ scope , ctx , attempt = self ._prepare_call_execution (args , kwargs )
473+ cache_key = attempt .hash
474+
475+ # Check for pending execution (task deduplication)
476+ existing_task = None
477+ with self ._pending_lock :
478+ if self not in self ._pending_executions :
479+ self ._pending_executions [self ] = {}
480+ pending = self ._pending_executions [self ]
481+
482+ if cache_key in pending :
483+ # Another coroutine is already executing this - save the task
484+ existing_task = pending [cache_key ]
485+
486+ # Await the existing task AFTER releasing the lock to avoid deadlock
487+ if existing_task is not None :
488+ return await existing_task
489+
490+ # No pending execution - create a new task
491+ task = asyncio .create_task (
492+ self ._execute_cached (scope , ctx , attempt , args , kwargs )
493+ )
494+
495+ with self ._pending_lock :
496+ pending [cache_key ] = task
497+
498+ try :
499+ result = await task
500+ finally :
501+ # Clean up completed task
502+ with self ._pending_lock :
503+ if cache_key in pending :
504+ del pending [cache_key ]
505+ # Clean up empty instance dict (WeakKeyDictionary handles instance cleanup)
506+ if not pending and self in self ._pending_executions :
507+ del self ._pending_executions [self ]
508+
509+ return result
510+
511+ async def _execute_cached (
512+ self ,
513+ scope : dict [str , Any ],
514+ ctx : Any ,
515+ attempt : Any ,
516+ args : tuple [Any , ...],
517+ kwargs : dict [str , Any ],
518+ ) -> Any :
519+ """Execute the cached function and update cache.
520+
521+ This is called by a single task even when multiple concurrent
522+ callers request the same computation.
523+ """
524+ assert self .__wrapped__ is not None , UNEXPECTED_FAILURE_BOILERPLATE
525+ failed = False
526+ self ._last_hash = attempt .hash
527+ try :
528+ if attempt .hit :
529+ attempt .restore (scope )
530+ return attempt .meta ["return" ]
531+
532+ start_time = time .time ()
533+ # Await the coroutine to get the actual result
534+ response = await self .__wrapped__ (* args , ** kwargs )
535+ runtime = time .time () - start_time
536+
537+ self ._finalize_cache_update (attempt , response , runtime , scope )
538+ except Exception as e :
539+ failed = True
540+ raise e
541+ finally :
542+ # NB. Exceptions raise their own side effects.
543+ if ctx and not failed :
544+ ctx .cell_lifecycle_registry .add (SideEffect (attempt .hash ))
545+ self ._misses += 1
546+ return response
547+
548+
424549class _cache_context (SkipContext , CacheContext ):
425550 def __init__ (
426551 self ,
@@ -617,7 +742,7 @@ def _invoke_call(
617742 * args : Any ,
618743 frame_offset : int = 1 ,
619744 ** kwargs : Any ,
620- ) -> _cache_call :
745+ ) -> Union [ _cache_call , _cache_call_async ] :
621746 if isinstance (loader , Loader ):
622747 raise TypeError (
623748 "A loader instance cannot be passed to cache directly. "
@@ -637,6 +762,13 @@ def _invoke_call(
637762 "Invalid loader type. "
638763 f"Expected a loader partial, got { type (loader )} ."
639764 )
765+
766+ # Check if the function is async
767+ if _fn is not None and inspect .iscoroutinefunction (_fn ):
768+ return _cache_call_async (
769+ _fn , loader , * args , frame_offset = frame_offset + 1 , ** kwargs
770+ )
771+
640772 return _cache_call (
641773 _fn , loader , * args , frame_offset = frame_offset + 1 , ** kwargs
642774 )
@@ -663,7 +795,7 @@ def _invoke_call_fn(
663795 * args : Any ,
664796 frame_offset : int = 1 ,
665797 ** kwargs : Any ,
666- ) -> _cache_call :
798+ ) -> Union [ _cache_call , _cache_call_async ] :
667799 return _invoke_call (
668800 _fn , loader , * args , frame_offset = frame_offset + 1 , ** kwargs
669801 )
0 commit comments