Skip to content

Commit 7ea3f64

Browse files
committed
[DCP] Support dcp kv_cache interleave size > 1
Signed-off-by: zhangsicheng5 <[email protected]>
1 parent 782505e commit 7ea3f64

File tree

9 files changed

+109
-11
lines changed

9 files changed

+109
-11
lines changed

tests/distributed/test_context_parallel.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ParallelSetup(NamedTuple):
3030
tp_size: int
3131
pp_size: int
3232
dcp_size: int
33+
cp_kv_cache_interleave_size: int
3334
eager_mode: bool
3435
chunked_prefill: bool
3536

@@ -52,6 +53,7 @@ def detailed(
5253
tp_base: int = 4,
5354
pp_base: int = 1,
5455
dcp_base: int = 1,
56+
cp_kv_cache_interleave_size: int = 1,
5557
multi_node_only: bool = False,
5658
runner: RunnerOption = "auto",
5759
load_format: str | None = None,
@@ -66,6 +68,7 @@ def detailed(
6668
tp_size=tp_base,
6769
pp_size=pp_multiplier * pp_base,
6870
dcp_size=int(dcp_multiplier * tp_base),
71+
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
6972
eager_mode=eager_mode_val,
7073
chunked_prefill=chunked_prefill_val,
7174
)
@@ -108,6 +111,7 @@ def _compare_cp_with_tp(
108111
tp_size,
109112
pp_size,
110113
dcp_size,
114+
cp_kv_cache_interleave_size,
111115
eager_mode,
112116
chunked_prefill,
113117
) = parallel_setup
@@ -180,6 +184,8 @@ def _compare_cp_with_tp(
180184
str(pp_size),
181185
"--decode-context-parallel-size",
182186
str(dcp_size),
187+
"--cp-kv-cache-interleave-size",
188+
str(cp_kv_cache_interleave_size),
183189
"--distributed-executor-backend",
184190
distributed_backend,
185191
]
@@ -208,6 +214,7 @@ def _compare_cp_with_tp(
208214
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
209215
CPTestSettings.detailed(),
210216
CPTestSettings.detailed(tp_base=2),
217+
CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64),
211218
],
212219
}
213220

vllm/config/parallel.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,17 @@ class is dynamically inherited by the worker class. This is used to inject
204204
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
205205
needs to be divisible by dcp_size."""
206206

207+
cp_kv_cache_interleave_size: int = 1
208+
"""Interleave size of kv_cache storage while using dcp or cp > 1,
209+
store interleave_size tokens on (d)cp i,
210+
then store next interleave_size tokens on (d)cp i+1.
211+
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
212+
Interleave_size=block_size: block-level align, first fill the block on first rank,
213+
token is stored on rank i+1 block j after rank i block j is full.
214+
Block_size should be greater than or equal to cp_kv_cache_interleave_size.
215+
Block_size should be divisible by cp_kv_cache_interleave_size.
216+
"""
217+
207218
_api_process_count: int = Field(default=1, gt=0)
208219
"""
209220
The number of API processes initialized.

vllm/config/vllm.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,23 @@ def __post_init__(self):
472472
)
473473
current_platform.check_and_update_config(self)
474474

