Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1ded8ae
Refactor: make sure the API calls are backward compatible
KuntaiDu Sep 25, 2025
42040ba
align function signature
KuntaiDu Sep 25, 2025
fbaa51a
fix mypy errors
KuntaiDu Sep 26, 2025
fae4c82
adjust the signature of block_ids
KuntaiDu Sep 26, 2025
f858a9d
merge and resolve merge conflict
KuntaiDu Oct 12, 2025
0aa2b01
allow hybrid kv cache manager + connector
KuntaiDu Oct 12, 2025
89a976c
init using ConnectorVllmConfig
KuntaiDu Oct 12, 2025
9cdd2b0
put the change of function signature inside KVConnectorHMAMixin class
KuntaiDu Oct 13, 2025
e0ac23c
remove unnecessary change of function signature
KuntaiDu Oct 13, 2025
312e065
fix merge conflict
KuntaiDu Oct 13, 2025
b29a257
copy kv cache config instead of just sending the pointer
KuntaiDu Oct 13, 2025
866c404
align the way of checking if the connector supports HMA
KuntaiDu Oct 16, 2025
367b7b7
change class name to SupportsHMA
KuntaiDu Oct 17, 2025
e35f118
avoid using ConnectorVllmConfig, use copy instead
KuntaiDu Oct 17, 2025
7e963b9
use deepcopy instead
KuntaiDu Oct 17, 2025
650d666
adjust the comments
KuntaiDu Oct 17, 2025
37a589d
adjust comments
KuntaiDu Oct 17, 2025
6abc1c2
adjust comments
KuntaiDu Oct 17, 2025
8fc7bca
Merge branch 'main' into kuntai-enable-hma-connector
KuntaiDu Oct 20, 2025
27774f3
change deepcopy to shallowcopy --- shallow copy should be enough
KuntaiDu Oct 20, 2025
1d7f75f
Merge branch 'kuntai-enable-hma-connector' of https://github.com/Kunt…
KuntaiDu Oct 20, 2025
ababeec
fix CPU offloading test
KuntaiDu Oct 20, 2025
1974b5f
fix CI errors
KuntaiDu Oct 20, 2025
9198d3e
fix NIXL-connector-related CI errors
KuntaiDu Oct 23, 2025
eee8c11
Merge branch 'main' into kuntai-enable-hma-connector
KuntaiDu Oct 23, 2025
c6e0bc4
fix CI errors
KuntaiDu Oct 23, 2025
919fe9b
remove hma support from LMCache for now
KuntaiDu Oct 23, 2025
0b67b76
add an extra sanity check for request_finished
KuntaiDu Oct 23, 2025
4bfdcf8
Merge branch 'main' into kuntai-enable-hma-connector
KuntaiDu Oct 23, 2025
36e42a1
fix bug
KuntaiDu Oct 24, 2025
6f6347c
Merge branch 'kuntai-enable-hma-connector' of https://github.com/Kunt…
KuntaiDu Oct 24, 2025
0df4f02
fix CI bug
KuntaiDu Oct 24, 2025
5d88c0d
fix CI issues
KuntaiDu Oct 24, 2025
2fac4fb
fix CI errors
KuntaiDu Oct 24, 2025
4c724a6
Merge branch 'main' into kuntai-enable-hma-connector
KuntaiDu Oct 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,9 +520,6 @@ def __post_init__(self):
if not current_platform.support_hybrid_kv_cache():
# Hybrid KV cache manager is not supported on non-GPU platforms.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_transfer_config is not None:
# Hybrid KV cache manager is not compatible with KV transfer.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
Expand Down
4 changes: 2 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorRole)
KVConnectorBase_V1, KVConnectorRole, supports_hma)

__all__ = ["KVConnectorRole", "KVConnectorBase_V1"]
__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "supports_hma"]
17 changes: 15 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import enum
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union

import torch

Expand All @@ -64,6 +64,19 @@
logger = init_logger(__name__)


class SupportsHMA:
"""
Inherent this interface if the connector supports hybrid memory
allocator (HMA). This is required to use the connector together
with hybrid memory allocator.
"""
pass


def supports_hma(cls: type) -> bool:
return isinstance(cls, SupportsHMA)


class KVConnectorRole(enum.Enum):
# Connector running in the scheduler process
SCHEDULER = 0
Expand Down Expand Up @@ -323,7 +336,7 @@ def update_connector_output(self, connector_output: KVConnectorOutput):
def request_finished(
self,
request: "Request",
block_ids: list[int],
block_ids: Union[list[int], tuple[list[int], ...]],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union

import torch
from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl

from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, SupportsHMA)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput

Expand All @@ -20,7 +20,7 @@
logger = init_logger(__name__)


class LMCacheConnectorV1(KVConnectorBase_V1):
class LMCacheConnectorV1(KVConnectorBase_V1, SupportsHMA):

