-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Core][Hybrid allocator + kv connector 1/n] Enable hybrid allocator + KV cache connector #25712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 23 commits
1ded8ae
42040ba
fbaa51a
fae4c82
f858a9d
0aa2b01
89a976c
9cdd2b0
e0ac23c
312e065
b29a257
866c404
367b7b7
e35f118
7e963b9
650d666
37a589d
6abc1c2
8fc7bca
27774f3
1d7f75f
ababeec
1974b5f
9198d3e
eee8c11
c6e0bc4
919fe9b
0b67b76
4bfdcf8
36e42a1
6f6347c
0df4f02
5d88c0d
2fac4fb
4c724a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,6 +70,42 @@ | |
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class SupportsHMA: | ||
| """ | ||
| The class that indicates the corresponding connector supports hybrid memory | ||
| allocator (HMA). | ||
| This is required to use the connector together with hybrid memory allocator. | ||
| """ | ||
|
|
||
| def request_finished( | ||
| self, | ||
| request: "Request", | ||
| block_ids: tuple[list[int], ...], | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand why it's tempting to do this, but I think this sort of overloading can cause unnecessary confusion - how about making this more explicit by calling the method something like
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still need to think this for a while, but this might be a good idea
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @heheda12345 @NickLucche I have no preference personally. Do you guys have preference on having a new function |
||
| ) -> tuple[bool, dict[str, Any] | None]: | ||
| """ | ||
| Called exactly once when a request has finished, before its blocks are | ||
| freed. | ||
|
|
||
| The connector may assumes responsibility for freeing the blocks | ||
| asynchronously by returning True. | ||
|
|
||
| Returns: | ||
| True if the request is being saved/sent asynchronously and blocks | ||
| should not be freed until the request_id is returned from | ||
| get_finished(). | ||
| Optional KVTransferParams to be included in the request outputs | ||
| returned by the engine. | ||
| """ | ||
| return False, None | ||
|
|
||
|
|
||
| def supports_hma(connector: Any) -> bool: | ||
| if isinstance(connector, type): | ||
| return issubclass(connector, SupportsHMA) | ||
| else: | ||
| return isinstance(connector, SupportsHMA) | ||
|
|
||
|
|
||
| class KVConnectorRole(enum.Enum): | ||
| # Connector running in the scheduler process | ||
| SCHEDULER = 0 | ||
|
|
@@ -370,7 +406,7 @@ def request_finished( | |
| Called exactly once when a request has finished, before its blocks are | ||
| freed. | ||
|
|
||
| The connector may assumes responsibility for freeing the the blocks | ||
| The connector may assumes responsibility for freeing the blocks | ||
| asynchronously by returning True. | ||
|
|
||
| Returns: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| KVConnectorBase_V1, | ||
| KVConnectorMetadata, | ||
| KVConnectorRole, | ||
| SupportsHMA, | ||
| ) | ||
| from vllm.logger import init_logger | ||
| from vllm.v1.core.sched.output import SchedulerOutput | ||
|
|
@@ -23,8 +24,8 @@ | |
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class LMCacheConnectorV1(KVConnectorBase_V1): | ||
| def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): | ||
| class LMCacheConnectorV1(SupportsHMA, KVConnectorBase_V1): | ||
| def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): | ||
| super().__init__(vllm_config=vllm_config, role=role) | ||
| self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self) | ||
|
|
||
|
|
@@ -157,10 +158,10 @@ def build_connector_meta( | |
| """ | ||
| return self._lmcache_engine.build_connector_meta(scheduler_output) | ||
|
|
||
| def request_finished( | ||
| def request_finished( # type: ignore[override] | ||
| self, | ||
| request: "Request", | ||
| block_ids: list[int], | ||
| block_ids: tuple[list[int], ...], | ||
| ) -> tuple[bool, dict[str, Any] | None]: | ||
| """ | ||
| Called when a request has finished, before its blocks are freed. | ||
|
|
@@ -171,5 +172,10 @@ def request_finished( | |
| get_finished(). | ||
| Optional KVTransferParams to be included in the request outputs | ||
| returned by the engine. | ||
|
|
||
| Note: | ||
| This method intentionally uses tuple[list[int], ...] from | ||
| SupportsHMA interface instead of list[int] from | ||
| KVConnectorBase_V1 to support hybrid memory allocation. | ||
| """ | ||
| return self._lmcache_engine.request_finished(request, block_ids) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do we know whether the new argument format is supported by the lmcache code that's installed? I'd expect a version check or some sort of capability check here
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I see LMCache/LMCache#1436 now - you need to require this version somehow? The old version will blow up if you pass it the tuple?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LMCache is not relying on
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really don't follow - this PR will work with older lmcache versions?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import copy | ||
| import itertools | ||
| import time | ||
| from collections import defaultdict | ||
|
|
@@ -13,6 +13,7 @@ | |
| from vllm.distributed.kv_transfer.kv_connector.v1 import ( | ||
| KVConnectorBase_V1, | ||
| KVConnectorRole, | ||
| supports_hma, | ||
| ) | ||
| from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats | ||
| from vllm.logger import init_logger | ||
|
|
@@ -85,15 +86,14 @@ def __init__( | |
| # KV Connector pushes/pull of remote KVs for P/D and offloading. | ||
| self.connector = None | ||
| if self.vllm_config.kv_transfer_config is not None: | ||
| assert len(self.kv_cache_config.kv_cache_groups) == 1, ( | ||
| "Multiple KV cache groups are not currently supported " | ||
| "with KV connectors" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be good to move this assertion into the backwards compat code
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar assertion is done during initializing the connector. It is better to assert during initialization instead of |
||
| ) | ||
| assert not self.is_encoder_decoder, ( | ||
| "Encoder-decoder models are not currently supported with KV connectors" | ||
| ) | ||
|
|
||
| connector_vllm_config = copy.copy(self.vllm_config) | ||
| connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) | ||
| self.connector = KVConnectorFactory.create_connector( | ||
| config=self.vllm_config, role=KVConnectorRole.SCHEDULER | ||
| config=connector_vllm_config, role=KVConnectorRole.SCHEDULER | ||
| ) | ||
|
|
||
| self.kv_event_publisher = EventPublisherFactory.create( | ||
|
|
@@ -1296,8 +1296,14 @@ def _connector_finished( | |
| if self.connector is None: | ||
| return False, None | ||
|
|
||
| (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) | ||
| return self.connector.request_finished(request, block_ids) | ||
| block_ids = self.kv_cache_manager.get_block_ids(request.request_id) | ||
|
|
||
| if not supports_hma(self.connector): | ||
| # NOTE(Kuntai): We should remove this code path after we enforce | ||
| # all connectors to support HMA. | ||
| return self.connector.request_finished(request, block_ids[0]) | ||
| else: | ||
| return self.connector.request_finished(request, block_ids) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer to see this logic in the connectors module ...
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function |
||
|
|
||
| def _update_waiting_for_remote_kv(self, request: Request) -> bool: | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -327,6 +327,15 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: | |
| def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: | ||
| """Allocate GPU KV cache with the specified kv_cache_config.""" | ||
|
|
||
| # Init kv cache connector here, because it requires | ||
| # `kv_cache_config`. | ||
| # NOTE(Kuntai): This need to be done before `initialize_kv_cache`, | ||
| # because `initialize_kv_cache` will inject kv cache groups not | ||
| # related to kv cache connector (e.g. kv cache sharing layers). | ||
KuntaiDu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| connector_vllm_config = copy.copy(self.vllm_config) | ||
| connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're using a copy of Why not just make Or could we instead supply the connector the layer ID/name to KV cache group ID mapping? Will all connectors need this mapping to support HMA?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two reasons of not initializing using
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Backwards compat is tricky, but we can't allow ourselves to be trapped in a situation where any new data for e.g. we could add an
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be stuffed into I understand that allowing arbitrary field injection is generally not ideal, but it aligns with the design goal of
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm only just seeing this now but agree with @markmc this is hacky and I don't think we should have merged it. Either include the changes to add that field to VllmConfig or we could use introspection on the connector for backwards compatibility. |
||
| ensure_kv_transfer_initialized(connector_vllm_config) | ||
|
|
||
| if self.vllm_config.model_config.enable_sleep_mode: | ||
| from vllm.device_allocator.cumem import CuMemAllocator | ||
|
|
||
|
|
@@ -779,5 +788,3 @@ def init_worker_distributed_environment( | |
| parallel_config.pipeline_parallel_size, | ||
| parallel_config.decode_context_parallel_size, | ||
| ) | ||
|
|
||
| ensure_kv_transfer_initialized(vllm_config) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible that other connectors (out of tree perhaps) might break by initializing earlier?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be fine because both locations are still before real model execution and CUDA graph capturing. So in terms of the ability of adding extra GPU operations before/after attention and before/after forwarding these two locations are the same. |
||
Uh oh!
There was an error while loading. Please reload this page.