From 497be791ed850bc667dae1c3ce5254c9e32bb040 Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Mon, 20 Oct 2025 17:50:03 +0300 Subject: [PATCH 1/3] fix: Use LoRA name for consistent KV-cache block hashing Signed-off-by: Sage Ahrac --- vllm/v1/core/kv_cache_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6870e7ebde37..3b2ce8f54229 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -373,7 +373,7 @@ def need_extra_keys(request: Request) -> bool: """ # Multimodal requests need to include the MM hash. - # LoRA requests need to include the LoRA ID. + # LoRA requests need to include the LoRA name. # Request with provided cache salt need to include the salt. return ( bool(request.mm_features) @@ -446,26 +446,26 @@ def _gen_mm_extra_hash_keys( return extra_keys, curr_mm_idx -def _gen_lora_extra_hash_keys(request: Request) -> list[int]: +def _gen_lora_extra_hash_keys(request: Request) -> list[str]: """Generate extra keys related to LoRA for block hash computation. Args: request: The request object. Returns: - Return LoRA id of the request if it is a LoRA request. Return empty + Return LoRA name of the request if it is a LoRA request. Return empty list otherwise. """ if not request.lora_request: return [] - return [request.lora_request.lora_int_id] + return [request.lora_request.lora_name] def generate_block_hash_extra_keys( request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int ) -> tuple[tuple[Any, ...] | None, int]: """Generate extra keys for the block hash. The extra keys can come from - the multi-modal inputs and request specific metadata (e.g., LoRA ID). + the multi-modal inputs and request specific metadata (e.g., LoRA name). Args: request: The request object. @@ -480,7 +480,7 @@ def generate_block_hash_extra_keys( mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( request, start_token_idx, end_token_idx, start_mm_idx ) - lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request) + lora_extra_keys: list[str] = _gen_lora_extra_hash_keys(request) cache_salt_keys: list[str] = ( [request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else [] ) From 7fb047eee94edcc7e8bf0e68a33d10a740b4db41 Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Mon, 20 Oct 2025 18:41:55 +0300 Subject: [PATCH 2/3] added test Signed-off-by: Sage Ahrac --- tests/v1/core/test_kv_cache_utils.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 6558267c13a3..5a5bec84d8ab 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -44,6 +44,7 @@ ) from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats from vllm.v1.request import Request +from vllm.lora.request import LoRARequest pytestmark = pytest.mark.cpu_test @@ -449,6 +450,26 @@ def test_generate_block_hash_extra_keys_cache_salt(): assert next_mm_idx == 1 +def test_generate_block_hash_extra_keys_lora(): + request = make_request( + request_id="0", + prompt_token_ids=[_ for _ in range(6)], + ) + + request.lora_request = LoRARequest( + lora_name="test_lora_adapter", + lora_int_id=1, + lora_path="/path/to/lora" + ) + + extra_keys, _ = generate_block_hash_extra_keys(request, 0, 3, 0) + assert extra_keys == ("test_lora_adapter",) + + request.lora_request = None + extra_keys, _ = generate_block_hash_extra_keys(request, 0, 3, 0) + assert extra_keys is None + + @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_hash_block_tokens(hash_fn): parent_block_hash = BlockHash(b"123") From a2065c2270d8f84ff86f113b867ed194643e74dc Mon Sep 17 00:00:00 2001 From: Sage Ahrac Date: Mon, 20 Oct 2025 18:47:53 +0300 Subject: [PATCH 3/3] lint Signed-off-by: Sage Ahrac --- tests/v1/core/test_kv_cache_utils.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 5a5bec84d8ab..d192c58a8c15 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -8,6 +8,7 @@ import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.config import ModelConfig, SchedulerConfig, VllmConfig +from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import ( MultiModalFeatureSpec, MultiModalKwargsItem, @@ -44,7 +45,6 @@ ) from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats from vllm.v1.request import Request -from vllm.lora.request import LoRARequest pytestmark = pytest.mark.cpu_test @@ -455,16 +455,14 @@ def test_generate_block_hash_extra_keys_lora(): request_id="0", prompt_token_ids=[_ for _ in range(6)], ) - + request.lora_request = LoRARequest( - lora_name="test_lora_adapter", - lora_int_id=1, - lora_path="/path/to/lora" + lora_name="test_lora_adapter", lora_int_id=1, lora_path="/path/to/lora" ) - + extra_keys, _ = generate_block_hash_extra_keys(request, 0, 3, 0) assert extra_keys == ("test_lora_adapter",) - + request.lora_request = None extra_keys, _ = generate_block_hash_extra_keys(request, 0, 3, 0) assert extra_keys is None