diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index c17c6f6c89b0..fba577239682 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -899,6 +899,7 @@ def test_kv_connector_basic(): scheduler = create_scheduler( enable_prefix_caching=True, use_kv_connector=True, + disable_hybrid_kv_cache_manager=True, ) NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks() BLOCK_SIZE = scheduler.cache_config.block_size @@ -1024,6 +1025,7 @@ def test_external_prefix_cache_metrics(): scheduler = create_scheduler( enable_prefix_caching=False, use_kv_connector=True, + disable_hybrid_kv_cache_manager=True, ) # Mock connector to simulate a partial external cache hit @@ -1088,6 +1090,7 @@ def test_kv_connector_unable_to_allocate(): use_kv_connector=True, block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, + disable_hybrid_kv_cache_manager=True, ) NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler.connector.get_num_new_matched_tokens = Mock(name="method") @@ -1171,6 +1174,7 @@ def test_kv_connector_handles_preemption(): use_kv_connector=True, block_size=BLOCK_SIZE, num_blocks=NUM_BLOCKS, + disable_hybrid_kv_cache_manager=True, ) NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE @@ -1387,6 +1391,7 @@ def create_scheduler_with_priority( block_size: int = 16, max_model_len: int | None = None, num_speculative_tokens: int | None = None, + disable_hybrid_kv_cache_manager: bool = False, ) -> Scheduler: """Create scheduler with priority policy enabled. @@ -1411,6 +1416,7 @@ def create_scheduler_with_priority( disable_chunked_mm_input=disable_chunked_mm_input, enable_chunked_prefill=True, policy="priority", # Enable priority scheduling + disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager, ) model_config = ModelConfig( model=model, @@ -2018,6 +2024,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv(): num_blocks=5, # Can hold 64 tokens (first block is null) block_size=16, # Standard block size use_kv_connector=True, + disable_hybrid_kv_cache_manager=True, ) # Create a request and schedule it diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 6e739d6b0e77..3f5e1b9eeaf7 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -46,6 +46,7 @@ def create_scheduler( num_speculative_tokens: int | None = None, skip_tokenizer_init: bool = False, async_scheduling: bool = False, + disable_hybrid_kv_cache_manager: bool = False, ) -> Scheduler | AsyncScheduler: """Create scheduler under test. @@ -70,6 +71,7 @@ def create_scheduler( disable_chunked_mm_input=disable_chunked_mm_input, enable_chunked_prefill=True, async_scheduling=async_scheduling, + disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager, ) model_config = ModelConfig( model=model, diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index a9817313cf02..a756858e2cc5 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -136,6 +136,7 @@ run_tests_for_model() { vllm serve $model_name \ --port $PORT \ --enforce-eager \ + --disable-hybrid-kv-cache-manager \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ --tensor-parallel-size $PREFILLER_TP_SIZE \ --kv-transfer-config '$KV_CONFIG'" @@ -178,6 +179,7 @@ run_tests_for_model() { --port $PORT \ --enforce-eager \ --gpu-memory-utilization $GPU_MEMORY_UTILIZATION \ + --disable-hybrid-kv-cache-manager \ --kv-transfer-config '$KV_CONFIG'" # DP-EP attention mode diff --git a/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh b/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh index c48b452e24cd..a3eeedb2e514 100755 --- a/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh @@ -85,6 +85,7 @@ run_tests_for_model() { --port $PREFILL_PORT \ --enforce-eager \ --gpu-memory-utilization 0.2 \ + --disable-hybrid-kv-cache-manager \ --kv-transfer-config '$KV_CONFIG'" if [ -n "$model_args" ]; then @@ -103,6 +104,7 @@ run_tests_for_model() { --port $DECODE_PORT \ --enforce-eager \ --gpu-memory-utilization 0.2 \ + --disable-hybrid-kv-cache-manager \ --kv-transfer-config '$KV_CONFIG'" if [ -n "$model_args" ]; then diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index 1c1ac915c758..6748532afd97 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -114,6 +114,7 @@ def test_multi_shared_storage_connector_consistency(): enforce_eager=True, gpu_memory_utilization=0.5, kv_transfer_config=kv_transfer_config, + disable_hybrid_kv_cache_manager=True, ) # Run generation - this should trigger saving KV cache _ = llm.generate(PROMPTS, SAMPLING_PARAMS) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index e073321c637b..445d115010cd 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -932,6 +932,7 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): "gpu_memory_utilization": 0.5, "kv_transfer_config": kv_transfer_config, "distributed_executor_backend": distributed_executor_backend, + "disable_hybrid_kv_cache_manager": True, } timeout = 6 diff --git a/tests/v1/kv_connector/unit/test_shared_storage_connector.py b/tests/v1/kv_connector/unit/test_shared_storage_connector.py index e7013a794a8c..6040ed5a6806 100644 --- a/tests/v1/kv_connector/unit/test_shared_storage_connector.py +++ b/tests/v1/kv_connector/unit/test_shared_storage_connector.py @@ -132,6 +132,7 @@ def test_shared_storage_connector_hashes(tmp_path): enforce_eager=True, kv_transfer_config=kv_transfer_config, limit_mm_per_prompt={"image": 2}, + disable_hybrid_kv_cache_manager=True, ) # don't put this import at the top level diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index e3f30bd7698f..46ea46e53084 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -91,6 +91,9 @@ def create_vllm_config( max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len, enable_chunked_prefill=enable_chunked_prefill, + # Disable hybrid KV cache manager for testing + # Should be removed after we support hybrid KV cache manager-based testing. + disable_hybrid_kv_cache_manager=True, ) model_config = ModelConfig( model=model, diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py index 0d90cc715fd4..e9c255b1ee99 100644 --- a/tests/v1/kv_offload/test_cpu_offloading.py +++ b/tests/v1/kv_offload/test_cpu_offloading.py @@ -27,6 +27,7 @@ def test_cpu_offloading(cpu_block_size: int) -> None: model="meta-llama/Llama-3.2-1B-Instruct", gpu_memory_utilization=0.5, kv_transfer_config=kv_transfer_config, + disable_hybrid_kv_cache_manager=True, ) prompts = ["Hi " * 100] diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 472d6ed2c1df..916f258d6586 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -40,11 +40,14 @@ from transformers import PretrainedConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + from vllm.v1.kv_cache_interface import KVCacheConfig else: PretrainedConfig = Any QuantizationConfig = Any + KVCacheConfig = Any + logger = init_logger(__name__) @@ -568,9 +571,6 @@ def __post_init__(self): if not current_platform.support_hybrid_kv_cache(): # Hybrid KV cache manager is not supported on non-GPU platforms. self.scheduler_config.disable_hybrid_kv_cache_manager = True - if self.kv_transfer_config is not None: - # Hybrid KV cache manager is not compatible with KV transfer. - self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.kv_events_config is not None: # Hybrid KV cache manager is not compatible with KV events. self.scheduler_config.disable_hybrid_kv_cache_manager = True diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 46a9ce77f8c4..c64996f13cd5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -6,15 +6,18 @@ from typing import TYPE_CHECKING, cast import vllm.envs as envs +from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import ( KVConnectorBase, KVConnectorBaseType, ) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorRole, + supports_hma, +) from vllm.logger import init_logger if TYPE_CHECKING: - from vllm.config import VllmConfig from vllm.config.kv_transfer import KVTransferConfig logger = init_logger(__name__) @@ -38,7 +41,7 @@ def loader() -> type[KVConnectorBase]: @classmethod def create_connector( cls, - config: "VllmConfig", + config: VllmConfig, role: KVConnectorRole, ) -> KVConnectorBase: if not envs.VLLM_USE_V1: @@ -51,6 +54,15 @@ def create_connector( if kv_transfer_config is None: raise ValueError("kv_transfer_config must be set to create a connector") connector_cls = cls.get_connector_class(kv_transfer_config) + + # check if the connector supports HMA + hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager + if hma_enabled and not supports_hma(connector_cls): + raise ValueError( + f"Connector {connector_cls.__name__} does not support HMA but " + f"HMA is enabled. Please set `--disable-hybrid-kv-cache-manager`." + ) + logger.info( "Creating v1 connector with name: %s and engine_id: %s", connector_cls.__name__, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py index bb558c956029..0e16bc5cc685 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py @@ -3,9 +3,17 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorRole, + SupportsHMA, + supports_hma, ) from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import ( # noqa E:501 DecodeBenchConnector, ) -__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "DecodeBenchConnector"] +__all__ = [ + "KVConnectorRole", + "KVConnectorBase_V1", + "supports_hma", + "SupportsHMA", + "DecodeBenchConnector", +] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 989e2f664bee..2562eb9ce70e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -70,6 +70,45 @@ logger = init_logger(__name__) +class SupportsHMA(ABC): + """ + The class that indicates the corresponding connector supports hybrid memory + allocator (HMA). + This is required to use the connector together with hybrid memory allocator. + """ + + @abstractmethod + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + """ + Called exactly once when a request has finished for all kv cache groups, + before its blocks are freed for each group. + + NOTE(Kuntai): This function is only supported by connectors that support HMA. + + The connector may assumes responsibility for freeing the blocks + asynchronously by returning True. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + raise NotImplementedError + + +def supports_hma(connector: Any) -> bool: + if isinstance(connector, type): + return issubclass(connector, SupportsHMA) + else: + return isinstance(connector, SupportsHMA) + + class KVConnectorRole(enum.Enum): # Connector running in the scheduler process SCHEDULER = 0 @@ -370,7 +409,7 @@ def request_finished( Called exactly once when a request has finished, before its blocks are freed. - The connector may assumes responsibility for freeing the the blocks + The connector may assumes responsibility for freeing the blocks asynchronously by returning True. Returns: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ff237b28a2c9..7afee15a2da6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import copy import itertools import time from collections import defaultdict @@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import ( KVConnectorBase_V1, KVConnectorRole, + supports_hma, ) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger @@ -86,15 +87,14 @@ def __init__( self.connector = None self.connector_prefix_cache_stats: PrefixCacheStats | None = None if self.vllm_config.kv_transfer_config is not None: - assert len(self.kv_cache_config.kv_cache_groups) == 1, ( - "Multiple KV cache groups are not currently supported " - "with KV connectors" - ) assert not self.is_encoder_decoder, ( "Encoder-decoder models are not currently supported with KV connectors" ) + + connector_vllm_config = copy.copy(self.vllm_config) + connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) self.connector = KVConnectorFactory.create_connector( - config=self.vllm_config, role=KVConnectorRole.SCHEDULER + config=connector_vllm_config, role=KVConnectorRole.SCHEDULER ) if self.log_stats: self.connector_prefix_cache_stats = PrefixCacheStats() @@ -1324,8 +1324,17 @@ def _connector_finished( if self.connector is None: return False, None - (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) - return self.connector.request_finished(request, block_ids) + block_ids = self.kv_cache_manager.get_block_ids(request.request_id) + + if not supports_hma(self.connector): + # NOTE(Kuntai): We should deprecate this code path after we enforce + # all connectors to support HMA. + # Hybrid memory allocator should be already turned off for this + # code path, but let's double-check here. + assert len(self.kv_cache_config.kv_cache_groups) == 1 + return self.connector.request_finished(request, block_ids[0]) + else: + return self.connector.request_finished(request, block_ids) def _update_waiting_for_remote_kv(self, request: Request) -> bool: """ diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 668f5b2307b7..0e2437e358ef 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -348,6 +348,15 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" + # Init kv cache connector here, because it requires + # `kv_cache_config`. + # NOTE(Kuntai): This need to be done before `initialize_kv_cache`, + # because `initialize_kv_cache` will inject kv cache groups not + # related to kv cache connector (e.g. kv cache sharing layers). + connector_vllm_config = copy.copy(self.vllm_config) + connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) + ensure_kv_transfer_initialized(connector_vllm_config) + if self.vllm_config.model_config.enable_sleep_mode: from vllm.device_allocator.cumem import CuMemAllocator @@ -800,5 +809,3 @@ def init_worker_distributed_environment( parallel_config.pipeline_parallel_size, parallel_config.decode_context_parallel_size, ) - - ensure_kv_transfer_initialized(vllm_config)