diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 04195ea0cf92..66124dd854ee 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -379,9 +379,9 @@ def test_duplicate_dict_args(caplog_vllm, parser): def test_supports_kw(callable,kw_name,requires_kw_only, allow_var_kwargs,is_supported): assert supports_kw( - callable=callable, - kw_name=kw_name, - requires_kw_only=requires_kw_only, + callable=callable, + kw_name=kw_name, + requires_kw_only=requires_kw_only, allow_var_kwargs=allow_var_kwargs ) == is_supported @@ -948,6 +948,36 @@ def test_join_host_port(): assert join_host_port("::1", 5555) == "[::1]:5555" +def test_json_count_leaves(): + """Test json_count_leaves function from jsontree utility.""" + from vllm.utils.jsontree import json_count_leaves + + # Single leaf values + assert json_count_leaves(42) == 1 + assert json_count_leaves("hello") == 1 + assert json_count_leaves(None) == 1 + + # Empty containers + assert json_count_leaves([]) == 0 + assert json_count_leaves({}) == 0 + assert json_count_leaves(()) == 0 + + # Flat structures + assert json_count_leaves([1, 2, 3]) == 3 + assert json_count_leaves({"a": 1, "b": 2}) == 2 + assert json_count_leaves((1, 2, 3)) == 3 + + # Nested structures + nested_dict = {"a": 1, "b": {"c": 2, "d": 3}} + assert json_count_leaves(nested_dict) == 3 + + nested_list = [1, [2, 3], 4] + assert json_count_leaves(nested_list) == 4 + + mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4} + assert json_count_leaves(mixed_nested) == 4 + + def test_convert_ids_list_to_tokens(): tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") token_ids = tokenizer.encode("Hello, world!") diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 0e81cb6d4d19..3c5cdc87835a 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -10,7 +10,8 @@ from vllm.logger import init_logger from vllm.utils import GiB_bytes, LRUCache -from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves +from vllm.utils.jsontree import (json_count_leaves, json_map_leaves, + json_reduce_leaves) from .inputs import (MultiModalFieldElem, MultiModalKwargs, MultiModalKwargsItem, MultiModalKwargsItems, @@ -127,11 +128,32 @@ def get_item_size( ) if debug: - logger.debug("Calculated size of %s to be %.2f GiB", type(value), - size / GiB_bytes) + leaf_count = json_count_leaves(value) + logger.debug( + "Calculated size of %s to be %.2f GiB (%d leaves)", + type(value), + size / GiB_bytes, + leaf_count, + ) return size + @classmethod + def get_item_complexity(cls, value: MultiModalCacheValue) -> int: + """ + Get the number of leaf elements in a multi-modal cache value. + + This provides a measure of structural complexity that can be useful + for debugging cache performance and understanding data patterns. + + Args: + value: The multi-modal cache value to analyze. + + Returns: + The number of leaf elements in the nested structure. + """ + return json_count_leaves(value) + @classmethod def get_lru_cache( cls, @@ -184,7 +206,7 @@ def get_and_update_item( """ Possibly update a multi-modal item based on whether it is in the underlying cache. - + This update is done out-of-place and updates the cache eviction order. Args: @@ -262,7 +284,7 @@ def is_cached(self, mm_hashes: list[str]) -> list[bool]: in the underlying cache. This **DOES NOT** update the cache eviction order. - + Args: mm_hashes: The hash of each item to check. diff --git a/vllm/utils/jsontree.py b/vllm/utils/jsontree.py index 4cbe0f76e006..457afb7e2c6f 100644 --- a/vllm/utils/jsontree.py +++ b/vllm/utils/jsontree.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Helper functions to work with nested JSON structures.""" + from collections.abc import Iterable from functools import reduce from typing import Callable, TypeVar, Union, overload @@ -8,8 +9,12 @@ _T = TypeVar("_T") _U = TypeVar("_U") -JSONTree = Union[dict[str, "JSONTree[_T]"], list["JSONTree[_T]"], - tuple["JSONTree[_T]", ...], _T] +JSONTree = Union[ + dict[str, "JSONTree[_T]"], + list["JSONTree[_T]"], + tuple["JSONTree[_T]", ...], + _T, +] """A nested JSON structure where the leaves need not be JSON-serializable.""" @@ -78,3 +83,8 @@ def json_reduce_leaves( json_iter_leaves(value), initial, ) + + +def json_count_leaves(value: JSONTree[_T]) -> int: + """Count the number of leaves in a nested JSON structure.""" + return sum(1 for _ in json_iter_leaves(value))