Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1ded8ae
Refactor: make sure the API calls are backward compatible
KuntaiDu Sep 25, 2025
42040ba
align function signature
KuntaiDu Sep 25, 2025
fbaa51a
fix mypy errors
KuntaiDu Sep 26, 2025
fae4c82
adjust the signature of block_ids
KuntaiDu Sep 26, 2025
f858a9d
merge and resolve merge conflict
KuntaiDu Oct 12, 2025
0aa2b01
allow hybrid kv cache manager + connector
KuntaiDu Oct 12, 2025
89a976c
init using ConnectorVllmConfig
KuntaiDu Oct 12, 2025
9cdd2b0
put the change of function signature inside KVConnectorHMAMixin class
KuntaiDu Oct 13, 2025
e0ac23c
remove unnecessary change of function signature
KuntaiDu Oct 13, 2025
312e065
fix merge conflict
KuntaiDu Oct 13, 2025
b29a257
copy kv cache config instead of just sending the pointer
KuntaiDu Oct 13, 2025
866c404
align the way of checking if the connector supports HMA
KuntaiDu Oct 16, 2025
367b7b7
change class name to SupportsHMA
KuntaiDu Oct 17, 2025
e35f118
avoid using ConnectorVllmConfig, use copy instead
KuntaiDu Oct 17, 2025
7e963b9
use deepcopy instead
KuntaiDu Oct 17, 2025
650d666
adjust the comments
KuntaiDu Oct 17, 2025
37a589d
adjust comments
KuntaiDu Oct 17, 2025
6abc1c2
adjust comments
KuntaiDu Oct 17, 2025
8fc7bca
Merge branch 'main' into kuntai-enable-hma-connector
KuntaiDu Oct 20, 2025
27774f3
change deepcopy to shallowcopy --- shallow copy should be enough
KuntaiDu Oct 20, 2025
1d7f75f
Merge branch 'kuntai-enable-hma-connector' of https://github.com/Kunt…
KuntaiDu Oct 20, 2025
ababeec
fix CPU offloading test
KuntaiDu Oct 20, 2025
1974b5f
fix CI errors
KuntaiDu Oct 20, 2025
9198d3e
fix NIXL-connector-related CI errors
KuntaiDu Oct 23, 2025
eee8c11
Merge branch 'main' into kuntai-enable-hma-connector
KuntaiDu Oct 23, 2025
c6e0bc4
fix CI errors
KuntaiDu Oct 23, 2025
919fe9b
remove hma support from LMCache for now
KuntaiDu Oct 23, 2025
0b67b76
add an extra sanity check for request_finished
KuntaiDu Oct 23, 2025
4bfdcf8
Merge branch 'main' into kuntai-enable-hma-connector
KuntaiDu Oct 23, 2025
36e42a1
fix bug
KuntaiDu Oct 24, 2025
6f6347c
Merge branch 'kuntai-enable-hma-connector' of https://github.com/Kunt…
KuntaiDu Oct 24, 2025
0df4f02
fix CI bug
KuntaiDu Oct 24, 2025
5d88c0d
fix CI issues
KuntaiDu Oct 24, 2025
2fac4fb
fix CI errors
KuntaiDu Oct 24, 2025
4c724a6
Merge branch 'main' into kuntai-enable-hma-connector
KuntaiDu Oct 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,7 @@ def test_kv_connector_basic():
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=True,
disable_hybrid_kv_cache_manager=True,
)
NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks()
BLOCK_SIZE = scheduler.cache_config.block_size
Expand Down Expand Up @@ -1028,6 +1029,7 @@ def test_kv_connector_unable_to_allocate():
use_kv_connector=True,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
disable_hybrid_kv_cache_manager=True,
)
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
Expand Down Expand Up @@ -1111,6 +1113,7 @@ def test_kv_connector_handles_preemption():
use_kv_connector=True,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
disable_hybrid_kv_cache_manager=True,
)

NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
Expand Down Expand Up @@ -1327,6 +1330,7 @@ def create_scheduler_with_priority(
block_size: int = 16,
max_model_len: int | None = None,
num_speculative_tokens: int | None = None,
disable_hybrid_kv_cache_manager: bool = False,
) -> Scheduler:
"""Create scheduler with priority policy enabled.

Expand All @@ -1351,6 +1355,7 @@ def create_scheduler_with_priority(
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
policy="priority", # Enable priority scheduling
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager,
)
model_config = ModelConfig(
model=model,
Expand Down Expand Up @@ -1958,6 +1963,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv():
num_blocks=5, # Can hold 64 tokens (first block is null)
block_size=16, # Standard block size
use_kv_connector=True,
disable_hybrid_kv_cache_manager=True,
)

# Create a request and schedule it
Expand Down
2 changes: 2 additions & 0 deletions tests/v1/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def create_scheduler(
num_speculative_tokens: int | None = None,
skip_tokenizer_init: bool = False,
async_scheduling: bool = False,
disable_hybrid_kv_cache_manager: bool = False,
) -> Scheduler | AsyncScheduler:
"""Create scheduler under test.

Expand All @@ -70,6 +71,7 @@ def create_scheduler(
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
async_scheduling=async_scheduling,
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager,
)
model_config = ModelConfig(
model=model,
Expand Down
1 change: 1 addition & 0 deletions tests/v1/kv_offload/test_cpu_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_cpu_offloading(cpu_block_size: int) -> None:
model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5,
kv_transfer_config=kv_transfer_config,
disable_hybrid_kv_cache_manager=True,
)

prompts = ["Hi " * 100]
Expand Down
6 changes: 3 additions & 3 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@
from transformers import PretrainedConfig

from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.v1.kv_cache_interface import KVCacheConfig
else:
PretrainedConfig = Any

QuantizationConfig = Any

KVCacheConfig = Any

logger = init_logger(__name__)


Expand Down Expand Up @@ -565,9 +568,6 @@ def __post_init__(self):
if not current_platform.support_hybrid_kv_cache():
# Hybrid KV cache manager is not supported on non-GPU platforms.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_transfer_config is not None:
# Hybrid KV cache manager is not compatible with KV transfer.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
Expand Down
18 changes: 15 additions & 3 deletions vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
from typing import TYPE_CHECKING, cast

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import (
KVConnectorBase,
KVConnectorBaseType,
)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorRole,
supports_hma,
)
from vllm.logger import init_logger

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig

