Skip to content

Commit 3cbf550

Browse files
pisceskkkFENPLookAround0301Jingchun Gaozhenwenqi2024
committed
init PCP basic support
Co-authored-by: QiuChunshuo <[email protected]> Co-authored-by: FENP <[email protected]> Co-authored-by: LookAround <[email protected]> Co-authored-by: Jingchun Gao <[email protected]> Co-authored-by: zhenwenqi2024 <[email protected]> Signed-off-by: QiuChunshuo <[email protected]> Signed-off-by: FENP <[email protected]> Signed-off-by: LookAround <[email protected]> Signed-off-by: zhenwenqi2024 <[email protected]>
1 parent 5c9ad13 commit 3cbf550

File tree

25 files changed

+385
-108
lines changed

25 files changed

+385
-108
lines changed

tests/distributed/test_context_parallel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class ParallelSetup(NamedTuple):
3131
tp_size: int
3232
pp_size: int
3333
dcp_size: int
34-
dcp_kv_cache_interleave_size: int
34+
cp_kv_cache_interleave_size: int
3535
eager_mode: bool
3636
chunked_prefill: bool
3737

@@ -54,7 +54,7 @@ def detailed(
5454
tp_base: int = 4,
5555
pp_base: int = 1,
5656
dcp_base: int = 1,
57-
dcp_kv_cache_interleave_size: int = 1,
57+
cp_kv_cache_interleave_size: int = 1,
5858
multi_node_only: bool = False,
5959
runner: RunnerOption = "auto",
6060
load_format: str | None = None,
@@ -69,7 +69,7 @@ def detailed(
6969
tp_size=tp_base,
7070
pp_size=pp_multiplier * pp_base,
7171
dcp_size=int(dcp_multiplier * tp_base),
72-
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
72+
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
7373
eager_mode=eager_mode_val,
7474
chunked_prefill=chunked_prefill_val,
7575
)
@@ -112,7 +112,7 @@ def _compare_cp_with_tp(
112112
tp_size,
113113
pp_size,
114114
dcp_size,
115-
dcp_kv_cache_interleave_size,
115+
cp_kv_cache_interleave_size,
116116
eager_mode,
117117
chunked_prefill,
118118
) = parallel_setup
@@ -186,7 +186,7 @@ def _compare_cp_with_tp(
186186
"--decode-context-parallel-size",
187187
str(dcp_size),
188188
"--dcp-kv-cache-interleave-size",
189-
str(dcp_kv_cache_interleave_size),
189+
str(cp_kv_cache_interleave_size),
190190
"--distributed-executor-backend",
191191
distributed_backend,
192192
]
@@ -214,7 +214,7 @@ def _compare_cp_with_tp(
214214
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
215215
CPTestSettings.detailed(),
216216
CPTestSettings.detailed(tp_base=2),
217-
CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
217+
CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64),
218218
],
219219
"bigcode/gpt_bigcode-santacoder": [
220220
CPTestSettings.detailed(),

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,7 @@ def test_hybrid_block_table_initialization():
956956
max_num_reqs = 10
957957
max_num_blocks_per_req = 20
958958
max_num_batched_tokens = 512
959-
dcp_kv_cache_interleave_size = 8
959+
cp_kv_cache_interleave_size = 8
960960

961961
block_table = BlockTable(
962962
block_size=block_size,
@@ -966,7 +966,7 @@ def test_hybrid_block_table_initialization():
966966
pin_memory=False,
967967
device=torch.device(DEVICE),
968968
kernel_block_size=kernel_block_sizes[0],
969-
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
969+
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
970970
)
971971

972972
# Verify hybrid block configuration

vllm/attention/backends/abstract.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@ class AttentionImpl(ABC, Generic[T]):
252252
dcp_world_size: int
253253
dcp_rank: int
254254

255+
pcp_world_size: int
256+
pcp_rank: int
257+
258+
total_cp_world_size: int
259+
total_cp_rank: int
260+
255261
def __new__(cls, *args, **kwargs):
256262
# use __new__ so that all subclasses will call this
257263
self = super().__new__(cls)
@@ -264,6 +270,17 @@ def __new__(cls, *args, **kwargs):
264270
# DCP might not be initialized in testing
265271
self.dcp_world_size = 1
266272
self.dcp_rank = 0
273+
try:
274+
from vllm.distributed.parallel_state import get_pcp_group
275+
276+
self.pcp_world_size = get_pcp_group().world_size
277+
self.pcp_rank = get_pcp_group().rank_in_group
278+
except AssertionError:
279+
self.pcp_world_size = 1
280+
self.pcp_rank = 0
281+
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
282+
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
283+
267284
self.need_to_return_lse_for_decode = (
268285
self.dcp_world_size > 1 and self.can_return_lse_for_decode
269286
)

vllm/attention/ops/common.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,11 @@ def correct_attn_out(
169169
return out, lse
170170

171171

172-
def cp_lse_ag_out_rs(
172+
def _cp_lse_common(
173173
cp_attn_out: torch.Tensor,
174174
cp_attn_lse: torch.Tensor,
175175
cp_group: GroupCoordinator,
176-
ctx: CPTritonContext = None,
177-
return_lse=False,
176+
ctx: CPTritonContext | None = None,
178177
):
179178
"""
180179
cp_attn_out: [ B, H, D ]
@@ -195,6 +194,22 @@ def cp_lse_ag_out_rs(
195194
cp_attn_lse = cp_attn_lse.contiguous()
196195
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
197196
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
197+
assert out.is_contiguous()
198+
return out, lse
199+
200+
201+
def cp_lse_ag_out_rs(
202+
cp_attn_out: torch.Tensor,
203+
cp_attn_lse: torch.Tensor,
204+
cp_group: GroupCoordinator,
205+
ctx: CPTritonContext | None = None,
206+
return_lse: bool = False,
207+
):
208+
"""
209+
cp_attn_out: [ B, H, D ]
210+
cp_attn_lse: [ B, H ]
211+
"""
212+
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
198213
out = cp_group.reduce_scatter(out, dim=1)
199214

200215
if return_lse:
@@ -205,6 +220,25 @@ def cp_lse_ag_out_rs(
205220
return out
206221

207222

223+
def cp_lse_ag_out_ar(
224+
cp_attn_out: torch.Tensor,
225+
cp_attn_lse: torch.Tensor,
226+
cp_group: GroupCoordinator,
227+
ctx: CPTritonContext | None = None,
228+
return_lse: bool = False,
229+
):
230+
"""
231+
cp_attn_out: [ B, H, D ]
232+
cp_attn_lse: [ B, H ]
233+
"""
234+
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
235+
out = cp_group.all_reduce(out)
236+
237+
if return_lse:
238+
return out, lse
239+
return out
240+
241+
208242
@triton.jit
209243
def _pack_seq_kernel(
210244
x_ptr, # [N, D]

vllm/config/parallel.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class ParallelConfig:
7272
"""Number of pipeline parallel groups."""
7373
tensor_parallel_size: int = 1
7474
"""Number of tensor parallel groups."""
75+
prefill_context_parallel_size: int = 1
76+
"""Number of prefill context parallel groups."""
7577
data_parallel_size: int = 1
7678
"""Number of data parallel groups. MoE layers will be sharded according to
7779
the product of the tensor parallel size and data parallel size."""
@@ -227,15 +229,19 @@ class is dynamically inherited by the worker class. This is used to inject
227229
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
228230
needs to be divisible by dcp_size."""
229231

230-
dcp_kv_cache_interleave_size: int = 1
231-
"""Interleave size of kv_cache storage while using dcp or cp > 1,
232-
store interleave_size tokens on (d)cp i,
233-
then store next interleave_size tokens on (d)cp i+1.
234-
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
235-
Interleave_size=block_size: block-level align, first fill the block on first rank,
236-
token is stored on rank i+1 block j after rank i block j is full.
237-
Block_size should be greater than or equal to dcp_kv_cache_interleave_size.
238-
Block_size should be divisible by dcp_kv_cache_interleave_size.
232+
cp_kv_cache_interleave_size: int = 1
233+
"""Interleave size of kv_cache storage while using dcp or pcp.
234+
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
235+
and `total_cp_world_size = pcp_world_size * dcp_world_szie`.
236+
store interleave_size tokens on total_cp_rank i,
237+
then store next interleave_size tokens on taotal_cp_rank i+1.
238+
Interleave_size=1: token-level alignment, where token `i` is stored on
239+
total_cp_rank `i % total_cp_world_size`.
240+
Interleave_size=block_size: block-level alignment, where tokens are
241+
first populated to the preceding ranks. Tokens are then stored
242+
in (rank i+1, block j) only after (rank i, block j) is fully occupied.
243+
Block_size should be greater than or equal to cp_kv_cache_interleave_size.
244+
Block_size should be divisible by cp_kv_cache_interleave_size.
239245
"""
240246

241247
_api_process_count: int = Field(default=1, gt=0)
@@ -300,6 +306,11 @@ def _validate_parallel_config(self) -> Self:
300306
"num_redundant_experts."
301307
)
302308

309+
if self.prefill_context_parallel_size > 1:
310+
raise ValueError(
311+
"Prefill context parallelism is not fully supported. "
312+
"Please set prefill_context_parallel_size to 1."
313+
)
303314
return self
304315

305316
@property
@@ -479,7 +490,11 @@ def __post_init__(self) -> None:
479490
)
480491

481492
# Continue with the rest of the initialization
482-
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
493+
self.world_size = (
494+
self.pipeline_parallel_size
495+
* self.tensor_parallel_size
496+
* self.prefill_context_parallel_size
497+
)
483498

484499
if self.distributed_executor_backend == "external_launcher":
485500
logger.info("Using external launcher for distributed inference.")

vllm/config/vllm.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,14 @@ def __post_init__(self):
470470
"Overriding cudagraph_mode to PIECEWISE."
471471
)
472472
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
473+
# prefill context parallel do not support full cudagraphs
474+
elif self.parallel_config.prefill_context_parallel_size > 1:
475+
logger.warning_once(
476+
"Prefill context parallel (PCP) is enabled, which is "
477+
"incompatible with full CUDA graphs. "
478+
"Overriding cudagraph_mode to PIECEWISE."
479+
)
480+
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
473481
elif self.model_config is not None:
474482
if self.model_config.pooler_config is not None:
475483
logger.warning_once(
@@ -615,21 +623,21 @@ def __post_init__(self):
615623
# If DCP, ensure the block size is right.
616624
if self.parallel_config.decode_context_parallel_size > 1:
617625
assert (
618-
self.parallel_config.dcp_kv_cache_interleave_size
626+
self.parallel_config.cp_kv_cache_interleave_size
619627
<= self.cache_config.block_size
620628
and self.cache_config.block_size
621-
% self.parallel_config.dcp_kv_cache_interleave_size
629+
% self.parallel_config.cp_kv_cache_interleave_size
622630
== 0
623631
), (
624632
f"Block_size({self.cache_config.block_size}) should be greater "
625-
"than or equal to and divisible by dcp_kv_cache_interleave_size "
626-
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
633+
"than or equal to and divisible by cp_kv_cache_interleave_size "
634+
f"({self.parallel_config.cp_kv_cache_interleave_size})."
627635
)
628636

629637
assert (
630-
self.parallel_config.dcp_kv_cache_interleave_size == 1
638+
self.parallel_config.cp_kv_cache_interleave_size == 1
631639
or self.speculative_config is None
632-
), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now."
640+
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
633641

634642
# Do this after all the updates to compilation_config.mode
635643
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:

0 commit comments

Comments
 (0)