Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/distributed/test_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def _compare_cp_with_tp(
str(pp_size),
"--decode-context-parallel-size",
str(dcp_size),
"--dcp-kv-cache-interleave-size",
"--cp-kv-cache-interleave-size",
str(cp_kv_cache_interleave_size),
"--distributed-executor-backend",
distributed_backend,
Expand Down
5 changes: 0 additions & 5 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,6 @@ def _validate_parallel_config(self) -> Self:
"num_redundant_experts."
)

if self.prefill_context_parallel_size > 1:
raise ValueError(
"Prefill context parallelism is not fully supported. "
"Please set prefill_context_parallel_size to 1."
)
return self

@property
Expand Down
10 changes: 10 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,16 @@ def get_decode_context_model_parallel_rank():
return get_dcp_group().rank_in_group


def get_prefill_context_model_parallel_world_size():
"""Return world size for the decode context model parallel group."""
return get_pcp_group().world_size


def get_prefill_context_model_parallel_rank():
"""Return my rank for the decode context model parallel group."""
return get_pcp_group().rank_in_group


def get_node_count() -> int:
"""Return the total number of nodes in the distributed environment."""
assert _NODE_COUNT is not None, "distributed environment is not initialized"
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_dcp_local_seq_lens,
get_cp_local_seq_lens,
get_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -384,7 +384,7 @@ def schedule(
)
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu

dcp_context_kv_lens_cpu = get_dcp_local_seq_lens(
dcp_context_kv_lens_cpu = get_cp_local_seq_lens(
dcp_context_kv_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
Expand Down
Loading