Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions tests/distributed/test_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
dcp_size: int
pcp_size: int
eager_mode: bool
chunked_prefill: bool


class CPTestOptions(NamedTuple):
multi_node_only: bool
load_format: str | None = None
attn_backend: str = "FLASH_ATTN"


@dataclass
Expand All @@ -52,20 +54,25 @@ def detailed(
tp_base: int = 4,
pp_base: int = 1,
dcp_base: int = 1,
pcp_base: int = 1,
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: str | None = None,
attn_backend: str = "FLASH_ATTN",
):
parallel_setups = []
for eager_mode_val in [False]:
for pp_multiplier in [1]:
for dcp_multiplier in [0.5, 1]:
# TODO(qcs): Test the effect of mixed activation
# when PCP and DCP are compatible.
for pcp_multiplier, dcp_multiplier in zip([1, 2, 1], [0.5, 1, 1]):
for chunked_prefill_val in [True]:
parallel_setups.append(
ParallelSetup(
tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
dcp_size=int(dcp_multiplier * tp_base),
pcp_size=int(pcp_multiplier * pcp_base),
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val,
)
Expand All @@ -75,7 +82,9 @@ def detailed(
distributed_backends=["mp"],
runner=runner,
test_options=CPTestOptions(
multi_node_only=multi_node_only, load_format=load_format
multi_node_only=multi_node_only,
load_format=load_format,
attn_backend=attn_backend,
),
)

Expand Down Expand Up @@ -108,11 +117,12 @@ def _compare_cp_with_tp(
tp_size,
pp_size,
dcp_size,
pcp_size,
eager_mode,
chunked_prefill,
) = parallel_setup

multi_node_only, load_format = test_options
multi_node_only, load_format, attn_backend = test_options

model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")
Expand Down Expand Up @@ -155,7 +165,7 @@ def _compare_cp_with_tp(
"--max-model-len",
"2048",
"--max-num-seqs",
"8",
"16",
]
if chunked_prefill:
common_args.append("--enable-chunked-prefill")
Expand All @@ -172,6 +182,10 @@ def _compare_cp_with_tp(
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])

cp_env = tp_env = {
"VLLM_ATTENTION_BACKEND": attn_backend,
}

cp_args = [
*common_args,
"--tensor-parallel-size",
Expand All @@ -180,6 +194,8 @@ def _compare_cp_with_tp(
str(pp_size),
"--decode-context-parallel-size",
str(dcp_size),
"--prefill-context-parallel-size",
str(pcp_size),
"--distributed-executor-backend",
distributed_backend,
]
Expand All @@ -198,19 +214,24 @@ def _compare_cp_with_tp(
model_id,
cp_args,
tp_args,
cp_env,
tp_env,
method=method,
max_wait_seconds=720,
)


CP_TEXT_GENERATION_MODELS = {
# [MLA attention only]
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
],
"bigcode/gpt_bigcode-santacoder": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
CPTestSettings.detailed(attn_backend="FLASHINFER"),
CPTestSettings.detailed(tp_base=2, attn_backend="FLASHINFER"),
],
}

Expand Down
13 changes: 13 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ class AttentionImpl(ABC, Generic[T]):
dcp_world_size: int
dcp_rank: int

pcp_world_size: int
pcp_rank: int

def __new__(cls, *args, **kwargs):
# use __new__ so that all subclasses will call this
self = super().__new__(cls)
Expand All @@ -139,6 +142,16 @@ def __new__(cls, *args, **kwargs):
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
try:
from vllm.distributed.parallel_state import get_pcp_group

self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group
except AssertionError:
# PCP might not be initialized in testing
self.pcp_world_size = 1
self.pcp_rank = 0

self.need_to_return_lse_for_decode = (
self.dcp_world_size > 1 and self.can_return_lse_for_decode
)
Expand Down
33 changes: 31 additions & 2 deletions vllm/attention/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,11 @@ def correct_attn_out(
return out, lse


def cp_lse_ag_out_rs(
def _cp_lse_common(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None,
return_lse=False,
):
"""
cp_attn_out: [ B, H, D ]
Expand All @@ -195,6 +194,21 @@ def cp_lse_ag_out_rs(
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
assert out.is_contiguous()
return out, lse


def cp_lse_ag_out_rs(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None,
return_lse: bool = False,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
out = cp_group.reduce_scatter(out, dim=1)

if return_lse:
Expand All @@ -205,6 +219,21 @@ def cp_lse_ag_out_rs(
return out


def cp_lse_ag_out_ar(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
out = cp_group.all_reduce(out)
return out


@triton.jit
def _pack_seq_kernel(
x_ptr, # [N, D]
Expand Down
8 changes: 7 additions & 1 deletion vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class ParallelConfig:
"""Number of pipeline parallel groups."""
tensor_parallel_size: int = 1
"""Number of tensor parallel groups."""
prefill_context_parallel_size: int = 1
"""Number of prefill context parallel groups."""
data_parallel_size: int = 1
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
Expand Down Expand Up @@ -467,7 +469,11 @@ def __post_init__(self) -> None:
)

# Continue with the rest of the initialization
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
self.world_size = (
self.pipeline_parallel_size
* self.tensor_parallel_size
* self.prefill_context_parallel_size
)

if self.distributed_executor_backend == "external_launcher":
logger.info("Using external launcher for distributed inference.")
Expand Down
9 changes: 9 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,15 @@ def __post_init__(self):
):
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

# prefill context parallel do not support full cudagraphs now.
if self.parallel_config.prefill_context_parallel_size > 1:
logger.warning(
"Prefill context parallel (PCP) is enabled, which is "
"incompatible with full CUDA graphs. Set "
"cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

# decode context parallel do not support full cudagraphs now.
if self.parallel_config.decode_context_parallel_size > 1:
logger.warning(
Expand Down
Loading