logger = init_logger(__name__)
Expand All @@ -38,7 +41,7 @@ def loader() -> type[KVConnectorBase]:
@classmethod
def create_connector(
cls,
config: "VllmConfig",
config: VllmConfig,
role: KVConnectorRole,
) -> KVConnectorBase:
if not envs.VLLM_USE_V1:
Expand All @@ -51,6 +54,15 @@ def create_connector(
if kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set to create a connector")
connector_cls = cls.get_connector_class(kv_transfer_config)

# check if the connector supports HMA
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
if hma_enabled and not supports_hma(connector_cls):
raise ValueError(
f"Connector {connector_cls.__name__} does not support HMA but "
f"HMA is enabled. Please set `--disable-hybrid-kv-cache-manager`."
)

logger.info(
"Creating v1 connector with name: %s and engine_id: %s",
connector_cls.__name__,
Expand Down
9 changes: 8 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorRole,
SupportsHMA,
supports_hma,
)

__all__ = ["KVConnectorRole", "KVConnectorBase_V1"]
__all__ = [
"KVConnectorRole",
"KVConnectorBase_V1",
"supports_hma",
"SupportsHMA",
]
38 changes: 37 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,42 @@
logger = init_logger(__name__)


class SupportsHMA:
"""
The class that indicates the corresponding connector supports hybrid memory
allocator (HMA).
This is required to use the connector together with hybrid memory allocator.
"""

def request_finished(
self,
request: "Request",
block_ids: tuple[list[int], ...],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand why it's tempting to do this, but I think this sort of overloading can cause unnecessary confusion - how about making this more explicit by calling the method something like request_finished_all_groups() ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still need to think this for a while, but this might be a good idea

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heheda12345 @NickLucche I have no preference personally. Do you guys have preference on having a new function request_finished_all_groups when passing block_ids as tuple of list of int?

) -> tuple[bool, dict[str, Any] | None]:
"""
Called exactly once when a request has finished, before its blocks are
freed.

The connector may assumes responsibility for freeing the blocks
asynchronously by returning True.

Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
return False, None


def supports_hma(connector: Any) -> bool:
if isinstance(connector, type):
return issubclass(connector, SupportsHMA)
else:
return isinstance(connector, SupportsHMA)


class KVConnectorRole(enum.Enum):
# Connector running in the scheduler process
SCHEDULER = 0
Expand Down Expand Up @@ -370,7 +406,7 @@ def request_finished(
Called exactly once when a request has finished, before its blocks are
freed.

The connector may assumes responsibility for freeing the the blocks
The connector may assumes responsibility for freeing the blocks
asynchronously by returning True.

Returns:
Expand Down
14 changes: 10 additions & 4 deletions vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
SupportsHMA,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
Expand All @@ -23,8 +24,8 @@
logger = init_logger(__name__)


class LMCacheConnectorV1(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
class LMCacheConnectorV1(SupportsHMA, KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)

Expand Down Expand Up @@ -157,10 +158,10 @@ def build_connector_meta(
"""
return self._lmcache_engine.build_connector_meta(scheduler_output)

def request_finished(
def request_finished( # type: ignore[override]
self,
request: "Request",
block_ids: list[int],
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called when a request has finished, before its blocks are freed.
Expand All @@ -171,5 +172,10 @@ def request_finished(
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.

Note:
This method intentionally uses tuple[list[int], ...] from
SupportsHMA interface instead of list[int] from
KVConnectorBase_V1 to support hybrid memory allocation.
"""
return self._lmcache_engine.request_finished(request, block_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we know whether the new argument format is supported by the lmcache code that's installed? I'd expect a version check or some sort of capability check here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see LMCache/LMCache#1436 now - you need to require this version somehow? The old version will blow up if you pass it the tuple?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LMCache is not relying on request_finished, so changing the signature is OK.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really don't follow - this PR will work with older lmcache versions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently request_finished is a placeholder function that simply does nothing in LMCache. So it is OK even if we pass block_ids as tuple[list[int]] to LMCache because LMCache won't process it anyway.

22 changes: 14 additions & 8 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
import itertools
import time
from collections import defaultdict
Expand All @@ -13,6 +13,7 @@
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
supports_hma,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.logger import init_logger
Expand Down Expand Up @@ -85,15 +86,14 @@ def __init__(
# KV Connector pushes/pull of remote KVs for P/D and offloading.
self.connector = None
if self.vllm_config.kv_transfer_config is not None:
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"Multiple KV cache groups are not currently supported "
"with KV connectors"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to move this assertion into the backwards compat code

def request_finished(
        connector: KVConnectorBase_V1,
        request: "Request",
        block_ids: tuple[list[int], ...],
    ) -> tuple[bool, dict[str, Any] | None]:
    if isinstance(connector, SupportsHMA):
        return connector.request_finished_all_groups(request, block_ids)
    else:  # for backwards compatibility
        assert connector.kv_cache_config.kv_cache_groups == 1
        return connector.request_finished(request, block_ids[0])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar assertion is done during initializing the connector. It is better to assert during initialization instead of request_finished to avoid the case where the user sees vLLM server launch up but it fails the assertion during the inference.

)
assert not self.is_encoder_decoder, (
"Encoder-decoder models are not currently supported with KV connectors"
)

connector_vllm_config = copy.copy(self.vllm_config)
connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config)
self.connector = KVConnectorFactory.create_connector(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER
config=connector_vllm_config, role=KVConnectorRole.SCHEDULER
)

self.kv_event_publisher = EventPublisherFactory.create(
Expand Down Expand Up @@ -1296,8 +1296,14 @@ def _connector_finished(
if self.connector is None:
return False, None

(block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id)
return self.connector.request_finished(request, block_ids)
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)

if not supports_hma(self.connector):
# NOTE(Kuntai): We should remove this code path after we enforce
# all connectors to support HMA.
return self.connector.request_finished(request, block_ids[0])
else:
return self.connector.request_finished(request, block_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to see this logic in the connectors module ...

def request_finished(
        connector: KVConnectorBase_V1,
        request: "Request",
        block_ids: tuple[list[int], ...],
    ) -> tuple[bool, dict[str, Any] | None]:
    if isinstance(connector, SupportsHMA):
        return connector.request_finished_all_groups(request, block_ids)
    else:  # for backwards compatibility
        return connector.request_finished(request, block_ids[0])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function _connector_finished is already a small wrapper function that contains < 10 LoC besides comments. Building one more wrapper on top of it may feel a bit over-abstracted.


def _update_waiting_for_remote_kv(self, request: Request) -> bool:
"""
Expand Down
11 changes: 9 additions & 2 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,15 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""

# Init kv cache connector here, because it requires
# `kv_cache_config`.
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
# because `initialize_kv_cache` will inject kv cache groups not
# related to kv cache connector (e.g. kv cache sharing layers).
connector_vllm_config = copy.copy(self.vllm_config)
connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're using a copy of VllmConfig as a holder to send KVCacheConfig down to the connector? And VllmConfig doesn't ordinarily have a kv_cache_config member? That seems extremely brittle?

Why not just make KVCacheConfig a constructor parameter for KVConnector?

Or could we instead supply the connector the layer ID/name to KV cache group ID mapping? Will all connectors need this mapping to support HMA?

Copy link
Collaborator Author

@KuntaiDu KuntaiDu Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two reasons of not initializing using vllm_config and kv_cache_config separately:

  • Having kv_cache_config as an extra arg in the constructor breaks backward compatibility because the connector may fail to initialize if we put vllm_config and kv_cache_config as a two separate args into an old connector.
  • Putting KVCacheConfig into vllm_config aligns with @heheda12345 's future refactoring direction.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backwards compat is tricky, but we can't allow ourselves to be trapped in a situation where any new data for KVConnector must be stuffed into VllmConfig

e.g. we could add an init_kv_cache_config() method and detect whether a connector implements the method

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be stuffed into vllm_config because the design goal of vllm_config is to centralize all the configurations into one place. Also the extra init argument kv_cache_config will be merged into vllm_config soon according to @heheda12345 . So for now I would still prefer inserting the kv_cache_config directly into vllm_config and then remove the injection after vllm_config is done.

I understand that allowing arbitrary field injection is generally not ideal, but it aligns with the design goal of vllm_config and the refactoring direction of kv_cache_config so I would still prefer the current way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm only just seeing this now but agree with @markmc this is hacky and I don't think we should have merged it.

Either include the changes to add that field to VllmConfig or we could use introspection on the connector for backwards compatibility.

ensure_kv_transfer_initialized(connector_vllm_config)

if self.vllm_config.model_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator

Expand Down Expand Up @@ -779,5 +788,3 @@ def init_worker_distributed_environment(
parallel_config.pipeline_parallel_size,
parallel_config.decode_context_parallel_size,
)

ensure_kv_transfer_initialized(vllm_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that other connectors (out of tree perhaps) might break by initializing earlier?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be fine because both locations are still before real model execution and CUDA graph capturing. So in terms of the ability of adding extra GPU operations before/after attention and before/after forwarding these two locations are the same.