Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 67 additions & 25 deletions tests/v1/kv_connector/unit/test_shared_storage_connector.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import asdict
from pathlib import Path
from typing import NamedTuple

import pytest
from PIL import Image

from vllm import LLM, EngineArgs, SamplingParams
Expand Down Expand Up @@ -108,53 +110,68 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
print("-" * 50)


def test_shared_storage_connector_hashes(tmp_path):
"""
Tests that SharedStorageConnector saves KV to the storage locations
with proper hashes; that are unique for inputs with identical text but
different images (same size), or same multiple images but different orders.
"""
# Using tmp_path as the storage path to store KV
print(f"KV storage path at: {str(tmp_path)}")
def build_llm_instance(eager: bool, shared_storage_path: Path):
"""Create the LLM instance with SharedStorageConnector configuration."""
print(f"KV storage path at: {str(shared_storage_path)}")

# Configure the SharedStorageConnector
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": str(tmp_path)},
kv_connector_extra_config={"shared_storage_path": str(shared_storage_path)},
)

engine_args = EngineArgs(
model=MODEL_NAME,
max_model_len=8192,
max_num_seqs=1,
gpu_memory_utilization=0.4,
enforce_eager=True,
enforce_eager=eager,
kv_transfer_config=kv_transfer_config,
limit_mm_per_prompt={"image": 2},
)

# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401
# Create the LLM instance
engine_args_dict = asdict(engine_args)
return LLM(**engine_args_dict)


# Create processor to handle the chat prompt
processor = AutoProcessor.from_pretrained(MODEL_NAME)
@pytest.fixture(scope="function")
def shared_storage_path(tmp_path_factory):
"""Create a shared storage path for all tests in this session."""
return tmp_path_factory.mktemp("kv_storage")

# Prepare images for the tests

@pytest.fixture(scope="session")
def test_images():
"""Prepare images for the tests."""
# Resize to the same size to check hashes correctness
image_1 = ImageAsset("stop_sign").pil_image.resize((1280, 720))
image_2 = ImageAsset("cherry_blossom").pil_image.resize((1280, 720))

# Make sure that they are not the same picture
assert image_1 != image_2, "The images should not be identical"

# Create the LLM instance
engine_args = asdict(engine_args)
llm = LLM(**engine_args)
return {"image_1": image_1, "image_2": image_2}


@pytest.fixture(scope="session")
def processor():
"""Create processor to handle the chat prompt."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401

return AutoProcessor.from_pretrained(MODEL_NAME)


@pytest.fixture(scope="session")
def input_cases(test_images):
"""Prepare the input cases for testing."""
image_1 = test_images["image_1"]
image_2 = test_images["image_2"]

# Prepare the input cases
input_cases = [
return [
InputCase(
text=TEXT_PROMPTS[0],
img=[image_1],
Expand Down Expand Up @@ -239,9 +256,34 @@ def test_shared_storage_connector_hashes(tmp_path):
),
]

# Run tests

@pytest.mark.parametrize("eager", [False, True])
def test_shared_storage_connector_hashes(
shared_storage_path,
processor,
input_cases,
eager,
):
"""
Tests that SharedStorageConnector saves KV to the storage locations
with proper hashes; that are unique for inputs with identical text but
different images (same size), or same multiple images but different orders.

Note: These tests are stateful and must run in order (case_id 0 to 10).
Each test depends on the cumulative state from previous tests.
"""
llm_instance = build_llm_instance(
eager=eager, shared_storage_path=shared_storage_path
)

for case_id, (text, img, expected_len, info) in enumerate(input_cases):
print("\n", "=" * 25, f"Below running input case: {case_id}", "=" * 25)
run_test(tmp_path, processor, llm, text, img, expected_len, info)

print("All tests passed successfully!")
run_test(
shared_storage_path,
processor,
llm_instance,
text,
img,
expected_len,
info,
)