Skip to content

Commit 22eb266

Browse files
pisceskkkLookAround0301
authored andcommitted
[Feature] Prefill Context Parallel (PCP) basic support (vllm-project#28718)
Signed-off-by: QiuChunshuo <[email protected]> Signed-off-by: FENP <[email protected]> Signed-off-by: LookAround <[email protected]> Signed-off-by: Jingchun Gao <[email protected]> Signed-off-by: zhenwenqi2024 <[email protected]> Co-authored-by: FENP <[email protected]> Co-authored-by: LookAround <[email protected]> Co-authored-by: Jingchun Gao <[email protected]> Co-authored-by: zhenwenqi2024 <[email protected]> Co-authored-by: Jingchun Gao <[email protected]> (cherry picked from commit 2fd893b)
1 parent 2918c1b commit 22eb266

File tree

22 files changed

+428
-62
lines changed

22 files changed

+428
-62
lines changed

tests/kernels/moe/modular_kernel_tools/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
)
1616
from tests.kernels.utils import torch_experts
1717
from vllm.config import VllmConfig
18-
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
18+
from vllm.distributed import (
19+
get_dp_group,
20+
get_pcp_group,
21+
get_tensor_model_parallel_world_size,
22+
)
1923
from vllm.forward_context import set_forward_context
2024
from vllm.model_executor.layers.fused_moe.config import (
2125
FusedMoEConfig,
@@ -561,6 +565,7 @@ def next_power_of_2(x):
561565
# make moe config
562566
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
563567
tp_size_=get_tensor_model_parallel_world_size(),
568+
pcp_size_=get_pcp_group().world_size,
564569
dp_size_=get_dp_group().world_size,
565570
vllm_parallel_config=vllm_config.parallel_config,
566571
)

vllm/attention/backends/abstract.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,12 @@ class AttentionImpl(ABC, Generic[T]):
127127
dcp_world_size: int
128128
dcp_rank: int
129129

130+
pcp_world_size: int
131+
pcp_rank: int
132+
133+
total_cp_world_size: int
134+
total_cp_rank: int
135+
130136
def __new__(cls, *args, **kwargs):
131137
# use __new__ so that all subclasses will call this
132138
self = super().__new__(cls)
@@ -139,6 +145,17 @@ def __new__(cls, *args, **kwargs):
139145
# DCP might not be initialized in testing
140146
self.dcp_world_size = 1
141147
self.dcp_rank = 0
148+
try:
149+
from vllm.distributed.parallel_state import get_pcp_group
150+
151+
self.pcp_world_size = get_pcp_group().world_size
152+
self.pcp_rank = get_pcp_group().rank_in_group
153+
except AssertionError:
154+
self.pcp_world_size = 1
155+
self.pcp_rank = 0
156+
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
157+
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
158+
142159
self.need_to_return_lse_for_decode = (
143160
self.dcp_world_size > 1 and self.can_return_lse_for_decode
144161
)

