Skip to content

Commit 30873d6

Browse files
pisceskkkFENPLookAround0301Jingchun Gaozhenwenqi2024
committed
init PCP basic support
Co-authored-by: QiuChunshuo <[email protected]> Co-authored-by: FENP <[email protected]> Co-authored-by: LookAround <[email protected]> Co-authored-by: Jingchun Gao <[email protected]> Co-authored-by: zhenwenqi2024 <[email protected]> Signed-off-by: QiuChunshuo <[email protected]> Signed-off-by: FENP <[email protected]> Signed-off-by: LookAround <[email protected]> Signed-off-by: Jingchun Gao <[email protected]> Signed-off-by: zhenwenqi2024 <[email protected]>
1 parent 5bb1da5 commit 30873d6

File tree

25 files changed

+390
-101
lines changed

25 files changed

+390
-101
lines changed

tests/distributed/test_context_parallel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class ParallelSetup(NamedTuple):
3131
tp_size: int
3232
pp_size: int
3333
dcp_size: int
34-
dcp_kv_cache_interleave_size: int
34+
cp_kv_cache_interleave_size: int
3535
eager_mode: bool
3636
chunked_prefill: bool
3737

@@ -55,7 +55,7 @@ def detailed(
5555
tp_base: int = 4,
5656
pp_base: int = 1,
5757
dcp_base: int = 1,
58-
dcp_kv_cache_interleave_size: int = 1,
58+
cp_kv_cache_interleave_size: int = 1,
5959
multi_node_only: bool = False,
6060
runner: RunnerOption = "auto",
6161
load_format: str | None = None,
@@ -71,7 +71,7 @@ def detailed(
7171
tp_size=tp_base,
7272
pp_size=pp_multiplier * pp_base,
7373
dcp_size=int(dcp_multiplier * tp_base),
74-
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
74+
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
7575
eager_mode=eager_mode_val,
7676
chunked_prefill=chunked_prefill_val,
7777
)
@@ -116,7 +116,7 @@ def _compare_cp_with_tp(
116116
tp_size,
117117
pp_size,
118118
dcp_size,
119-
dcp_kv_cache_interleave_size,
119+
cp_kv_cache_interleave_size,
120120
eager_mode,
121121
chunked_prefill,
122122
) = parallel_setup
@@ -197,7 +197,7 @@ def _compare_cp_with_tp(
197197
"--decode-context-parallel-size",
198198
str(dcp_size),
199199
"--dcp-kv-cache-interleave-size",
200-
str(dcp_kv_cache_interleave_size),
200+
str(cp_kv_cache_interleave_size),
201201
"--distributed-executor-backend",
202202
distributed_backend,
203203
]
@@ -227,7 +227,7 @@ def _compare_cp_with_tp(
227227
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
228228
CPTestSettings.detailed(),
229229
CPTestSettings.detailed(tp_base=2),
230-
CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
230+
CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64),
231231
],
232232
"bigcode/gpt_bigcode-santacoder": [
233233
CPTestSettings.detailed(),

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,7 @@ def test_hybrid_block_table_initialization():
956956
max_num_reqs = 10
957957
max_num_blocks_per_req = 20
958958
max_num_batched_tokens = 512
959-
dcp_kv_cache_interleave_size = 8
959+
cp_kv_cache_interleave_size = 8
960960

961961
block_table = BlockTable(
962962
block_size=block_size,
@@ -966,7 +966,7 @@ def test_hybrid_block_table_initialization():
966966
pin_memory=False,
967967
device=torch.device(DEVICE),
968968
kernel_block_size=kernel_block_sizes[0],
969-
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
969+
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
970970
)
971971

972972
# Verify hybrid block configuration

vllm/attention/backends/abstract.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,12 @@ class AttentionImpl(ABC, Generic[T]):
266266
dcp_world_size: int
267267
dcp_rank: int
268268

269+
pcp_world_size: int
270+
pcp_rank: int
271+
272+
total_cp_world_size: int
273+
total_cp_rank: int
274+
269275
def __new__(cls, *args, **kwargs):
270276
# use __new__ so that all subclasses will call this
271277
self = super().__new__(cls)
@@ -278,6 +284,17 @@ def __new__(cls, *args, **kwargs):
278284
# DCP might not be initialized in testing
279285
self.dcp_world_size = 1
280286
self.dcp_rank = 0
287+
try:
288+
from vllm.distributed.parallel_state import get_pcp_group
289+
290+
self.pcp_world_size = get_pcp_group().world_size
291+
self.pcp_rank = get_pcp_group().rank_in_group
292+
except AssertionError:
293+
self.pcp_world_size = 1
294+
self.pcp_rank = 0
295+
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
296+
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
297+
281298
self.need_to_return_lse_for_decode = (
282299
self.dcp_world_size > 1 and self.can_return_lse_for_decode
283300
)

vllm/attention/ops/common.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,11 @@ def correct_attn_out(
169169
return out, lse
170170

171171

172-
def cp_lse_ag_out_rs(
172+
def _cp_lse_common(
173173
cp_attn_out: torch.Tensor,
174174
cp_attn_lse: torch.Tensor,
175175
cp_group: GroupCoordinator,
176-
ctx: CPTritonContext = None,
177-
return_lse=False,
176+
ctx: CPTritonContext | None = None,
178177
):
179178
"""
180179
cp_attn_out: [ B, H, D ]
@@ -195,6 +194,22 @@ def cp_lse_ag_out_rs(
195194
cp_attn_lse = cp_attn_lse.contiguous()
196195
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
197196
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
197+
assert out.is_contiguous()
198+
return out, lse
199+
200+
201+
def cp_lse_ag_out_rs(
202+
cp_attn_out: torch.Tensor,
203+
cp_attn_lse: torch.Tensor,
204+
cp_group: GroupCoordinator,
205+
ctx: CPTritonContext | None = None,
206+
return_lse: bool = False,
207+
):
208+
"""
209+
cp_attn_out: [ B, H, D ]
210+
cp_attn_lse: [ B, H ]
211+
"""
212+
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
198213
out = cp_group.reduce_scatter(out, dim=1)
199214

200215
if return_lse:
@@ -205,6 +220,25 @@ def cp_lse_ag_out_rs(
205220
return out
206221

207222

223+
def cp_lse_ag_out_ar(
224+
cp_attn_out: torch.Tensor,
225+
cp_attn_lse: torch.Tensor,
226+
cp_group: GroupCoordinator,
227+
ctx: CPTritonContext | None = None,
228+
return_lse: bool = False,
229+
):
230+
"""
231+
cp_attn_out: [ B, H, D ]
232+
cp_attn_lse: [ B, H ]
233+
"""
234+
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
235+
out = cp_group.all_reduce(out)
236+
237+
if return_lse:
238+
return out, lse
239+
return out
240+
241+
208242
@triton.jit
209243
def _pack_seq_kernel(
210244
x_ptr, # [N, D]

vllm/config/parallel.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class ParallelConfig:
7272
"""Number of pipeline parallel groups."""
7373
tensor_parallel_size: int = 1
7474
"""Number of tensor parallel groups."""
75+
prefill_context_parallel_size: int = 1
76+
"""Number of prefill context parallel groups."""
7577
data_parallel_size: int = 1
7678
"""Number of data parallel groups. MoE layers will be sharded according to
7779
the product of the tensor parallel size and data parallel size."""
@@ -240,14 +242,25 @@ class is dynamically inherited by the worker class. This is used to inject
240242
needs to be divisible by dcp_size."""
241243

242244
dcp_kv_cache_interleave_size: int = 1
243-
"""Interleave size of kv_cache storage while using dcp or cp > 1,
244-
store interleave_size tokens on (d)cp i,
245-
then store next interleave_size tokens on (d)cp i+1.
246-
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
247-
Interleave_size=block_size: block-level align, first fill the block on first rank,
248-
token is stored on rank i+1 block j after rank i block j is full.
249-
Block_size should be greater than or equal to dcp_kv_cache_interleave_size.
250-
Block_size should be divisible by dcp_kv_cache_interleave_size.
245+
"""
246+
Interleave size of kv_cache storage while using DCP.
247+
dcp_kv_cache_interleave_size has been replaced by cp_kv_cache_interleave_size,
248+
and will be deprecated when PCP is fully supported.
249+
250+
"""
251+
cp_kv_cache_interleave_size: int = 1
252+
"""Interleave size of kv_cache storage while using DCP or PCP.
253+
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
254+
and `total_cp_world_size = pcp_world_size * dcp_world_szie`.
255+
store interleave_size tokens on total_cp_rank i,
256+
then store next interleave_size tokens on taotal_cp_rank i+1.
257+
Interleave_size=1: token-level alignment, where token `i` is stored on
258+
total_cp_rank `i % total_cp_world_size`.
259+
Interleave_size=block_size: block-level alignment, where tokens are
260+
first populated to the preceding ranks. Tokens are then stored
261+
in (rank i+1, block j) only after (rank i, block j) is fully occupied.
262+
Block_size should be greater than or equal to cp_kv_cache_interleave_size.
263+
Block_size should be divisible by cp_kv_cache_interleave_size.
251264
"""
252265

253266
_api_process_count: int = Field(default=1, gt=0)
@@ -312,6 +325,11 @@ def _validate_parallel_config(self) -> Self:
312325
"num_redundant_experts."
313326
)
314327

328+
if self.prefill_context_parallel_size > 1:
329+
raise ValueError(
330+
"Prefill context parallelism is not fully supported. "
331+
"Please set prefill_context_parallel_size to 1."
332+
)
315333
return self
316334

317335
@property
@@ -508,7 +526,11 @@ def __post_init__(self) -> None:
508526
)
509527

510528
# Continue with the rest of the initialization
511-
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
529+
self.world_size = (
530+
self.pipeline_parallel_size
531+
* self.tensor_parallel_size
532+
* self.prefill_context_parallel_size
533+
)
512534

513535
if self.distributed_executor_backend == "external_launcher":
514536
logger.info("Using external launcher for distributed inference.")

vllm/config/vllm.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,14 @@ def __post_init__(self):
481481
"Overriding cudagraph_mode to PIECEWISE."
482482
)
483483
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
484+
# prefill context parallel do not support full cudagraphs
485+
elif self.parallel_config.prefill_context_parallel_size > 1:
486+
logger.warning_once(
487+
"Prefill context parallel (PCP) is enabled, which is "
488+
"incompatible with full CUDA graphs. "
489+
"Overriding cudagraph_mode to PIECEWISE."
490+
)
491+
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
484492
elif self.model_config is not None:
485493
if self.model_config.pooler_config is not None:
486494
logger.warning_once(
@@ -610,22 +618,34 @@ def __post_init__(self):
610618

611619
# If DCP, ensure the block size is right.
612620
if self.parallel_config.decode_context_parallel_size > 1:
621+
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
622+
self.parallel_config.cp_kv_cache_interleave_size
623+
!= self.parallel_config.dcp_kv_cache_interleave_size
624+
):
625+
self.parallel_config.cp_kv_cache_interleave_size = (
626+
self.parallel_config.dcp_kv_cache_interleave_size
627+
)
628+
logger.warning_once(
629+
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
630+
"_interleave_size. And dcp-kv-cache-interleave-size will be "
631+
"deprecated when PCP is fully supported."
632+
)
613633
assert (
614-
self.parallel_config.dcp_kv_cache_interleave_size
634+
self.parallel_config.cp_kv_cache_interleave_size
615635
<= self.cache_config.block_size
616636
and self.cache_config.block_size
617-
% self.parallel_config.dcp_kv_cache_interleave_size
637+
% self.parallel_config.cp_kv_cache_interleave_size
618638
== 0
619639
), (
620640
f"Block_size({self.cache_config.block_size}) should be greater "
621-
"than or equal to and divisible by dcp_kv_cache_interleave_size "
622-
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
641+
"than or equal to and divisible by cp_kv_cache_interleave_size "
642+
f"({self.parallel_config.cp_kv_cache_interleave_size})."
623643
)
624644

625645
assert (
626-
self.parallel_config.dcp_kv_cache_interleave_size == 1
646+
self.parallel_config.cp_kv_cache_interleave_size == 1
627647
or self.speculative_config is None
628-
), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now."
648+
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
629649

630650
# Do this after all the updates to compilation_config.mode
631651
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:

0 commit comments

Comments
 (0)