Skip to content

Commit f4c98b4

Browse files
authored
[Misc] Consolidate LRUCache implementations (#15481)
Signed-off-by: Bella kira <[email protected]>
1 parent e1e0fd7 commit f4c98b4

File tree

2 files changed

+104
-56
lines changed

2 files changed

+104
-56
lines changed

vllm/multimodal/processing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
TypeVar, Union, cast)
1313

1414
import torch
15-
from cachetools import LRUCache
1615
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
1716
from typing_extensions import assert_never
1817

@@ -21,7 +20,7 @@
2120
from vllm.logger import init_logger
2221
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
2322
encode_tokens)
24-
from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
23+
from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby
2524

2625
from .hasher import MultiModalHasher
2726
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,

vllm/utils.py

Lines changed: 103 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,17 @@
3333
import warnings
3434
import weakref
3535
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
36-
from collections import OrderedDict, UserDict, defaultdict
36+
from collections import UserDict, defaultdict
3737
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
38-
Iterable, Iterator, Mapping)
38+
Iterable, Iterator, KeysView, Mapping)
3939
from dataclasses import dataclass, field
4040
from functools import cache, lru_cache, partial, wraps
41+
from types import MappingProxyType
4142
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
42-
Optional, Type, TypeVar, Union)
43+
Optional, Type, TypeVar, Union, cast, overload)
4344
from uuid import uuid4
4445

46+
import cachetools
4547
import cloudpickle
4648
import numpy as np
4749
import numpy.typing as npt
@@ -173,6 +175,7 @@
173175

174176
_K = TypeVar("_K", bound=Hashable)
175177
_V = TypeVar("_V")
178+
_T = TypeVar("_T")
176179

177180

178181
class _Sentinel:
@@ -206,6 +209,19 @@ def reset(self) -> None:
206209
self.counter = 0
207210

208211

212+
class _MappingOrderCacheView(UserDict[_K, _V]):
213+
214+
def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]):
215+
super().__init__(data)
216+
self.ordered_keys = ordered_keys
217+
218+
def __iter__(self) -> Iterator[_K]:
219+
return iter(self.ordered_keys)
220+
221+
def keys(self) -> KeysView[_K]:
222+
return KeysView(self.ordered_keys)
223+
224+
209225
class CacheInfo(NamedTuple):
210226
hits: int
211227
total: int
@@ -218,45 +234,62 @@ def hit_ratio(self) -> float:
218234
return self.hits / self.total
219235

220236

221-
class LRUCache(Generic[_K, _V]):
222-
"""Note: This class is not thread safe!"""
237+
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
223238

224-
def __init__(self, capacity: int) -> None:
225-
self.cache = OrderedDict[_K, _V]()
239+
def __init__(self,
240+
capacity: float,
241+
getsizeof: Optional[Callable[[_V], float]] = None):
242+
super().__init__(capacity, getsizeof)
226243
self.pinned_items = set[_K]()
227244
self.capacity = capacity
228245

229246
self._hits = 0
230247
self._total = 0
231248

232-
def __contains__(self, key: _K) -> bool:
233-
return key in self.cache
234-
235-
def __len__(self) -> int:
236-
return len(self.cache)
237-
238-
def __getitem__(self, key: _K) -> _V:
239-
value = self.cache[key] # Raise KeyError if not exists
240-
self.cache.move_to_end(key)
241-
return value
249+
def __delitem__(self, key: _K) -> None:
250+
run_on_remove = key in self
251+
value = self.__getitem__(key)
252+
super().__delitem__(key)
253+
if key in self.pinned_items:
254+
# Todo: add warning to inform that del pinned item
255+
self._unpin(key)
256+
if run_on_remove:
257+
self._on_remove(key, value)
242258

243-
def __setitem__(self, key: _K, value: _V) -> None:
244-
self.put(key, value)
259+
@property
260+
def cache(self) -> Mapping[_K, _V]:
261+
"""Return the internal cache dictionary in order (read-only)."""
262+
return _MappingOrderCacheView(
263+
self._Cache__data, # type: ignore
264+
self.order)
245265

246-
def __delitem__(self, key: _K) -> None:
247-
self.pop(key)
266+
@property
267+
def order(self) -> Mapping[_K, None]:
268+
"""Return the internal order dictionary (read-only)."""
269+
return MappingProxyType(self._LRUCache__order) # type: ignore
248270