vllm/attention/ops/common.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,11 @@ def correct_attn_out(
168168
return out, lse
169169

170170

171-
def cp_lse_ag_out_rs(
171+
def _cp_lse_common(
172172
cp_attn_out: torch.Tensor,
173173
cp_attn_lse: torch.Tensor,
174174
cp_group: GroupCoordinator,
175-
ctx: CPTritonContext = None,
176-
return_lse=False,
175+
ctx: CPTritonContext | None = None,
177176
):
178177
"""
179178
cp_attn_out: [ B, H, D ]
@@ -195,6 +194,22 @@ def cp_lse_ag_out_rs(
195194
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
196195
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
197196
assert out.is_contiguous()
197+
assert out.is_contiguous()
198+
return out, lse
199+
200+
201+
def cp_lse_ag_out_rs(
202+
cp_attn_out: torch.Tensor,
203+
cp_attn_lse: torch.Tensor,
204+
cp_group: GroupCoordinator,
205+
ctx: CPTritonContext | None = None,
206+
return_lse: bool = False,
207+
):
208+
"""
209+
cp_attn_out: [ B, H, D ]
210+
cp_attn_lse: [ B, H ]
211+
"""
212+
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
198213
out = cp_group.reduce_scatter(out, dim=1)
199214

200215
if return_lse:
@@ -205,6 +220,25 @@ def cp_lse_ag_out_rs(
205220
return out
206221

207222

223+
def cp_lse_ag_out_ar(
224+
cp_attn_out: torch.Tensor,
225+
cp_attn_lse: torch.Tensor,
226+
cp_group: GroupCoordinator,
227+
ctx: CPTritonContext | None = None,
228+
return_lse: bool = False,
229+
):
230+
"""
231+
cp_attn_out: [ B, H, D ]
232+
cp_attn_lse: [ B, H ]
233+
"""
234+
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
235+
out = cp_group.all_reduce(out)
236+
237+
if return_lse:
238+
return out, lse
239+
return out
240+
241+
208242
@triton.jit
209243
def _pack_seq_kernel(
210244
x_ptr, # [N, D]

vllm/config/parallel.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class ParallelConfig:
7272
"""Number of pipeline parallel groups."""
7373
tensor_parallel_size: int = 1
7474
"""Number of tensor parallel groups."""
75+
prefill_context_parallel_size: int = 1
76+
"""Number of prefill context parallel groups."""
7577
data_parallel_size: int = 1
7678
"""Number of data parallel groups. MoE layers will be sharded according to
7779
the product of the tensor parallel size and data parallel size."""
@@ -227,6 +229,21 @@ class is dynamically inherited by the worker class. This is used to inject
227229
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
228230
needs to be divisible by dcp_size."""
229231

232+
cp_kv_cache_interleave_size: int = 1
233+
"""Interleave size of kv_cache storage while using DCP or PCP.
234+
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
235+
and `total_cp_world_size = pcp_world_size * dcp_world_szie`.
236+
store interleave_size tokens on total_cp_rank i,
237+
then store next interleave_size tokens on taotal_cp_rank i+1.
238+
Interleave_size=1: token-level alignment, where token `i` is stored on
239+
total_cp_rank `i % total_cp_world_size`.
240+
Interleave_size=block_size: block-level alignment, where tokens are
241+
first populated to the preceding ranks. Tokens are then stored
242+
in (rank i+1, block j) only after (rank i, block j) is fully occupied.
243+
Block_size should be greater than or equal to cp_kv_cache_interleave_size.
244+
Block_size should be divisible by cp_kv_cache_interleave_size.
245+
"""
246+
230247
_api_process_count: int = Field(default=1, gt=0)
231248
"""
232249
The number of API processes initialized.
@@ -289,6 +306,11 @@ def _validate_parallel_config(self) -> Self:
289306
"num_redundant_experts."
290307
)
291308

309+
if self.prefill_context_parallel_size > 1:
310+
raise ValueError(
311+
"Prefill context parallelism is not fully supported. "
312+
"Please set prefill_context_parallel_size to 1."
313+
)
292314
return self
293315

294316
@property
@@ -468,7 +490,11 @@ def __post_init__(self) -> None:
468490
)
469491

470492
# Continue with the rest of the initialization
471-
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
493+
self.world_size = (
494+
self.pipeline_parallel_size
495+
* self.tensor_parallel_size
496+
* self.prefill_context_parallel_size
497+
)
472498

473499
if self.distributed_executor_backend == "external_launcher":
474500
logger.info("Using external launcher for distributed inference.")

vllm/config/vllm.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,14 @@ def __post_init__(self):
377377
"Overriding cudagraph_mode to PIECEWISE."
378378
)
379379
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
380+
# prefill context parallel do not support full cudagraphs
381+
elif self.parallel_config.prefill_context_parallel_size > 1:
382+
logger.warning_once(
383+
"Prefill context parallel (PCP) is enabled, which is "
384+
"incompatible with full CUDA graphs. "
385+
"Overriding cudagraph_mode to PIECEWISE."
386+
)
387+
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
380388
elif self.model_config is not None:
381389
if self.model_config.pooler_config is not None:
382390
logger.warning_once(
@@ -519,6 +527,37 @@ def __post_init__(self):
519527
)
520528
current_platform.check_and_update_config(self)
521529

530+
# If DCP, ensure the block size is right.
531+
if self.parallel_config.decode_context_parallel_size > 1:
532+
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
533+
self.parallel_config.cp_kv_cache_interleave_size
534+
!= self.parallel_config.dcp_kv_cache_interleave_size
535+
):
536+
self.parallel_config.cp_kv_cache_interleave_size = (
537+
self.parallel_config.dcp_kv_cache_interleave_size
538+
)
539+
logger.warning_once(
540+
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
541+
"_interleave_size. And dcp-kv-cache-interleave-size will be "
542+
"deprecated when PCP is fully supported."
543+
)
544+
assert (
545+
self.parallel_config.cp_kv_cache_interleave_size
546+
<= self.cache_config.block_size
547+
and self.cache_config.block_size
548+
% self.parallel_config.cp_kv_cache_interleave_size
549+
== 0
550+
), (
551+
f"Block_size({self.cache_config.block_size}) should be greater "
552+
"than or equal to and divisible by cp_kv_cache_interleave_size "
553+
f"({self.parallel_config.cp_kv_cache_interleave_size})."
554+
)
555+
556+
assert (
557+
self.parallel_config.cp_kv_cache_interleave_size == 1
558+
or self.speculative_config is None
559+
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
560+
522561
# Do this after all the updates to compilation_config.mode
523562
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
524563
self.compilation_config.set_splitting_ops_for_v1()

0 commit comments

Comments
 (0)