Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/features/disagg_prefill.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ Now supports 5 types of connectors:
--kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}'
```

- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker):

```bash
--kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 64, "num_cpu_blocks": 1000}}'
```

## Benchmarks

Please refer to <gh-file:benchmarks/disagg_benchmarks> for disaggregated prefilling benchmarks.
Expand Down
62 changes: 62 additions & 0 deletions tests/v1/kv_offload/test_cpu_offloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time

import pytest

from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig

CPU_BLOCK_SIZES = [16, 48]


@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES)
def test_cpu_offloading(cpu_block_size: int) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to add some assertions to the test such that it will fail if the offload is not working? Should probably also verify correctness in conjunction with this.

Copy link
Collaborator Author

@orozery orozery Sep 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a correctness unit test for the transfer function (test_cpu_gpu.py).
Also there is a correctness unit test that the offloading connector generates the correct transfer addresses of the GPU and the offloaded medium.
I don't know how to we can test correctness e2e.

Currently the test here just checks that prompt generation does not crash when using cpu offloading.
It does not verify that any offloading actually occurs.

One way we can verify this is by adding a kv_events_config (like in test_kv_cache_events) that will check for KVEvents with the CPU medium.
I actually started coding that but saw that it is a bit cumbersome, so I decided to defer this to see if others think it's worthwhile.

Another option is to verify latency decreases when we're supposed to hit the cpu cache (after resetting the GPU prefix cache).
We can decrease the variance by, say, repeat this 100 times and verify that at least 70 times the latency decreased.
This will actually be easy to implement (comparing to the KVEvents test).
My concern is that even when repeating the test multiple (e.g. 100) times, it can still be flakey.
Your thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @orozery yes I was thinking perhaps at least some kind of latency comparison but I agree timing tests are fragile / generally not a good idea. If the magnitude of the difference is large enough perhaps it wouldn't need so many attempts, maybe just a handful?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging this to ensure that it makes the release but we might want to think a bit more about the e2e CI tests.

Thanks again for all of your hard work @orozery!

"""
Tests OffloadingConnector with CPUOffloadingSpec.
"""

# configure OffloadingConnector (spec_name=CPUOffloadingSpec by default)
kv_transfer_config = KVTransferConfig(
kv_connector="OffloadingConnector",
kv_role="kv_both",
kv_connector_extra_config={
"num_cpu_blocks": 100,
"block_size": cpu_block_size
},
)

llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5,
kv_transfer_config=kv_transfer_config,
)

prompts = ["Hi " * 100]
sampling_params = SamplingParams(temperature=0, max_tokens=20)

# run generation - this should trigger saving KV cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cold_time = time.time() - start_time

# run generation again - should hit the GPU prefix cache
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
gpu_hit_time = time.time() - start_time

# reset prefix cache to avoid GPU hit.
llm.reset_prefix_cache()

# sleep for a sec to make sure CPU finished storing
time.sleep(1)

# run generation again - this should trigger loading from CPU
start_time = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cpu_hit_time = time.time() - start_time

print("Generation times:")
print(f" Cold: {cold_time * 1000:.2f}ms")
print(f" GPU hit: {gpu_hit_time * 1000:.2f}ms")
print(f" CPU hit: {cpu_hit_time * 1000:.2f}ms")
75 changes: 75 additions & 0 deletions vllm/v1/kv_offload/cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterator
from typing import Optional

import torch

from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.platforms import current_platform
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.backends.cpu import CPUBackend
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.kv_offload.spec import OffloadingSpec
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler
from vllm.v1.kv_offload.worker.worker import OffloadingHandler


class CPUOffloadingSpec(OffloadingSpec):

def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)

num_cpu_blocks = self.extra_config.get("num_cpu_blocks")
if not num_cpu_blocks:
raise Exception("num_cpu_blocks must be specified "
"in kv_connector_extra_config")
self.num_cpu_blocks: int = num_cpu_blocks

# scheduler-side
self._manager: Optional[OffloadingManager] = None

# worker-side
self._handler: Optional[OffloadingHandler] = None

def get_manager(self) -> OffloadingManager:
if not self._manager:
kv_events_config = self.vllm_config.kv_events_config
enable_events = (kv_events_config is not None
and kv_events_config.enable_kv_cache_events)
self._manager = LRUOffloadingManager(CPUBackend(
block_size=self.offloaded_block_size,
num_blocks=self.num_cpu_blocks),
enable_events=enable_events)
return self._manager

def get_handlers(
self, kv_caches: dict[str, torch.Tensor]
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
OffloadingHandler]]:
if not self._handler:
if not current_platform.is_cuda():
raise Exception("CPU Offloading is currently only supported"
" on CUDA GPUs")

layer_names = list(kv_caches.keys())
layers = get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase,
layer_names)
attn_backends = {
layer_name: layers[layer_name].get_attn_backend()
for layer_name in layer_names
}

self._handler = CpuGpuOffloadingHandler(
attn_backends=attn_backends,
gpu_block_size=self.gpu_block_size,
cpu_block_size=self.offloaded_block_size,
num_cpu_blocks=self.num_cpu_blocks,
gpu_caches=kv_caches)

assert self._handler is not None
yield GPULoadStoreSpec, CPULoadStoreSpec, self._handler
yield CPULoadStoreSpec, GPULoadStoreSpec, self._handler
3 changes: 3 additions & 0 deletions vllm/v1/kv_offload/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,6 @@ def create_spec(


# Register various specs here.
OffloadingSpecFactory.register_spec("CPUOffloadingSpec",
"vllm.v1.kv_offload.cpu",
"CPUOffloadingSpec")