def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
Expand Down Expand Up @@ -153,7 +153,7 @@ def build_connector_meta(
def request_finished(
self,
request: "Request",
block_ids: list[int],
block_ids: Union[list[int], tuple[list[int], ...]],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union

import torch

Expand Down Expand Up @@ -245,7 +245,7 @@ def update_connector_output(self, connector_output: KVConnectorOutput):
def request_finished(
self,
request: "Request",
blocks: list[int],
blocks: Union[tuple[list[int], ...], list[int]],
) -> tuple[bool, Optional[dict[str, Any]]]:
async_saves = 0
kv_txfer_params = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def build_connector_meta(
def request_finished(
self,
request: "Request",
block_ids: list[int],
block_ids: Union[tuple[list[int], ...], list[int]],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
Expand Down Expand Up @@ -397,7 +397,7 @@ def build_connector_meta(
def request_finished(
self,
request: "Request",
block_ids: list[int],
block_ids: Union[tuple[list[int], ...], list[int]],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Once a request is finished, determine whether request blocks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from itertools import islice
from typing import Any, Optional
from typing import Any, Optional, Union

import torch

Expand Down Expand Up @@ -108,7 +108,7 @@ def update_connector_output(self, connector_output: KVConnectorOutput):
def request_finished(
self,
request: "Request",
block_ids: list[int],
block_ids: Union[list[int], tuple[list[int], ...]],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
Expand Down Expand Up @@ -344,7 +344,7 @@ def update_connector_output(self, connector_output: KVConnectorOutput):
def request_finished(
self,
request: Request,
block_ids: list[int],
block_ids: Union[tuple[list[int], ...], list[int]],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Expand Down
33 changes: 26 additions & 7 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
import time
from collections import defaultdict
from collections.abc import Iterable
from copy import deepcopy
from typing import Any, Optional, Union

from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
KVConnectorRole,
supports_hma)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats)
from vllm.logger import init_logger
Expand Down Expand Up @@ -83,14 +85,22 @@ def __init__(
# KV Connector pushes/pull of remote KVs for P/D and offloading.
self.connector = None
if self.vllm_config.kv_transfer_config is not None:
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"Multiple KV cache groups are not currently supported "
"with KV connectors")
assert not self.is_encoder_decoder, (
"Encoder-decoder models are not currently supported "
"with KV connectors")

connector_vllm_config = deepcopy(self.vllm_config)
connector_vllm_config.kv_cache_config = kv_cache_config
self.connector = KVConnectorFactory.create_connector(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
config=connector_vllm_config, role=KVConnectorRole.SCHEDULER)

# Make sure that the connector supports HMA if HMA is enabled.
num_kv_cache_groups = len(self.kv_cache_config.kv_cache_groups)
if not supports_hma(self.connector) and num_kv_cache_groups > 1:
raise NotImplementedError(
f"Connector {self.connector.__class__.__name__} does not"
f" support HMA but HMA is enabled. Please set "
f"`--disable-hybrid-kv-cache-manager`.")

self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
Expand Down Expand Up @@ -1231,8 +1241,17 @@ def _connector_finished(
if self.connector is None:
return False, None

(block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
return self.connector.request_finished(request, block_ids)
num_kv_cache_groups = len(self.kv_cache_config.kv_cache_groups)

block_ids = self.kv_cache_manager.get_block_ids(request.request_id)

if not supports_hma(self.connector) or num_kv_cache_groups == 1:
# NOTE(Kuntai): this code path is a hack.
# We should remove this code path after all connectors
# support hybrid memory allocator.
return self.connector.request_finished(request, block_ids[0])
else:
return self.connector.request_finished(request, block_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to see this logic in the connectors module ...

def request_finished(
        connector: KVConnectorBase_V1,
        request: "Request",
        block_ids: tuple[list[int], ...],
    ) -> tuple[bool, dict[str, Any] | None]:
    if isinstance(connector, SupportsHMA):
        return connector.request_finished_all_groups(request, block_ids)
    else:  # for backwards compatibility
        return connector.request_finished(request, block_ids[0])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function _connector_finished is already a small wrapper function that contains < 10 LoC besides comments. Building one more wrapper on top of it may feel a bit over-abstracted.


def _update_waiting_for_remote_kv(self, request: Request) -> bool:
"""
Expand Down
12 changes: 10 additions & 2 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import gc
import os
from contextlib import AbstractContextManager, nullcontext
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional, Union

import torch
Expand Down Expand Up @@ -315,6 +316,15 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""

# Init kv cache connector here, because it requires
# `kv_cache_config`.
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
# because `initialize_kv_cache` will inject kv cache groups not
# related to kv cache connector (e.g. kv cache sharing layers).
connector_vllm_config = deepcopy(self.vllm_config)
connector_vllm_config.kv_cache_config = kv_cache_config
ensure_kv_transfer_initialized(connector_vllm_config)

if self.vllm_config.model_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator

Expand Down Expand Up @@ -714,5 +724,3 @@ def init_worker_distributed_environment(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.decode_context_parallel_size)

ensure_kv_transfer_initialized(vllm_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that other connectors (out of tree perhaps) might break by initializing earlier?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be fine because both locations are still before real model execution and CUDA graph capturing. So in terms of the ability of adding extra GPU operations before/after attention and before/after forwarding these two locations are the same.

Loading