475+
assert (
476+
self.parallel_config.cp_kv_cache_interleave_size
477+
<= self.cache_config.block_size
478+
and self.cache_config.block_size
479+
% self.parallel_config.cp_kv_cache_interleave_size
480+
== 0
481+
), (
482+
f"Block_size({self.cache_config.block_size}) should be "
483+
"greater than or equal to and divisible by cp_kv_cache_interleave_size "
484+
f"({self.parallel_config.cp_kv_cache_interleave_size})."
485+
)
486+
487+
assert (
488+
self.parallel_config.cp_kv_cache_interleave_size == 1
489+
or self.speculative_config is None
490+
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
491+
475492
# Do this after all the updates to compilation_config.level
476493
if (
477494
envs.VLLM_USE_V1

vllm/engine/arg_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ class EngineArgs:
362362
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
363363
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
364364
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
365+
cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
365366
data_parallel_size: int = ParallelConfig.data_parallel_size
366367
data_parallel_rank: int | None = None
367368
data_parallel_start_rank: int | None = None
@@ -715,6 +716,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
715716
"-dcp",
716717
**parallel_kwargs["decode_context_parallel_size"],
717718
)
719+
parallel_group.add_argument(
720+
"--cp-kv-cache-interleave-size",
721+
**parallel_kwargs["cp_kv_cache_interleave_size"],
722+
)
718723
parallel_group.add_argument(
719724
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
720725
)
@@ -1470,6 +1475,7 @@ def create_engine_config(
14701475
worker_cls=self.worker_cls,
14711476
worker_extension_cls=self.worker_extension_cls,
14721477
decode_context_parallel_size=self.decode_context_parallel_size,
1478+
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
14731479
_api_process_count=self._api_process_count,
14741480
_api_process_rank=self._api_process_rank,
14751481
)

vllm/v1/attention/backends/mla/common.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -749,15 +749,6 @@ def build(
749749
)
750750
)
751751

752-
# Note(hc): update seq_lens of decode reqs under DCP.
753-
if self.dcp_world_size > 1:
754-
assert dcp_local_seq_lens is not None
755-
dcp_local_seq_lens[:num_decodes] = seq_lens[
756-
:num_decodes
757-
] // self.dcp_world_size + (
758-
self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size
759-
)
760-
761752
assert num_decodes + num_prefills == num_reqs
762753
assert num_decode_tokens + num_prefill_tokens == num_tokens
763754

vllm/v1/attention/backends/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,3 +991,35 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
991991
nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore
992992

993993
return nums_dict, batch_ptr, token_chunk_offset_ptr
994+
995+
996+
def get_dcp_local_seq_lens(
997+
seq_lens: torch.Tensor,
998+
dcp_world_size: int = 1,
999+
cp_kv_cache_interleave_size: int = 1,
1000+
) -> torch.Tensor:
1001+
"""While using dcp, kv_cache size stored on each rank may be different,
1002+
use this function to calculate split decode seq_lens of each dcp rank.
1003+
Only consider dcp now, we can extend the case of cp based on this.
1004+
"""
1005+
num_requests = seq_lens.size(0)
1006+
seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, dcp_world_size)
1007+
rank_offsets = (
1008+
torch.arange(dcp_world_size, dtype=torch.int32)
1009+
.unsqueeze(0)
1010+
.repeat(num_requests, 1)
1011+
)
1012+
base = (
1013+
seq_lens_tiled
1014+
// cp_kv_cache_interleave_size
1015+
// dcp_world_size
1016+
* cp_kv_cache_interleave_size
1017+
)
1018+
remainder = seq_lens_tiled - base * dcp_world_size
1019+
remainder = torch.clip(
1020+
remainder - rank_offsets * cp_kv_cache_interleave_size,
1021+
0,
1022+
cp_kv_cache_interleave_size,
1023+
)
1024+
dcp_local_seq_lens = base + remainder
1025+
return dcp_local_seq_lens

