Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/distributed/test_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def _compare_cp_with_tp(
str(pp_size),
"--decode-context-parallel-size",
str(dcp_size),
"--dcp-kv-cache-interleave-size",
"--cp-kv-cache-interleave-size",
str(cp_kv_cache_interleave_size),
"--distributed-executor-backend",
distributed_backend,
Expand Down
5 changes: 0 additions & 5 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,6 @@ def _validate_parallel_config(self) -> Self:
"num_redundant_experts."
)

if self.prefill_context_parallel_size > 1:
raise ValueError(
"Prefill context parallelism is not fully supported. "
"Please set prefill_context_parallel_size to 1."
)
return self

@property
Expand Down
10 changes: 10 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,16 @@ def get_decode_context_model_parallel_rank():
return get_dcp_group().rank_in_group


def get_prefill_context_model_parallel_world_size():
"""Return world size for the decode context model parallel group."""
return get_pcp_group().world_size


def get_prefill_context_model_parallel_rank():
"""Return my rank for the decode context model parallel group."""
return get_pcp_group().rank_in_group


def get_node_count() -> int:
"""Return the total number of nodes in the distributed environment."""
assert _NODE_COUNT is not None, "distributed environment is not initialized"
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_dcp_local_seq_lens,
get_cp_local_seq_lens,
get_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -384,7 +384,7 @@ def schedule(
)
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu

dcp_context_kv_lens_cpu = get_dcp_local_seq_lens(
dcp_context_kv_lens_cpu = get_cp_local_seq_lens(
dcp_context_kv_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
Expand Down
Loading