-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[Core][Hybrid allocator + kv connector 1/n] Enable hybrid allocator + KV cache connector #25712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
1ded8ae
42040ba
fbaa51a
fae4c82
f858a9d
0aa2b01
89a976c
9cdd2b0
e0ac23c
312e065
b29a257
866c404
367b7b7
e35f118
7e963b9
650d666
37a589d
6abc1c2
8fc7bca
27774f3
1d7f75f
ababeec
1974b5f
9198d3e
eee8c11
c6e0bc4
919fe9b
0b67b76
4bfdcf8
36e42a1
6f6347c
0df4f02
5d88c0d
2fac4fb
4c724a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from vllm.distributed.kv_transfer.kv_connector.v1.base import ( | ||
| KVConnectorBase_V1, KVConnectorRole) | ||
| KVConnectorBase_V1, KVConnectorRole, supports_hma) | ||
|
|
||
| __all__ = ["KVConnectorRole", "KVConnectorBase_V1"] | ||
| __all__ = ["KVConnectorRole", "KVConnectorBase_V1", "supports_hma"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,14 +7,16 @@ | |
| import time | ||
| from collections import defaultdict | ||
| from collections.abc import Iterable | ||
| from copy import deepcopy | ||
| from typing import Any, Optional, Union | ||
|
|
||
| from vllm.config import VllmConfig | ||
| from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch | ||
| from vllm.distributed.kv_transfer.kv_connector.factory import ( | ||
| KVConnectorFactory) | ||
| from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, | ||
| KVConnectorRole) | ||
| KVConnectorRole, | ||
| supports_hma) | ||
| from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( | ||
| KVConnectorStats) | ||
| from vllm.logger import init_logger | ||
|
|
@@ -83,14 +85,22 @@ def __init__( | |
| # KV Connector pushes/pull of remote KVs for P/D and offloading. | ||
| self.connector = None | ||
| if self.vllm_config.kv_transfer_config is not None: | ||
| assert len(self.kv_cache_config.kv_cache_groups) == 1, ( | ||
| "Multiple KV cache groups are not currently supported " | ||
| "with KV connectors") | ||
| assert not self.is_encoder_decoder, ( | ||
| "Encoder-decoder models are not currently supported " | ||
| "with KV connectors") | ||
|
|
||
| connector_vllm_config = deepcopy(self.vllm_config) | ||
| connector_vllm_config.kv_cache_config = kv_cache_config | ||
KuntaiDu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.connector = KVConnectorFactory.create_connector( | ||
| config=self.vllm_config, role=KVConnectorRole.SCHEDULER) | ||
| config=connector_vllm_config, role=KVConnectorRole.SCHEDULER) | ||
KuntaiDu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Make sure that the connector supports HMA if HMA is enabled. | ||
| num_kv_cache_groups = len(self.kv_cache_config.kv_cache_groups) | ||
| if not supports_hma(self.connector) and num_kv_cache_groups > 1: | ||
| raise NotImplementedError( | ||
| f"Connector {self.connector.__class__.__name__} does not" | ||
| f" support HMA but HMA is enabled. Please set " | ||
| f"`--disable-hybrid-kv-cache-manager`.") | ||
KuntaiDu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| self.kv_event_publisher = EventPublisherFactory.create( | ||
| self.kv_events_config, | ||
|
|
@@ -1231,8 +1241,17 @@ def _connector_finished( | |
| if self.connector is None: | ||
| return False, None | ||
|
|
||
| (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) | ||
| return self.connector.request_finished(request, block_ids) | ||
| num_kv_cache_groups = len(self.kv_cache_config.kv_cache_groups) | ||
|
|
||
| block_ids = self.kv_cache_manager.get_block_ids(request.request_id) | ||
|
|
||
| if not supports_hma(self.connector) or num_kv_cache_groups == 1: | ||
KuntaiDu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # NOTE(Kuntai): this code path is a hack. | ||
KuntaiDu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # We should remove this code path after all connectors | ||
| # support hybrid memory allocator. | ||
| return self.connector.request_finished(request, block_ids[0]) | ||
| else: | ||
| return self.connector.request_finished(request, block_ids) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer to see this logic in the connectors module ...
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function |
||
|
|
||
| def _update_waiting_for_remote_kv(self, request: Request) -> bool: | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| import gc | ||
| import os | ||
| from contextlib import AbstractContextManager, nullcontext | ||
| from copy import deepcopy | ||
| from typing import TYPE_CHECKING, Any, Optional, Union | ||
|
|
||
| import torch | ||
|
|
@@ -315,6 +316,15 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: | |
| def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: | ||
| """Allocate GPU KV cache with the specified kv_cache_config.""" | ||
|
|
||
| # Init kv cache connector here, because it requires | ||
| # `kv_cache_config`. | ||
| # NOTE(Kuntai): This need to be done before `initialize_kv_cache`, | ||
| # because `initialize_kv_cache` will inject kv cache groups not | ||
| # related to kv cache connector (e.g. kv cache sharing layers). | ||
KuntaiDu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| connector_vllm_config = deepcopy(self.vllm_config) | ||
| connector_vllm_config.kv_cache_config = kv_cache_config | ||
KuntaiDu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ensure_kv_transfer_initialized(connector_vllm_config) | ||
KuntaiDu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if self.vllm_config.model_config.enable_sleep_mode: | ||
| from vllm.device_allocator.cumem import CuMemAllocator | ||
|
|
||
|
|
@@ -714,5 +724,3 @@ def init_worker_distributed_environment( | |
| parallel_config.tensor_parallel_size, | ||
| parallel_config.pipeline_parallel_size, | ||
| parallel_config.decode_context_parallel_size) | ||
|
|
||
| ensure_kv_transfer_initialized(vllm_config) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible that other connectors (out of tree perhaps) might break by initializing earlier?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be fine because both locations are still before real model execution and CUDA graph capturing. So in terms of the ability of adding extra GPU operations before/after attention and before/after forwarding these two locations are the same. |
||
Uh oh!
There was an error while loading. Please reload this page.