3333import warnings
3434import weakref
3535from asyncio import FIRST_COMPLETED , AbstractEventLoop , Task
36- from collections import OrderedDict , UserDict , defaultdict
36+ from collections import UserDict , defaultdict
3737from collections .abc import (AsyncGenerator , Awaitable , Generator , Hashable ,
38- Iterable , Iterator , Mapping )
38+ Iterable , Iterator , KeysView , Mapping )
3939from dataclasses import dataclass , field
4040from functools import cache , lru_cache , partial , wraps
41+ from types import MappingProxyType
4142from typing import (TYPE_CHECKING , Any , Callable , Generic , Literal , NamedTuple ,
42- Optional , Type , TypeVar , Union )
43+ Optional , Type , TypeVar , Union , cast , overload )
4344from uuid import uuid4
4445
46+ import cachetools
4547import cloudpickle
4648import numpy as np
4749import numpy .typing as npt
173175
174176_K = TypeVar ("_K" , bound = Hashable )
175177_V = TypeVar ("_V" )
178+ _T = TypeVar ("_T" )
176179
177180
178181class _Sentinel :
@@ -206,6 +209,19 @@ def reset(self) -> None:
206209 self .counter = 0
207210
208211
212+ class _MappingOrderCacheView (UserDict [_K , _V ]):
213+
214+ def __init__ (self , data : Mapping [_K , _V ], ordered_keys : Mapping [_K , None ]):
215+ super ().__init__ (data )
216+ self .ordered_keys = ordered_keys
217+
218+ def __iter__ (self ) -> Iterator [_K ]:
219+ return iter (self .ordered_keys )
220+
221+ def keys (self ) -> KeysView [_K ]:
222+ return KeysView (self .ordered_keys )
223+
224+
209225class CacheInfo (NamedTuple ):
210226 hits : int
211227 total : int
@@ -218,45 +234,62 @@ def hit_ratio(self) -> float:
218234 return self .hits / self .total
219235
220236
221- class LRUCache (Generic [_K , _V ]):
222- """Note: This class is not thread safe!"""
237+ class LRUCache (cachetools .LRUCache [_K , _V ], Generic [_K , _V ]):
223238
224- def __init__ (self , capacity : int ) -> None :
225- self .cache = OrderedDict [_K , _V ]()
239+ def __init__ (self ,
240+ capacity : float ,
241+ getsizeof : Optional [Callable [[_V ], float ]] = None ):
242+ super ().__init__ (capacity , getsizeof )
226243 self .pinned_items = set [_K ]()
227244 self .capacity = capacity
228245
229246 self ._hits = 0
230247 self ._total = 0
231248
232- def __contains__ (self , key : _K ) -> bool :
233- return key in self .cache
234-
235- def __len__ (self ) -> int :
236- return len (self .cache )
237-
238- def __getitem__ (self , key : _K ) -> _V :
239- value = self .cache [key ] # Raise KeyError if not exists
240- self .cache .move_to_end (key )
241- return value
249+ def __delitem__ (self , key : _K ) -> None :
250+ run_on_remove = key in self
251+ value = self .__getitem__ (key )
252+ super ().__delitem__ (key )
253+ if key in self .pinned_items :
254+ # Todo: add warning to inform that del pinned item
255+ self ._unpin (key )
256+ if run_on_remove :
257+ self ._on_remove (key , value )
242258
243- def __setitem__ (self , key : _K , value : _V ) -> None :
244- self .put (key , value )
259+ @property
260+ def cache (self ) -> Mapping [_K , _V ]:
261+ """Return the internal cache dictionary in order (read-only)."""
262+ return _MappingOrderCacheView (
263+ self ._Cache__data , # type: ignore
264+ self .order )
245265
246- def __delitem__ (self , key : _K ) -> None :
247- self .pop (key )
266+ @property
267+ def order (self ) -> Mapping [_K , None ]:
268+ """Return the internal order dictionary (read-only)."""
269+ return MappingProxyType (self ._LRUCache__order ) # type: ignore
248270
249271 def stat (self ) -> CacheInfo :
250272 return CacheInfo (hits = self ._hits , total = self ._total )
251273
252274 def touch (self , key : _K ) -> None :
253- self .cache .move_to_end (key )
275+ self ._LRUCache__update (key ) # type: ignore
276+
277+ @overload
278+ def get (self , key : _K , / ) -> Optional [_V ]:
279+ ...
280+
281+ @overload
282+ def get (self , key : _K , / , default : Union [_V , _T ]) -> Union [_V , _T ]:
283+ ...
254284
255- def get (self , key : _K , default : Optional [_V ] = None ) -> Optional [_V ]:
256- value : Optional [_V ]
257- if key in self .cache :
258- value = self .cache [key ]
259- self .cache .move_to_end (key )
285+ def get (self ,
286+ key : _K ,
287+ / ,
288+ default : Optional [Union [_V ,
289+ _T ]] = None ) -> Optional [Union [_V , _T ]]:
290+ value : Optional [Union [_V , _T ]]
291+ if key in self :
292+ value = self .__getitem__ (key )
260293
261294 self ._hits += 1
262295 else :
@@ -265,60 +298,76 @@ def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
265298 self ._total += 1
266299 return value
267300
301+ @overload
302+ def pop (self , key : _K ) -> _V :
303+ ...
304+
305+ @overload
306+ def pop (self , key : _K , default : Union [_V , _T ]) -> Union [_V , _T ]:
307+ ...
308+
309+ def pop (self ,
310+ key : _K ,
311+ default : Optional [Union [_V ,
312+ _T ]] = None ) -> Optional [Union [_V , _T ]]:
313+ value : Optional [Union [_V , _T ]]
314+ if key not in self :
315+ return default
316+
317+ value = self [key ]
318+ del self [key ]
319+ return value
320+
268321 def put (self , key : _K , value : _V ) -> None :
269- self .cache [key ] = value
270- self .cache .move_to_end (key )
271- self ._remove_old_if_needed ()
322+ self .__setitem__ (key , value )
272323
273324 def pin (self , key : _K ) -> None :
274325 """
275326 Pins a key in the cache preventing it from being
276327 evicted in the LRU order.
277328 """
278- if key not in self . cache :
329+ if key not in self :
279330 raise ValueError (f"Cannot pin key: { key } not in cache." )
280331 self .pinned_items .add (key )
281332
282333 def _unpin (self , key : _K ) -> None :
334+ """
335+ Unpins a key in the cache allowing it to be
336+ evicted in the LRU order.
337+ """
283338 self .pinned_items .remove (key )
284339
285340 def _on_remove (self , key : _K , value : Optional [_V ]) -> None :
286341 pass
287342
288343 def remove_oldest (self , * , remove_pinned : bool = False ) -> None :
289- if not self . cache :
344+ if len ( self ) == 0 :
290345 return
291346
347+ self .popitem (remove_pinned = remove_pinned )
348+
349+ def _remove_old_if_needed (self ) -> None :
350+ while self .currsize > self .capacity :
351+ self .remove_oldest ()
352+
353+ def clear (self ) -> None :
354+ while len (self ) > 0 :
355+ self .remove_oldest (remove_pinned = True )
356+
357+ def popitem (self , remove_pinned : bool = False ):
358+ """Remove and return the `(key, value)` pair least recently used."""
292359 if not remove_pinned :
293360 # pop the oldest item in the cache that is not pinned
294361 lru_key = next (
295- (key for key in self .cache if key not in self .pinned_items ),
362+ (key for key in self .order if key not in self .pinned_items ),
296363 ALL_PINNED_SENTINEL )
297364 if lru_key is ALL_PINNED_SENTINEL :
298365 raise RuntimeError ("All items are pinned, "
299366 "cannot remove oldest from the cache." )
300367 else :
301- lru_key = next (iter (self .cache ))
302- self .pop (lru_key ) # type: ignore
303-
304- def _remove_old_if_needed (self ) -> None :
305- while len (self .cache ) > self .capacity :
306- self .remove_oldest ()
307-
308- def pop (self , key : _K , default : Optional [_V ] = None ) -> Optional [_V ]:
309- run_on_remove = key in self .cache
310- value = self .cache .pop (key , default )
311- # remove from pinned items
312- if key in self .pinned_items :
313- self ._unpin (key )
314- if run_on_remove :
315- self ._on_remove (key , value )
316- return value
317-
318- def clear (self ) -> None :
319- while len (self .cache ) > 0 :
320- self .remove_oldest (remove_pinned = True )
321- self .cache .clear ()
368+ lru_key = next (iter (self .order ))
369+ value = self .pop (cast (_K , lru_key ))
370+ return (lru_key , value )
322371
323372
324373class PyObjectCache :
0 commit comments