Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions tests/utils_/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ def test_duplicate_dict_args(caplog_vllm, parser):
def test_supports_kw(callable,kw_name,requires_kw_only,
allow_var_kwargs,is_supported):
assert supports_kw(
callable=callable,
kw_name=kw_name,
requires_kw_only=requires_kw_only,
callable=callable,
kw_name=kw_name,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs
) == is_supported

Expand Down Expand Up @@ -948,6 +948,36 @@ def test_join_host_port():
assert join_host_port("::1", 5555) == "[::1]:5555"


def test_json_count_leaves():
"""Test json_count_leaves function from jsontree utility."""
from vllm.utils.jsontree import json_count_leaves

# Single leaf values
assert json_count_leaves(42) == 1
assert json_count_leaves("hello") == 1
assert json_count_leaves(None) == 1

# Empty containers
assert json_count_leaves([]) == 0
assert json_count_leaves({}) == 0
assert json_count_leaves(()) == 0

# Flat structures
assert json_count_leaves([1, 2, 3]) == 3
assert json_count_leaves({"a": 1, "b": 2}) == 2
assert json_count_leaves((1, 2, 3)) == 3

# Nested structures
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
assert json_count_leaves(nested_dict) == 3

nested_list = [1, [2, 3], 4]
assert json_count_leaves(nested_list) == 4

mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
assert json_count_leaves(mixed_nested) == 4


def test_convert_ids_list_to_tokens():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
token_ids = tokenizer.encode("Hello, world!")
Expand Down
32 changes: 27 additions & 5 deletions vllm/multimodal/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache
from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves
from vllm.utils.jsontree import (json_count_leaves, json_map_leaves,
json_reduce_leaves)

from .inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem, MultiModalKwargsItems,
Expand Down Expand Up @@ -127,11 +128,32 @@ def get_item_size(
)

if debug:
logger.debug("Calculated size of %s to be %.2f GiB", type(value),
size / GiB_bytes)
leaf_count = json_count_leaves(value)
logger.debug(
"Calculated size of %s to be %.2f GiB (%d leaves)",
type(value),
size / GiB_bytes,
leaf_count,
)

return size

@classmethod
def get_item_complexity(cls, value: MultiModalCacheValue) -> int:
"""
Get the number of leaf elements in a multi-modal cache value.

This provides a measure of structural complexity that can be useful
for debugging cache performance and understanding data patterns.

Args:
value: The multi-modal cache value to analyze.

Returns:
The number of leaf elements in the nested structure.
"""
return json_count_leaves(value)

@classmethod
def get_lru_cache(
cls,
Expand Down Expand Up @@ -184,7 +206,7 @@ def get_and_update_item(
"""
Possibly update a multi-modal item based on whether it is
in the underlying cache.

This update is done out-of-place and updates the cache eviction order.

Args:
Expand Down Expand Up @@ -262,7 +284,7 @@ def is_cached(self, mm_hashes: list[str]) -> list[bool]:
in the underlying cache.

This **DOES NOT** update the cache eviction order.

Args:
mm_hashes: The hash of each item to check.

Expand Down
14 changes: 12 additions & 2 deletions vllm/utils/jsontree.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Helper functions to work with nested JSON structures."""

from collections.abc import Iterable
from functools import reduce
from typing import Callable, TypeVar, Union, overload

_T = TypeVar("_T")
_U = TypeVar("_U")

JSONTree = Union[dict[str, "JSONTree[_T]"], list["JSONTree[_T]"],
tuple["JSONTree[_T]", ...], _T]
JSONTree = Union[
dict[str, "JSONTree[_T]"],
list["JSONTree[_T]"],
tuple["JSONTree[_T]", ...],
_T,
]
"""A nested JSON structure where the leaves need not be JSON-serializable."""


Expand Down Expand Up @@ -78,3 +83,8 @@ def json_reduce_leaves(
json_iter_leaves(value),
initial,
)


def json_count_leaves(value: JSONTree[_T]) -> int:
"""Count the number of leaves in a nested JSON structure."""
return sum(1 for _ in json_iter_leaves(value))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation uses json_iter_leaves, which is recursive. This can lead to a RecursionError for deeply nested JSON structures (the default limit in Python is often 1000). To make this utility function more robust, an iterative approach using an explicit stack is recommended. This avoids the recursion limit and is generally safer for functions processing arbitrary tree-like data structures.

Suggested change
return sum(1 for _ in json_iter_leaves(value))
count = 0
stack = [value]
while stack:
node = stack.pop()
if isinstance(node, dict):
stack.extend(node.values())
elif isinstance(node, (list, tuple)):
stack.extend(node)
else:
count += 1
return count