-
-
Notifications
You must be signed in to change notification settings - Fork 13k
[KV offload][5/N] Add CPUOffloadingSpec
#24251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import time | ||
|
|
||
| import pytest | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
| from vllm.config import KVTransferConfig | ||
|
|
||
| CPU_BLOCK_SIZES = [16, 48] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES) | ||
| def test_cpu_offloading(cpu_block_size: int) -> None: | ||
| """ | ||
| Tests OffloadingConnector with CPUOffloadingSpec. | ||
| """ | ||
|
|
||
| # configure OffloadingConnector (spec_name=CPUOffloadingSpec by default) | ||
| kv_transfer_config = KVTransferConfig( | ||
| kv_connector="OffloadingConnector", | ||
| kv_role="kv_both", | ||
| kv_connector_extra_config={ | ||
| "num_cpu_blocks": 100, | ||
| "block_size": cpu_block_size | ||
| }, | ||
| ) | ||
|
|
||
| llm = LLM( | ||
| model="meta-llama/Llama-3.2-1B-Instruct", | ||
| gpu_memory_utilization=0.5, | ||
| kv_transfer_config=kv_transfer_config, | ||
| ) | ||
|
|
||
| prompts = ["Hi " * 100] | ||
| sampling_params = SamplingParams(temperature=0, max_tokens=20) | ||
|
|
||
| # run generation - this should trigger saving KV cache | ||
| start_time = time.time() | ||
| llm.generate(prompts, sampling_params, use_tqdm=False) | ||
| cold_time = time.time() - start_time | ||
|
|
||
| # run generation again - should hit the GPU prefix cache | ||
| start_time = time.time() | ||
| llm.generate(prompts, sampling_params, use_tqdm=False) | ||
| gpu_hit_time = time.time() - start_time | ||
|
|
||
| # reset prefix cache to avoid GPU hit. | ||
| llm.reset_prefix_cache() | ||
|
|
||
| # sleep for a sec to make sure CPU finished storing | ||
| time.sleep(1) | ||
|
|
||
| # run generation again - this should trigger loading from CPU | ||
| start_time = time.time() | ||
| llm.generate(prompts, sampling_params, use_tqdm=False) | ||
| cpu_hit_time = time.time() - start_time | ||
|
|
||
| print("Generation times:") | ||
| print(f" Cold: {cold_time * 1000:.2f}ms") | ||
| print(f" GPU hit: {gpu_hit_time * 1000:.2f}ms") | ||
| print(f" CPU hit: {cpu_hit_time * 1000:.2f}ms") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from collections.abc import Iterator | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.config import VllmConfig, get_layers_from_vllm_config | ||
| from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase | ||
| from vllm.platforms import current_platform | ||
| from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager | ||
| from vllm.v1.kv_offload.backends.cpu import CPUBackend | ||
| from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager | ||
| from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec | ||
| from vllm.v1.kv_offload.spec import OffloadingSpec | ||
| from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler | ||
| from vllm.v1.kv_offload.worker.worker import OffloadingHandler | ||
|
|
||
|
|
||
| class CPUOffloadingSpec(OffloadingSpec): | ||
|
|
||
| def __init__(self, vllm_config: VllmConfig): | ||
| super().__init__(vllm_config) | ||
|
|
||
| num_cpu_blocks = self.extra_config.get("num_cpu_blocks") | ||
| if not num_cpu_blocks: | ||
| raise Exception("num_cpu_blocks must be specified " | ||
| "in kv_connector_extra_config") | ||
| self.num_cpu_blocks: int = num_cpu_blocks | ||
|
|
||
| # scheduler-side | ||
| self._manager: Optional[OffloadingManager] = None | ||
|
|
||
| # worker-side | ||
| self._handler: Optional[OffloadingHandler] = None | ||
|
|
||
| def get_manager(self) -> OffloadingManager: | ||
| if not self._manager: | ||
| kv_events_config = self.vllm_config.kv_events_config | ||
| enable_events = (kv_events_config is not None | ||
| and kv_events_config.enable_kv_cache_events) | ||
| self._manager = LRUOffloadingManager(CPUBackend( | ||
| block_size=self.offloaded_block_size, | ||
| num_blocks=self.num_cpu_blocks), | ||
| enable_events=enable_events) | ||
| return self._manager | ||
|
|
||
| def get_handlers( | ||
| self, kv_caches: dict[str, torch.Tensor] | ||
| ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], | ||
| OffloadingHandler]]: | ||
| if not self._handler: | ||
| if not current_platform.is_cuda(): | ||
| raise Exception("CPU Offloading is currently only supported" | ||
| " on CUDA GPUs") | ||
|
|
||
| layer_names = list(kv_caches.keys()) | ||
| layers = get_layers_from_vllm_config(self.vllm_config, | ||
| AttentionLayerBase, | ||
| layer_names) | ||
| attn_backends = { | ||
| layer_name: layers[layer_name].get_attn_backend() | ||
| for layer_name in layer_names | ||
| } | ||
|
|
||
| self._handler = CpuGpuOffloadingHandler( | ||
| attn_backends=attn_backends, | ||
| gpu_block_size=self.gpu_block_size, | ||
| cpu_block_size=self.offloaded_block_size, | ||
| num_cpu_blocks=self.num_cpu_blocks, | ||
| gpu_caches=kv_caches) | ||
|
|
||
| assert self._handler is not None | ||
| yield GPULoadStoreSpec, CPULoadStoreSpec, self._handler | ||
| yield CPULoadStoreSpec, GPULoadStoreSpec, self._handler |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to add some assertions to the test such that it will fail if the offload is not working? Should probably also verify correctness in conjunction with this.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a correctness unit test for the transfer function (
test_cpu_gpu.py).Also there is a correctness unit test that the offloading connector generates the correct transfer addresses of the GPU and the offloaded medium.
I don't know how to we can test correctness e2e.
Currently the test here just checks that prompt generation does not crash when using cpu offloading.
It does not verify that any offloading actually occurs.
One way we can verify this is by adding a
kv_events_config(like intest_kv_cache_events) that will check for KVEvents with the CPU medium.I actually started coding that but saw that it is a bit cumbersome, so I decided to defer this to see if others think it's worthwhile.
Another option is to verify latency decreases when we're supposed to hit the cpu cache (after resetting the GPU prefix cache).
We can decrease the variance by, say, repeat this 100 times and verify that at least 70 times the latency decreased.
This will actually be easy to implement (comparing to the KVEvents test).
My concern is that even when repeating the test multiple (e.g. 100) times, it can still be flakey.
Your thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @orozery yes I was thinking perhaps at least some kind of latency comparison but I agree timing tests are fragile / generally not a good idea. If the magnitude of the difference is large enough perhaps it wouldn't need so many attempts, maybe just a handful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merging this to ensure that it makes the release but we might want to think a bit more about the e2e CI tests.
Thanks again for all of your hard work @orozery!