Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/distributed/test_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
dcp_size: int
dcp_kv_cache_interleave_size: int
cp_kv_cache_interleave_size: int
eager_mode: bool
chunked_prefill: bool

Expand All @@ -55,7 +55,7 @@ def detailed(
tp_base: int = 4,
pp_base: int = 1,
dcp_base: int = 1,
dcp_kv_cache_interleave_size: int = 1,
cp_kv_cache_interleave_size: int = 1,
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: str | None = None,
Expand All @@ -71,7 +71,7 @@ def detailed(
tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
dcp_size=int(dcp_multiplier * tp_base),
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val,
)
Expand Down Expand Up @@ -116,7 +116,7 @@ def _compare_cp_with_tp(
tp_size,
pp_size,
dcp_size,
dcp_kv_cache_interleave_size,
cp_kv_cache_interleave_size,
eager_mode,
chunked_prefill,
) = parallel_setup
Expand Down Expand Up @@ -197,7 +197,7 @@ def _compare_cp_with_tp(
"--decode-context-parallel-size",
str(dcp_size),
"--dcp-kv-cache-interleave-size",
str(dcp_kv_cache_interleave_size),
str(cp_kv_cache_interleave_size),
"--distributed-executor-backend",
distributed_backend,
]
Expand Down Expand Up @@ -227,7 +227,7 @@ def _compare_cp_with_tp(
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64),
],
"bigcode/gpt_bigcode-santacoder": [
CPTestSettings.detailed(),
Expand Down
7 changes: 6 additions & 1 deletion tests/kernels/moe/modular_kernel_tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
)
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
from vllm.distributed import (
get_dp_group,
get_pcp_group,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
Expand Down Expand Up @@ -561,6 +565,7 @@ def next_power_of_2(x):
# make moe config
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
tp_size_=get_tensor_model_parallel_world_size(),
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
vllm_parallel_config=vllm_config.parallel_config,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ def test_hybrid_block_table_initialization():
max_num_reqs = 10
max_num_blocks_per_req = 20
max_num_batched_tokens = 512
dcp_kv_cache_interleave_size = 8
cp_kv_cache_interleave_size = 8

block_table = BlockTable(
block_size=block_size,
Expand All @@ -966,7 +966,7 @@ def test_hybrid_block_table_initialization():
pin_memory=False,
device=torch.device(DEVICE),
kernel_block_size=kernel_block_sizes[0],
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)

# Verify hybrid block configuration
Expand Down
17 changes: 17 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,12 @@ class AttentionImpl(ABC, Generic[T]):
dcp_world_size: int
dcp_rank: int

pcp_world_size: int
pcp_rank: int

total_cp_world_size: int
total_cp_rank: int

def __new__(cls, *args, **kwargs):
# use __new__ so that all subclasses will call this
self = super().__new__(cls)
Expand All @@ -278,6 +284,17 @@ def __new__(cls, *args, **kwargs):
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
try:
from vllm.distributed.parallel_state import get_pcp_group

self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group
except AssertionError:
self.pcp_world_size = 1
self.pcp_rank = 0
self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank

self.need_to_return_lse_for_decode = (
self.dcp_world_size > 1 and self.can_return_lse_for_decode
)
Expand Down
40 changes: 37 additions & 3 deletions vllm/attention/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,11 @@ def correct_attn_out(
return out, lse


def cp_lse_ag_out_rs(
def _cp_lse_common(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None,
return_lse=False,
ctx: CPTritonContext | None = None,
):
"""
cp_attn_out: [ B, H, D ]
Expand All @@ -195,6 +194,22 @@ def cp_lse_ag_out_rs(
cp_attn_lse = cp_attn_lse.contiguous()
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
assert out.is_contiguous()
return out, lse


def cp_lse_ag_out_rs(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
return_lse: bool = False,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
out = cp_group.reduce_scatter(out, dim=1)

if return_lse:
Expand All @@ -205,6 +220,25 @@ def cp_lse_ag_out_rs(
return out


def cp_lse_ag_out_ar(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext | None = None,
return_lse: bool = False,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
out = cp_group.all_reduce(out)

if return_lse:
return out, lse
return out


@triton.jit
def _pack_seq_kernel(
x_ptr, # [N, D]
Expand Down
40 changes: 31 additions & 9 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class ParallelConfig:
"""Number of pipeline parallel groups."""
tensor_parallel_size: int = 1
"""Number of tensor parallel groups."""
prefill_context_parallel_size: int = 1
"""Number of prefill context parallel groups."""
data_parallel_size: int = 1
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
Expand Down Expand Up @@ -240,14 +242,25 @@ class is dynamically inherited by the worker class. This is used to inject
needs to be divisible by dcp_size."""

dcp_kv_cache_interleave_size: int = 1
"""Interleave size of kv_cache storage while using dcp or cp > 1,
store interleave_size tokens on (d)cp i,
then store next interleave_size tokens on (d)cp i+1.
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
Interleave_size=block_size: block-level align, first fill the block on first rank,
token is stored on rank i+1 block j after rank i block j is full.
Block_size should be greater than or equal to dcp_kv_cache_interleave_size.
Block_size should be divisible by dcp_kv_cache_interleave_size.
"""
Interleave size of kv_cache storage while using DCP.
dcp_kv_cache_interleave_size has been replaced by cp_kv_cache_interleave_size,
and will be deprecated when PCP is fully supported.

"""
cp_kv_cache_interleave_size: int = 1
"""Interleave size of kv_cache storage while using DCP or PCP.
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
and `total_cp_world_size = pcp_world_size * dcp_world_szie`.
store interleave_size tokens on total_cp_rank i,
then store next interleave_size tokens on taotal_cp_rank i+1.
Interleave_size=1: token-level alignment, where token `i` is stored on
total_cp_rank `i % total_cp_world_size`.
Interleave_size=block_size: block-level alignment, where tokens are
first populated to the preceding ranks. Tokens are then stored
in (rank i+1, block j) only after (rank i, block j) is fully occupied.
Block_size should be greater than or equal to cp_kv_cache_interleave_size.
Block_size should be divisible by cp_kv_cache_interleave_size.
"""

_api_process_count: int = Field(default=1, gt=0)
Expand Down Expand Up @@ -312,6 +325,11 @@ def _validate_parallel_config(self) -> Self:
"num_redundant_experts."
)

if self.prefill_context_parallel_size > 1:
raise ValueError(
"Prefill context parallelism is not fully supported. "
"Please set prefill_context_parallel_size to 1."
)
return self

@property
Expand Down Expand Up @@ -508,7 +526,11 @@ def __post_init__(self) -> None:
)

# Continue with the rest of the initialization
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
self.world_size = (
self.pipeline_parallel_size
* self.tensor_parallel_size
* self.prefill_context_parallel_size
)

if self.distributed_executor_backend == "external_launcher":
logger.info("Using external launcher for distributed inference.")
Expand Down
32 changes: 26 additions & 6 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,14 @@ def __post_init__(self):
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# prefill context parallel do not support full cudagraphs
elif self.parallel_config.prefill_context_parallel_size > 1:
logger.warning_once(
"Prefill context parallel (PCP) is enabled, which is "
"incompatible with full CUDA graphs. "
"Overriding cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
elif self.model_config is not None:
if self.model_config.pooler_config is not None:
logger.warning_once(
Expand Down Expand Up @@ -610,22 +618,34 @@ def __post_init__(self):

# If DCP, ensure the block size is right.
if self.parallel_config.decode_context_parallel_size > 1:
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
self.parallel_config.cp_kv_cache_interleave_size
!= self.parallel_config.dcp_kv_cache_interleave_size
):
self.parallel_config.cp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
)
logger.warning_once(
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
"_interleave_size. And dcp-kv-cache-interleave-size will be "
"deprecated when PCP is fully supported."
)
assert (
self.parallel_config.dcp_kv_cache_interleave_size
self.parallel_config.cp_kv_cache_interleave_size
<= self.cache_config.block_size
and self.cache_config.block_size
% self.parallel_config.dcp_kv_cache_interleave_size
% self.parallel_config.cp_kv_cache_interleave_size
== 0
), (
f"Block_size({self.cache_config.block_size}) should be greater "
"than or equal to and divisible by dcp_kv_cache_interleave_size "
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
"than or equal to and divisible by cp_kv_cache_interleave_size "
f"({self.parallel_config.cp_kv_cache_interleave_size})."
)

assert (
self.parallel_config.dcp_kv_cache_interleave_size == 1
self.parallel_config.cp_kv_cache_interleave_size == 1
or self.speculative_config is None
), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now."
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."

# Do this after all the updates to compilation_config.mode
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
Expand Down
Loading