Skip to content

Commit 58cbd8f

Browse files
pisceskkkFENPLookAround0301Jingchun Gaozhenwenqi2024
committed
[PCP] common supports for PCP
Co-authored-by: QiuChunshuo <[email protected]> Co-authored-by: FENP <[email protected]> Co-authored-by: LookAround <[email protected]> Co-authored-by: Jingchun Gao <[email protected]> Co-authored-by: zhenwenqi2024 <[email protected]> Signed-off-by: QiuChunshuo <[email protected]> Signed-off-by: FENP <[email protected]> Signed-off-by: LookAround <[email protected]> Signed-off-by: Jingchun Gao <[email protected]> Signed-off-by: zhenwenqi2024 <[email protected]>
1 parent 30873d6 commit 58cbd8f

File tree

7 files changed

+455
-354
lines changed

7 files changed

+455
-354
lines changed

tests/distributed/test_context_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def _compare_cp_with_tp(
196196
str(pp_size),
197197
"--decode-context-parallel-size",
198198
str(dcp_size),
199-
"--dcp-kv-cache-interleave-size",
199+
"--cp-kv-cache-interleave-size",
200200
str(cp_kv_cache_interleave_size),
201201
"--distributed-executor-backend",
202202
distributed_backend,

vllm/config/parallel.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,11 +325,6 @@ def _validate_parallel_config(self) -> Self:
325325
"num_redundant_experts."
326326
)
327327

328-
if self.prefill_context_parallel_size > 1:
329-
raise ValueError(
330-
"Prefill context parallelism is not fully supported. "
331-
"Please set prefill_context_parallel_size to 1."
332-
)
333328
return self
334329

335330
@property

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
AttentionCGSupport,
4646
AttentionMetadataBuilder,
4747
CommonAttentionMetadata,
48-
get_dcp_local_seq_lens,
48+
get_cp_local_seq_lens,
4949
get_kv_cache_layout,
5050
)
5151
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -384,7 +384,7 @@ def schedule(
384384
)
385385
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
386386

387-
dcp_context_kv_lens_cpu = get_dcp_local_seq_lens(
387+
dcp_context_kv_lens_cpu = get_cp_local_seq_lens(
388388
dcp_context_kv_lens_cpu,
389389
self.dcp_world_size,
390390
self.dcp_rank,

vllm/v1/attention/backends/mla/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@
225225
from vllm.v1.attention.backends.utils import (
226226
AttentionMetadataBuilder,
227227
CommonAttentionMetadata,
228-
get_dcp_local_seq_lens,
228+
get_cp_local_seq_lens,
229229
get_per_layer_parameters,
230230
infer_global_hyperparameters,
231231
split_decodes_and_prefills,
@@ -831,7 +831,7 @@ def build(
831831
)
832832

833833
if self.dcp_world_size > 1:
834-
local_context_lens_allranks = get_dcp_local_seq_lens(
834+
local_context_lens_allranks = get_cp_local_seq_lens(
835835
context_lens_cpu,
836836
self.dcp_world_size,
837837
None,

vllm/v1/attention/backends/utils.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,19 @@
4848
def is_valid_kv_cache_layout(value: str) -> bool:
4949
return value in get_args(KVCacheLayoutType)
5050

51+
@dataclass
52+
class PrefillContextParallelMetadata:
53+
"""
54+
Attention metadata for prefill context parallel
55+
"""
56+
q_head_indices: torch.Tensor
57+
q_tail_indices: torch.Tensor
58+
q_head_start_loc: torch.Tensor
59+
kv_for_head_indices: torch.Tensor
60+
kv_for_tail_indices : torch.Tensor
61+
kv_for_head_indptr: torch.Tensor
62+
kv_for_tail_indptr: torch.Tensor
63+
q_full_indices: torch.Tensor
5164

5265
@dataclass
5366
class CommonAttentionMetadata:
@@ -94,6 +107,7 @@ class CommonAttentionMetadata:
94107
dcp_local_seq_lens: torch.Tensor | None = None
95108
"""Sequence lengths of the local rank in decode context parallelism world"""
96109

110+
pcp_metadata: PrefillContextParallelMetadata | None = None
97111

98112
def slice_query_start_locs(
99113
query_start_loc: torch.Tensor,
@@ -1077,35 +1091,35 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
10771091
return nums_dict, batch_ptr, token_chunk_offset_ptr
10781092

10791093

1080-
def get_dcp_local_seq_lens(
1094+
def get_cp_local_seq_lens(
10811095
seq_lens: torch.Tensor,
1082-
dcp_size: int = 1,
1083-
dcp_rank: int | None = None,
1096+
cp_size: int = 1,
1097+
cp_rank: int | None = None,
10841098
cp_kv_cache_interleave_size: int = 1,
10851099
) -> torch.Tensor:
10861100
"""While using dcp, kv_cache size stored on each rank may be different,
10871101
use this function to calculate split decode seq_lens of each dcp rank.
10881102
Only consider dcp now, we can extend the case of cp based on this.
10891103
"""
10901104
num_requests = seq_lens.size(0)
1091-
if dcp_rank is None:
1105+
if cp_rank is None:
10921106
rank_offsets = (
1093-
torch.arange(dcp_size, dtype=torch.int32)
1107+
torch.arange(cp_size, dtype=torch.int32)
10941108
.unsqueeze(0)
10951109
.repeat(num_requests, 1)
10961110
)
10971111
else:
1098-
rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32)
1112+
rank_offsets = torch.Tensor([[cp_rank]]).to(dtype=torch.int32)
10991113
seq_lens_tiled = (
11001114
seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
11011115
)
11021116
base = (
11031117
seq_lens_tiled
11041118
// cp_kv_cache_interleave_size
1105-
// dcp_size
1119+
// cp_size
11061120
* cp_kv_cache_interleave_size
11071121
)
1108-
remainder = seq_lens_tiled - base * dcp_size
1122+
remainder = seq_lens_tiled - base * cp_size
11091123
remainder = torch.clip(
11101124
remainder - rank_offsets * cp_kv_cache_interleave_size,
11111125
0,

vllm/v1/core/kv_cache_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,7 @@ def _report_kv_cache_config(
12221222
dcp_size = vllm_config.parallel_config.decode_context_parallel_size
12231223
pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
12241224
if pcp_size * dcp_size > 1:
1225-
num_tokens *= (pcp_size * dcp_size)
1225+
num_tokens *= pcp_size * dcp_size
12261226
logger.info(
12271227
"Multiplying the GPU KV cache size by the cp_world_size %d "
12281228
"(pcp_world_size %d * dcp_world_size %d).",

0 commit comments

Comments
 (0)