249271
def stat(self) -> CacheInfo:
250272
return CacheInfo(hits=self._hits, total=self._total)
251273

252274
def touch(self, key: _K) -> None:
253-
self.cache.move_to_end(key)
275+
self._LRUCache__update(key) # type: ignore
276+
277+
@overload
278+
def get(self, key: _K, /) -> Optional[_V]:
279+
...
280+
281+
@overload
282+
def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]:
283+
...
254284

255-
def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
256-
value: Optional[_V]
257-
if key in self.cache:
258-
value = self.cache[key]
259-
self.cache.move_to_end(key)
285+
def get(self,
286+
key: _K,
287+
/,
288+
default: Optional[Union[_V,
289+
_T]] = None) -> Optional[Union[_V, _T]]:
290+
value: Optional[Union[_V, _T]]
291+
if key in self:
292+
value = self.__getitem__(key)
260293

261294
self._hits += 1
262295
else:
@@ -265,60 +298,76 @@ def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
265298
self._total += 1
266299
return value
267300

301+
@overload
302+
def pop(self, key: _K) -> _V:
303+
...
304+
305+
@overload
306+
def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]:
307+
...
308+
309+
def pop(self,
310+
key: _K,
311+
default: Optional[Union[_V,
312+
_T]] = None) -> Optional[Union[_V, _T]]:
313+
value: Optional[Union[_V, _T]]
314+
if key not in self:
315+
return default
316+
317+
value = self[key]
318+
del self[key]
319+
return value
320+
268321
def put(self, key: _K, value: _V) -> None:
269-
self.cache[key] = value
270-
self.cache.move_to_end(key)
271-
self._remove_old_if_needed()
322+
self.__setitem__(key, value)
272323

273324
def pin(self, key: _K) -> None:
274325
"""
275326
Pins a key in the cache preventing it from being
276327
evicted in the LRU order.
277328
"""
278-
if key not in self.cache:
329+
if key not in self:
279330
raise ValueError(f"Cannot pin key: {key} not in cache.")
280331
self.pinned_items.add(key)
281332

282333
def _unpin(self, key: _K) -> None:
334+
"""
335+
Unpins a key in the cache allowing it to be
336+
evicted in the LRU order.
337+
"""
283338
self.pinned_items.remove(key)
284339

285340
def _on_remove(self, key: _K, value: Optional[_V]) -> None:
286341
pass
287342

288343
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
289-
if not self.cache:
344+
if len(self) == 0:
290345
return
291346

347+
self.popitem(remove_pinned=remove_pinned)
348+
349+
def _remove_old_if_needed(self) -> None:
350+
while self.currsize > self.capacity:
351+
self.remove_oldest()
352+
353+
def clear(self) -> None:
354+
while len(self) > 0:
355+
self.remove_oldest(remove_pinned=True)
356+
357+
def popitem(self, remove_pinned: bool = False):
358+
"""Remove and return the `(key, value)` pair least recently used."""
292359
if not remove_pinned:
293360
# pop the oldest item in the cache that is not pinned
294361
lru_key = next(
295-
(key for key in self.cache if key not in self.pinned_items),
362+
(key for key in self.order if key not in self.pinned_items),
296363
ALL_PINNED_SENTINEL)
297364
if lru_key is ALL_PINNED_SENTINEL:
298365
raise RuntimeError("All items are pinned, "
299366
"cannot remove oldest from the cache.")
300367
else:
301-
lru_key = next(iter(self.cache))
302-
self.pop(lru_key) # type: ignore
303-
304-
def _remove_old_if_needed(self) -> None:
305-
while len(self.cache) > self.capacity:
306-
self.remove_oldest()
307-
308-
def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
309-
run_on_remove = key in self.cache
310-
value = self.cache.pop(key, default)
311-
# remove from pinned items
312-
if key in self.pinned_items:
313-
self._unpin(key)
314-
if run_on_remove:
315-
self._on_remove(key, value)
316-
return value
317-
318-
def clear(self) -> None:
319-
while len(self.cache) > 0:
320-
self.remove_oldest(remove_pinned=True)
321-
self.cache.clear()
368+
lru_key = next(iter(self.order))
369+
value = self.pop(cast(_K, lru_key))
370+
return (lru_key, value)
322371

323372

324373
class PyObjectCache:

0 commit comments

Comments
 (0)