vllm/v1/worker/block_table.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
pin_memory: bool,
2323
device: torch.device,
2424
kernel_block_size: int,
25+
cp_kv_cache_interleave_size: int,
2526
):
2627
"""
2728
Args:
@@ -86,6 +87,7 @@ def __init__(
8687
# DCP might not be initialized in testing
8788
self.dcp_world_size = 1
8889
self.dcp_rank = 0
90+
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
8991

9092
def append_row(
9193
self,
@@ -144,9 +146,19 @@ def compute_slot_mapping(
144146
# Use virtual_block_size for mask calculation, which marks local
145147
# tokens.
146148
virtual_block_offsets = positions % virtual_block_size
147-
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
149+
mask = (
150+
virtual_block_offsets
151+
// self.cp_kv_cache_interleave_size
152+
% self.dcp_world_size
153+
== self.dcp_rank
154+
)
148155
# Calculate local block_offsets
149-
block_offsets = virtual_block_offsets // self.dcp_world_size
156+
block_offsets = (
157+
virtual_block_offsets
158+
// (self.dcp_world_size * self.cp_kv_cache_interleave_size)
159+
* self.cp_kv_cache_interleave_size
160+
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
161+
)
150162
# Calculate slot_mapping
151163
slot_mapping = block_numbers * self.block_size + block_offsets
152164
# Write final slots, use -1 for not-local
@@ -234,6 +246,7 @@ def __init__(
234246
block_sizes: list[int],
235247
kernel_block_sizes: list[int],
236248
num_speculative_tokens: int = 0,
249+
cp_kv_cache_interleave_size: int = 1,
237250
) -> None:
238251
# Note(hc): each dcp rank only store
239252
# (max_model_len//dcp_world_size) tokens in kvcache,
@@ -263,6 +276,7 @@ def __init__(
263276
pin_memory,
264277
device,
265278
kernel_block_size,
279+
cp_kv_cache_interleave_size,
266280
)
267281
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
268282
]

vllm/v1/worker/gpu_input_batch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(
8383
is_spec_decode: bool = False,
8484
is_pooling_model: bool = False,
8585
num_speculative_tokens: int = 0,
86+
cp_kv_cache_interleave_size: int = 1,
8687
):
8788
self.is_pooling_model = is_pooling_model
8889
self.is_spec_decode = is_spec_decode
@@ -135,6 +136,7 @@ def __init__(
135136
block_sizes=block_sizes,
136137
kernel_block_sizes=kernel_block_sizes,
137138
num_speculative_tokens=num_speculative_tokens,
139+
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
138140
)
139141

140142
# Sampling-related.

vllm/v1/worker/gpu_model_runner.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
3636
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
3737
from vllm.distributed.parallel_state import (
38+
get_dcp_group,
3839
get_pp_group,
3940
get_tp_group,
4041
graph_capture,
@@ -92,6 +93,7 @@
9293
AttentionMetadataBuilder,
9394
CommonAttentionMetadata,
9495
create_fast_prefill_custom_backend,
96+
get_dcp_local_seq_lens,
9597
reorder_batch_to_split_decodes_and_prefills,
9698
split_attn_metadata,
9799
)
@@ -256,6 +258,11 @@ def __init__(
256258
self.is_multimodal_pruning_enabled = False
257259
self.max_model_len = model_config.max_model_len
258260
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
261+
try:
262+
self.dcp_rank = get_dcp_group().rank_in_group
263+
except AssertionError:
264+
# DCP might not be initialized in testing
265+
self.dcp_rank = 0
259266
self.max_num_tokens = scheduler_config.max_num_batched_tokens
260267
self.max_num_reqs = scheduler_config.max_num_seqs
261268

@@ -372,6 +379,8 @@ def __init__(
372379
# uses output token ids so we set this conservatively.
373380
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
374381
is_pooling_model=self.is_pooling_model,
382+
cp_kv_cache_interleave_size=
383+
self.parallel_config.cp_kv_cache_interleave_size,
375384
)
376385

377386
self.use_async_scheduling = self.scheduler_config.async_scheduling
@@ -1276,6 +1285,15 @@ def _prepare_inputs(
12761285
logits_indices
12771286
)
12781287

1288+
# update seq_lens of decode reqs under DCP.
1289+
if self.dcp_world_size > 1:
1290+
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
1291+
self.seq_lens.cpu[:num_reqs],
1292+
self.dcp_world_size,
1293+
self.parallel_config.cp_kv_cache_interleave_size,
1294+
)[:, self.dcp_rank]
1295+
self.dcp_local_seq_lens.copy_to_gpu(num_reqs)
1296+
12791297
attn_metadata: PerLayerAttnMetadata = {}
12801298
if ubatch_slices is not None:
12811299
attn_metadata = [dict() for _ in range(len(ubatch_slices))]

0 commit comments

Comments
 (0)