diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index b16fd0d06b14..7e4713b8aece 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -31,7 +31,7 @@ class ParallelSetup(NamedTuple): tp_size: int pp_size: int dcp_size: int - dcp_kv_cache_interleave_size: int + cp_kv_cache_interleave_size: int eager_mode: bool chunked_prefill: bool @@ -55,7 +55,7 @@ def detailed( tp_base: int = 4, pp_base: int = 1, dcp_base: int = 1, - dcp_kv_cache_interleave_size: int = 1, + cp_kv_cache_interleave_size: int = 1, multi_node_only: bool = False, runner: RunnerOption = "auto", load_format: str | None = None, @@ -71,7 +71,7 @@ def detailed( tp_size=tp_base, pp_size=pp_multiplier * pp_base, dcp_size=int(dcp_multiplier * tp_base), - dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size, + cp_kv_cache_interleave_size=cp_kv_cache_interleave_size, eager_mode=eager_mode_val, chunked_prefill=chunked_prefill_val, ) @@ -116,7 +116,7 @@ def _compare_cp_with_tp( tp_size, pp_size, dcp_size, - dcp_kv_cache_interleave_size, + cp_kv_cache_interleave_size, eager_mode, chunked_prefill, ) = parallel_setup @@ -197,7 +197,7 @@ def _compare_cp_with_tp( "--decode-context-parallel-size", str(dcp_size), "--dcp-kv-cache-interleave-size", - str(dcp_kv_cache_interleave_size), + str(cp_kv_cache_interleave_size), "--distributed-executor-backend", distributed_backend, ] @@ -227,7 +227,7 @@ def _compare_cp_with_tp( "deepseek-ai/DeepSeek-V2-Lite-Chat": [ CPTestSettings.detailed(), CPTestSettings.detailed(tp_base=2), - CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64), + CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64), ], "bigcode/gpt_bigcode-santacoder": [ CPTestSettings.detailed(), diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index 1d925dc1bea8..d95c22fdf0a5 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -15,7 +15,11 @@ ) from tests.kernels.utils import torch_experts from vllm.config import VllmConfig -from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size +from vllm.distributed import ( + get_dp_group, + get_pcp_group, + get_tensor_model_parallel_world_size, +) from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, @@ -561,6 +565,7 @@ def next_power_of_2(x): # make moe config moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( tp_size_=get_tensor_model_parallel_world_size(), + pcp_size_=get_pcp_group().world_size, dp_size_=get_dp_group().world_size, vllm_parallel_config=vllm_config.parallel_config, ) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index b95c8df3469b..824e45897835 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -956,7 +956,7 @@ def test_hybrid_block_table_initialization(): max_num_reqs = 10 max_num_blocks_per_req = 20 max_num_batched_tokens = 512 - dcp_kv_cache_interleave_size = 8 + cp_kv_cache_interleave_size = 8 block_table = BlockTable( block_size=block_size, @@ -966,7 +966,7 @@ def test_hybrid_block_table_initialization(): pin_memory=False, device=torch.device(DEVICE), kernel_block_size=kernel_block_sizes[0], - dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size, + cp_kv_cache_interleave_size=cp_kv_cache_interleave_size, ) # Verify hybrid block configuration diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 9275d70fd86a..d28bc065852d 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -266,6 +266,12 @@ class AttentionImpl(ABC, Generic[T]): dcp_world_size: int dcp_rank: int + pcp_world_size: int + pcp_rank: int + + total_cp_world_size: int + total_cp_rank: int + def __new__(cls, *args, **kwargs): # use __new__ so that all subclasses will call this self = super().__new__(cls) @@ -278,6 +284,17 @@ def __new__(cls, *args, **kwargs): # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 + try: + from vllm.distributed.parallel_state import get_pcp_group + + self.pcp_world_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group + except AssertionError: + self.pcp_world_size = 1 + self.pcp_rank = 0 + self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size + self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank + self.need_to_return_lse_for_decode = ( self.dcp_world_size > 1 and self.can_return_lse_for_decode ) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 2cbb5c91cc3b..67c5f7dbba9c 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -169,12 +169,11 @@ def correct_attn_out( return out, lse -def cp_lse_ag_out_rs( +def _cp_lse_common( cp_attn_out: torch.Tensor, cp_attn_lse: torch.Tensor, cp_group: GroupCoordinator, - ctx: CPTritonContext = None, - return_lse=False, + ctx: CPTritonContext | None = None, ): """ cp_attn_out: [ B, H, D ] @@ -195,6 +194,22 @@ def cp_lse_ag_out_rs( cp_attn_lse = cp_attn_lse.contiguous() lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) + assert out.is_contiguous() + return out, lse + + +def cp_lse_ag_out_rs( + cp_attn_out: torch.Tensor, + cp_attn_lse: torch.Tensor, + cp_group: GroupCoordinator, + ctx: CPTritonContext | None = None, + return_lse: bool = False, +): + """ + cp_attn_out: [ B, H, D ] + cp_attn_lse: [ B, H ] + """ + out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx) out = cp_group.reduce_scatter(out, dim=1) if return_lse: @@ -205,6 +220,25 @@ def cp_lse_ag_out_rs( return out +def cp_lse_ag_out_ar( + cp_attn_out: torch.Tensor, + cp_attn_lse: torch.Tensor, + cp_group: GroupCoordinator, + ctx: CPTritonContext | None = None, + return_lse: bool = False, +): + """ + cp_attn_out: [ B, H, D ] + cp_attn_lse: [ B, H ] + """ + out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx) + out = cp_group.all_reduce(out) + + if return_lse: + return out, lse + return out + + @triton.jit def _pack_seq_kernel( x_ptr, # [N, D] diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 9a6326d62e82..e3ed426d86fe 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -72,6 +72,8 @@ class ParallelConfig: """Number of pipeline parallel groups.""" tensor_parallel_size: int = 1 """Number of tensor parallel groups.""" + prefill_context_parallel_size: int = 1 + """Number of prefill context parallel groups.""" data_parallel_size: int = 1 """Number of data parallel groups. MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.""" @@ -240,14 +242,25 @@ class is dynamically inherited by the worker class. This is used to inject needs to be divisible by dcp_size.""" dcp_kv_cache_interleave_size: int = 1 - """Interleave size of kv_cache storage while using dcp or cp > 1, - store interleave_size tokens on (d)cp i, - then store next interleave_size tokens on (d)cp i+1. - Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size. - Interleave_size=block_size: block-level align, first fill the block on first rank, - token is stored on rank i+1 block j after rank i block j is full. - Block_size should be greater than or equal to dcp_kv_cache_interleave_size. - Block_size should be divisible by dcp_kv_cache_interleave_size. + """ + Interleave size of kv_cache storage while using DCP. + dcp_kv_cache_interleave_size has been replaced by cp_kv_cache_interleave_size, + and will be deprecated when PCP is fully supported. + + """ + cp_kv_cache_interleave_size: int = 1 + """Interleave size of kv_cache storage while using DCP or PCP. + For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`, + and `total_cp_world_size = pcp_world_size * dcp_world_szie`. + store interleave_size tokens on total_cp_rank i, + then store next interleave_size tokens on taotal_cp_rank i+1. + Interleave_size=1: token-level alignment, where token `i` is stored on + total_cp_rank `i % total_cp_world_size`. + Interleave_size=block_size: block-level alignment, where tokens are + first populated to the preceding ranks. Tokens are then stored + in (rank i+1, block j) only after (rank i, block j) is fully occupied. + Block_size should be greater than or equal to cp_kv_cache_interleave_size. + Block_size should be divisible by cp_kv_cache_interleave_size. """ _api_process_count: int = Field(default=1, gt=0) @@ -312,6 +325,11 @@ def _validate_parallel_config(self) -> Self: "num_redundant_experts." ) + if self.prefill_context_parallel_size > 1: + raise ValueError( + "Prefill context parallelism is not fully supported. " + "Please set prefill_context_parallel_size to 1." + ) return self @property @@ -508,7 +526,11 @@ def __post_init__(self) -> None: ) # Continue with the rest of the initialization - self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size + self.world_size = ( + self.pipeline_parallel_size + * self.tensor_parallel_size + * self.prefill_context_parallel_size + ) if self.distributed_executor_backend == "external_launcher": logger.info("Using external launcher for distributed inference.") diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 672b004c4aa5..d64e315b4fe3 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -481,6 +481,14 @@ def __post_init__(self): "Overriding cudagraph_mode to PIECEWISE." ) self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + # prefill context parallel do not support full cudagraphs + elif self.parallel_config.prefill_context_parallel_size > 1: + logger.warning_once( + "Prefill context parallel (PCP) is enabled, which is " + "incompatible with full CUDA graphs. " + "Overriding cudagraph_mode to PIECEWISE." + ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE elif self.model_config is not None: if self.model_config.pooler_config is not None: logger.warning_once( @@ -610,22 +618,34 @@ def __post_init__(self): # If DCP, ensure the block size is right. if self.parallel_config.decode_context_parallel_size > 1: + if self.parallel_config.dcp_kv_cache_interleave_size > 1 and ( + self.parallel_config.cp_kv_cache_interleave_size + != self.parallel_config.dcp_kv_cache_interleave_size + ): + self.parallel_config.cp_kv_cache_interleave_size = ( + self.parallel_config.dcp_kv_cache_interleave_size + ) + logger.warning_once( + "cp_kv_cache_interleave_size is overridden by dcp_kv_cache" + "_interleave_size. And dcp-kv-cache-interleave-size will be " + "deprecated when PCP is fully supported." + ) assert ( - self.parallel_config.dcp_kv_cache_interleave_size + self.parallel_config.cp_kv_cache_interleave_size <= self.cache_config.block_size and self.cache_config.block_size - % self.parallel_config.dcp_kv_cache_interleave_size + % self.parallel_config.cp_kv_cache_interleave_size == 0 ), ( f"Block_size({self.cache_config.block_size}) should be greater " - "than or equal to and divisible by dcp_kv_cache_interleave_size " - f"({self.parallel_config.dcp_kv_cache_interleave_size})." + "than or equal to and divisible by cp_kv_cache_interleave_size " + f"({self.parallel_config.cp_kv_cache_interleave_size})." ) assert ( - self.parallel_config.dcp_kv_cache_interleave_size == 1 + self.parallel_config.cp_kv_cache_interleave_size == 1 or self.speculative_config is None - ), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now." + ), "MTP with cp_kv_cache_interleave_size > 1 is not supported now." # Do this after all the updates to compilation_config.mode if self.compilation_config.mode == CompilationMode.VLLM_COMPILE: diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 852c4c644433..f81612fd1f4a 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1098,6 +1098,12 @@ def get_dcp_group() -> GroupCoordinator: _PP: GroupCoordinator | None = None + +def get_pp_group() -> GroupCoordinator: + assert _PP is not None, "pipeline model parallel group is not initialized" + return _PP + + _DP: GroupCoordinator | None = None @@ -1114,9 +1120,12 @@ def get_ep_group() -> GroupCoordinator: return _EP -def get_pp_group() -> GroupCoordinator: - assert _PP is not None, "pipeline model parallel group is not initialized" - return _PP +_PCP: GroupCoordinator | None = None + + +def get_pcp_group() -> GroupCoordinator: + assert _PCP is not None, "prefill context parallel group is not initialized" + return _PCP @deprecated( @@ -1276,6 +1285,7 @@ def init_distributed_environment( def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + prefill_context_model_parallel_size: int = 1, decode_context_model_parallel_size: int | None = 1, backend: str | None = None, ) -> None: @@ -1325,7 +1335,11 @@ def initialize_model_parallel( # to get group_ranks for each dimension, transpose that dimension to the # last dimension, then reshape to 2D, then unbind the last dimension all_ranks = torch.arange(world_size).reshape( - -1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size + -1, + data_parallel_size, + pipeline_model_parallel_size, + prefill_context_model_parallel_size, + tensor_model_parallel_size, ) # noqa # Build the tensor model-parallel groups. @@ -1360,11 +1374,23 @@ def initialize_model_parallel( group_name="dcp", ) + global _PCP + assert _PCP is None, "prefill context parallel group is already initialized" + group_ranks = ( + all_ranks.transpose(3, 4) + .reshape(-1, prefill_context_model_parallel_size) + .unbind(0) + ) + group_ranks = [x.tolist() for x in group_ranks] + _PCP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="pcp" + ) + # Build the pipeline model-parallel groups. global _PP assert _PP is None, "pipeline model parallel group is already initialized" group_ranks = ( - all_ranks.transpose(2, 3).reshape(-1, pipeline_model_parallel_size).unbind(0) + all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] _PP = init_model_parallel_group( @@ -1373,7 +1399,7 @@ def initialize_model_parallel( global _DP assert _DP is None, "data parallel group is already initialized" - group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0) + group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] _DP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, group_name="dp" @@ -1383,7 +1409,12 @@ def initialize_model_parallel( assert _EP is None, "expert parallel group is already initialized" group_ranks = ( all_ranks.transpose(1, 2) - .reshape(-1, data_parallel_size * tensor_model_parallel_size) + .reshape( + -1, + data_parallel_size + * prefill_context_model_parallel_size + * tensor_model_parallel_size, + ) .unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] @@ -1393,11 +1424,13 @@ def initialize_model_parallel( logger.info_once( "rank %s in world size %s is assigned as " - "DP rank %s, PP rank %s, TP rank %s, EP rank %s", + "DP rank %s, PP rank %s, PCP rank %s, " + "TP rank %s, EP rank %s", rank, world_size, _DP.rank_in_group, _PP.rank_in_group, + _PCP.rank_in_group, _TP.rank_in_group, _EP.rank_in_group, ) @@ -1406,6 +1439,7 @@ def initialize_model_parallel( def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, + prefill_context_model_parallel_size: int = 1, decode_context_model_parallel_size: int | None = 1, backend: str | None = None, ) -> None: @@ -1418,6 +1452,7 @@ def ensure_model_parallel_initialized( initialize_model_parallel( tensor_model_parallel_size, pipeline_model_parallel_size, + prefill_context_model_parallel_size, decode_context_model_parallel_size, backend, ) @@ -1434,6 +1469,12 @@ def ensure_model_parallel_initialized( f"got: {pp_world_size=} vs. " f"wanted: {pipeline_model_parallel_size=}" ) + pcp_world_size = get_pcp_group().world_size + assert pcp_world_size == prefill_context_model_parallel_size, ( + "prefill context parallel group already initialized, but of unexpected size: " + f"{pcp_world_size=} vs. " + f"{prefill_context_model_parallel_size=}" + ) def prepare_communication_buffer_for_model(model: torch.nn.Module): @@ -1445,6 +1486,8 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module): """ if _TP is not None: _TP.prepare_communication_buffer_for_model(model) + if _PCP is not None: + _PCP.prepare_communication_buffer_for_model(model) if _PP is not None: _PP.prepare_communication_buffer_for_model(model) if _DP is not None: @@ -1520,16 +1563,21 @@ def destroy_model_parallel(): _TP.destroy() _TP = None - global _PP - if _PP: - _PP.destroy() - _PP = None - global _DCP if _DCP: _DCP.destroy() _DCP = None + global _PCP + if _PCP: + _PCP.destroy() + _PCP = None + + global _PP + if _PP: + _PP.destroy() + _PP = None + global _DP if _DP: _DP.destroy() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e2f7326448b3..68205b6079d7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -389,8 +389,10 @@ class EngineArgs: nnodes: int = ParallelConfig.nnodes node_rank: int = ParallelConfig.node_rank tensor_parallel_size: int = ParallelConfig.tensor_parallel_size + prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size + cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: int | None = None data_parallel_start_rank: int | None = None @@ -770,6 +772,15 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--dcp-kv-cache-interleave-size", **parallel_kwargs["dcp_kv_cache_interleave_size"], ) + parallel_group.add_argument( + "--cp-kv-cache-interleave-size", + **parallel_kwargs["cp_kv_cache_interleave_size"], + ) + parallel_group.add_argument( + "--prefill-context-parallel-size", + "-pcp", + **parallel_kwargs["prefill_context_parallel_size"], + ) parallel_group.add_argument( "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"] ) @@ -1600,6 +1611,7 @@ def create_engine_config( parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, + prefill_context_parallel_size=self.prefill_context_parallel_size, data_parallel_size=self.data_parallel_size, data_parallel_rank=self.data_parallel_rank or 0, data_parallel_external_lb=data_parallel_external_lb, @@ -1631,6 +1643,7 @@ def create_engine_config( worker_extension_cls=self.worker_extension_cls, decode_context_parallel_size=self.decode_context_parallel_size, dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size, + cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size, _api_process_count=self._api_process_count, _api_process_rank=self._api_process_rank, ) @@ -1952,6 +1965,15 @@ def _set_default_args( default_prefix_caching, ) = self.get_chunked_prefill_prefix_caching_defaults(model_config) + if self.prefill_context_parallel_size > 1: + default_chunked_prefill = False + default_prefix_caching = False + logger.warning( + "--prefill-context-parallel-size > 1 is not compatible with " + "chunked prefill and prefix caching now. Chunked prefill " + "and prefix caching have been disabled by default." + ) + if self.enable_chunked_prefill is None: self.enable_chunked_prefill = default_chunked_prefill diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index a7bd64b1c65e..21eb4d590a7d 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -8,7 +8,11 @@ import vllm.envs as envs from vllm.config import ParallelConfig -from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank +from vllm.distributed import ( + get_dp_group, + get_pcp_group, + get_tensor_model_parallel_rank, +) from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_DTYPES, @@ -684,9 +688,11 @@ def biased_moe_quant_config( @dataclass class FusedMoEParallelConfig: tp_size: int + pcp_size: int dp_size: int ep_size: int tp_rank: int + pcp_rank: int dp_rank: int ep_rank: int @@ -713,19 +719,22 @@ def use_deepep_ll_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" @staticmethod - def flatten_tp_across_dp( - tp_size: int, dp_size: int, dp_rank: int + def flatten_tp_across_dp_and_pcp( + tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int ) -> tuple[int, int]: tp_rank = 0 if tp_size == 1 else get_tensor_model_parallel_rank() - # There are actually dp_size * tp_size devices. Update tp_size - # and tp_rank so we shard across all devices. - flatten_tp_size = dp_size * tp_size - flatten_tp_rank = dp_rank * tp_size + tp_rank + # There are actually dp_size * pcp_size * tp_size devices. + # Update tp_size and tp_rank so we shard across all devices. + flatten_tp_size = dp_size * pcp_size * tp_size + flatten_tp_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank return flatten_tp_size, flatten_tp_rank @staticmethod def make( - tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig + tp_size_: int, + pcp_size_: int, + dp_size_: int, + vllm_parallel_config: ParallelConfig, ) -> "FusedMoEParallelConfig": """ Determine MoE parallel configuration. Based on the input `tp_size_`, @@ -734,19 +743,22 @@ def make( Args: tp_size_ (int): `tp_size` passed into the FusedMoE constructor. + pcp_size_ (int): `pcp_size` passed into the FusedMoE constructor. dp_size_ (int): `dp_size` passed into the FusedMoE constructor. vllm_parallel_config (ParallelConfig): vLLM's parallel config object which contains the `enable_expert_parallel` flag. Examples: When there is no parallelism requested, - i.e. `tp_size_` = `dp_size_` = 1, we simply return the sizes + i.e. `tp_size_` = `pcp_size_` = `dp_size_` = 1, we simply return the sizes unaltered and the ranks set to 0. - Expert Parallelism is considered only when either `dp_size_` or + Expert Parallelism is considered only when either `dp_size_`, `pcp_size_` or `tp_size_` is non trivial. - When TP = 2, DP = 1 and EP = False, the configuration on different + Note that PCP serves the same function as DP here. + + When TP = 2, DP(PCP) = 1 and EP = False, the configuration on different devices: - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // @@ -754,7 +766,7 @@ def make( - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} - Comment : Tensors are sharded across 2 devices. - When TP = 1, DP = 2 and EP = False, the configuration on different + When TP = 1, DP(PCP) = 2 and EP = False, the configuration on different devices: - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} @@ -762,7 +774,7 @@ def make( - Comment: There are 2 engine instances and the tensors are sharded across 2 decvices. - When TP = 2, DP = 2 and EP = False, the configuration on different + When TP = 2, DP(PCP) = 2 and EP = False, the configuration on different devices: - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} @@ -772,14 +784,14 @@ def make( - Comment: There are 2 engine instances and the tensors are sharded across 4 devices. - When, TP = 2, DP = 1 and EP = True, the configuration on different + When, TP = 2, DP(PCP) = 1 and EP = True, the configuration on different devices: - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} - Comment: The experts are split between the 2 devices. - When, TP = 1, DP = 2 and EP = True, the configuration on different + When, TP = 1, DP(PCP) = 2 and EP = True, the configuration on different devices: - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} @@ -787,7 +799,7 @@ def make( - Comment: There are 2 engine instances and the experts are split between the 2 devices. - When TP = 2, DP = 2 and EP = True, the configuration on different + When TP = 2, DP(PCP) = 2 and EP = True, the configuration on different devices: - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} @@ -798,18 +810,25 @@ def make( between the 4 devices. """ - use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel + use_ep = ( + dp_size_ * pcp_size_ * tp_size_ > 1 + and vllm_parallel_config.enable_expert_parallel + ) dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 - tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp( - tp_size_, dp_size_, dp_rank + pcp_size = pcp_size_ + pcp_rank = get_pcp_group().rank_in_group if pcp_size > 1 else 0 + tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp( + tp_size_, dp_size_, dp_rank, pcp_size_, pcp_rank ) if not use_ep: return FusedMoEParallelConfig( tp_size=tp_size, tp_rank=tp_rank, + pcp_size=pcp_size, + pcp_rank=pcp_rank, dp_size=dp_size, dp_rank=dp_rank, ep_size=1, @@ -826,6 +845,8 @@ def make( return FusedMoEParallelConfig( tp_size=1, tp_rank=0, + pcp_size=pcp_size, + pcp_rank=pcp_rank, dp_size=dp_size, dp_rank=dp_rank, ep_size=ep_size, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 023132acfed3..fe9e478253ea 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -18,6 +18,7 @@ from vllm.distributed import ( get_dp_group, get_ep_group, + get_pcp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) @@ -307,6 +308,7 @@ def __init__( tp_size: int | None = None, ep_size: int | None = None, dp_size: int | None = None, + pcp_size: int | None = None, prefix: str = "", custom_routing_function: Callable | None = None, scoring_func: str = "softmax", @@ -362,12 +364,14 @@ def __init__( tp_size if tp_size is not None else get_tensor_model_parallel_world_size() ) dp_size_ = dp_size if dp_size is not None else get_dp_group().world_size + pcp_size_ = pcp_size if pcp_size is not None else get_pcp_group().world_size self.is_sequence_parallel = is_sequence_parallel self.sp_size = tp_size_ if is_sequence_parallel else 1 self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( tp_size_=tp_size_, + pcp_size_=pcp_size_, dp_size_=dp_size_, vllm_parallel_config=vllm_config.parallel_config, ) @@ -646,6 +650,10 @@ def tp_size(self): def dp_size(self): return self.moe_parallel_config.dp_size + @property + def pcp_size(self): + return self.moe_parallel_config.pcp_size + @property def ep_size(self): return self.moe_parallel_config.ep_size @@ -658,6 +666,10 @@ def tp_rank(self): def dp_rank(self): return self.moe_parallel_config.dp_rank + @property + def pcp_rank(self): + return self.moe_parallel_config.pcp_rank + @property def ep_rank(self): return self.moe_parallel_config.ep_rank @@ -1753,6 +1765,19 @@ def forward_impl( hidden_states, router_logits, self.is_sequence_parallel ) + # NOTE: Similar with DP, PCP also needs dispatch and combine. For + # simplicity, AgRsAll2All was added separately for PCP here. Maybe + # we should modify All2AllManager abstract to better support PCP. + if self.pcp_size > 1: + hidden_states = get_pcp_group().all_gather( + hidden_states, + dim=0, + ) + router_logits = get_pcp_group().all_gather( + router_logits, + dim=0, + ) + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -1809,6 +1834,13 @@ def forward_impl( def combine_output(states: torch.Tensor) -> torch.Tensor: if do_naive_dispatch_combine: states = get_ep_group().combine(states, self.is_sequence_parallel) + + if self.pcp_size > 1: + states = get_pcp_group().reduce_scatter( + states, + dim=0, + ) + return states if self.shared_experts is not None: diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 7df3b087ccb8..614ca1993b78 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -13,6 +13,7 @@ from vllm.distributed import ( get_dp_group, get_ep_group, + get_pcp_group, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -323,10 +324,12 @@ def _load_weights_mxfp4( # In MoE, we need to flatten the tensor parallel size across the data # parallel size when EP is disabled. - tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp( + tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp( tp_size=get_tensor_model_parallel_world_size(), dp_size=get_dp_group().world_size, dp_rank=get_dp_group().rank_in_group, + pcp_size=get_pcp_group().world_size, + pcp_rank=get_pcp_group().rank_in_group, ) intermediate_size = self.config.intermediate_size @@ -508,10 +511,12 @@ def _load_weights_other( # In MoE, we need to flatten the tensor parallel size across the data # parallel size when EP is disabled. - tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp( + tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp( tp_size=get_tensor_model_parallel_world_size(), dp_size=get_dp_group().world_size, dp_rank=get_dp_group().rank_in_group, + pcp_size=get_pcp_group().world_size, + pcp_rank=get_pcp_group().rank_in_group, ) intermediate_size = self.config.intermediate_size diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fdc99a0df1c8..cf3c1d05f5b3 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -265,8 +265,8 @@ def __init__( self.dcp_world_size = 1 self.dcp_rank = 0 - self.dcp_kv_cache_interleave_size = ( - self.parallel_config.dcp_kv_cache_interleave_size + self.cp_kv_cache_interleave_size = ( + self.parallel_config.cp_kv_cache_interleave_size ) self.use_full_cuda_graph = ( @@ -388,7 +388,7 @@ def schedule( dcp_context_kv_lens_cpu, self.dcp_world_size, self.dcp_rank, - self.dcp_kv_cache_interleave_size, + self.cp_kv_cache_interleave_size, ) dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) max_dcp_context_kv_len = dcp_context_kv_lens.max().item() diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 2ccdd1f143ce..7a2b83d06917 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -536,7 +536,7 @@ def __init__( # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 - self.dcp_local_block_size = parallel_config.dcp_kv_cache_interleave_size + self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size # Don't try to access the runner on AMD @@ -1286,8 +1286,8 @@ def __init__(self, *args, **kwargs) -> None: get_current_vllm_config() ) ) - self.dcp_kv_cache_interleave_size: int = ( - get_current_vllm_config().parallel_config.dcp_kv_cache_interleave_size + self.cp_kv_cache_interleave_size: int = ( + get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size ) def _flash_attn_varlen_diff_headdims( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 578153cda786..e23bcde85703 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1079,9 +1079,9 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): def get_dcp_local_seq_lens( seq_lens: torch.Tensor, - dcp_world_size: int = 1, + dcp_size: int = 1, dcp_rank: int | None = None, - dcp_kv_cache_interleave_size: int = 1, + cp_kv_cache_interleave_size: int = 1, ) -> torch.Tensor: """While using dcp, kv_cache size stored on each rank may be different, use this function to calculate split decode seq_lens of each dcp rank. @@ -1090,7 +1090,7 @@ def get_dcp_local_seq_lens( num_requests = seq_lens.size(0) if dcp_rank is None: rank_offsets = ( - torch.arange(dcp_world_size, dtype=torch.int32) + torch.arange(dcp_size, dtype=torch.int32) .unsqueeze(0) .repeat(num_requests, 1) ) @@ -1101,15 +1101,15 @@ def get_dcp_local_seq_lens( ) base = ( seq_lens_tiled - // dcp_kv_cache_interleave_size - // dcp_world_size - * dcp_kv_cache_interleave_size + // cp_kv_cache_interleave_size + // dcp_size + * cp_kv_cache_interleave_size ) - remainder = seq_lens_tiled - base * dcp_world_size + remainder = seq_lens_tiled - base * dcp_size remainder = torch.clip( - remainder - rank_offsets * dcp_kv_cache_interleave_size, + remainder - rank_offsets * cp_kv_cache_interleave_size, 0, - dcp_kv_cache_interleave_size, + cp_kv_cache_interleave_size, ) dcp_local_seq_lens = base + remainder return dcp_local_seq_lens.squeeze(1) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 137e5e0cdb6d..1531b61f88fe 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -27,6 +27,7 @@ def __init__( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + pcp_world_size: int, ): self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len @@ -44,6 +45,7 @@ def __init__( block_pool=self.block_pool, kv_cache_group_id=i, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups) ) @@ -210,6 +212,7 @@ def __init__( use_eagle: bool, enable_kv_cache_events: bool, dcp_world_size: int, + pcp_world_size: int, ): super().__init__( kv_cache_config, @@ -218,6 +221,7 @@ def __init__( False, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) self.num_single_type_manager = len(self.single_type_managers) @@ -250,6 +254,7 @@ def __init__( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + pcp_world_size: int, ): super().__init__( kv_cache_config, @@ -258,12 +263,16 @@ def __init__( enable_caching, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size self.dcp_world_size = dcp_world_size + self.pcp_world_size = pcp_world_size if dcp_world_size > 1: self.block_size *= dcp_world_size + if pcp_world_size > 1: + self.block_size *= pcp_world_size assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "UnitaryKVCacheCoordinator assumes only one kv cache group" ) @@ -281,6 +290,7 @@ def find_longest_cache_hit( kv_cache_spec=self.kv_cache_spec, use_eagle=self.use_eagle, dcp_world_size=self.dcp_world_size, + pcp_world_size=self.pcp_world_size, ) return hit_blocks, len(hit_blocks[0]) * self.block_size @@ -302,6 +312,7 @@ def __init__( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + pcp_world_size: int, ): super().__init__( kv_cache_config, @@ -310,8 +321,10 @@ def __init__( enable_caching, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) assert dcp_world_size == 1, "DCP not support hybrid attn now." + assert pcp_world_size == 1, "PCP not support hybrid attn now." self.verify_and_split_kv_cache_groups() def verify_and_split_kv_cache_groups(self) -> None: @@ -452,6 +465,7 @@ def get_kv_cache_coordinator( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + pcp_world_size: int, ) -> KVCacheCoordinator: if not enable_caching: return KVCacheCoordinatorNoPrefixCache( @@ -460,6 +474,7 @@ def get_kv_cache_coordinator( use_eagle, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator( @@ -469,6 +484,7 @@ def get_kv_cache_coordinator( enable_caching, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) return HybridKVCacheCoordinator( kv_cache_config, @@ -477,4 +493,5 @@ def get_kv_cache_coordinator( enable_caching, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 7f405fc248ac..2012c3fef88b 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -100,6 +100,7 @@ def __init__( log_stats: bool = False, enable_kv_cache_events: bool = False, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> None: self.max_model_len = max_model_len @@ -124,12 +125,9 @@ def __init__( 0 ].kv_cache_spec.block_size - if dcp_world_size > 1: + if dcp_world_size * pcp_world_size > 1: assert len(kv_cache_config.kv_cache_groups) == 1 - # Note(hc): need revisit. When both DCP and any future - # PCP are enabled, the block_size may need to be scaled - # by a factor of dcp_size × pcp_size? - self.block_size *= dcp_world_size + self.block_size *= dcp_world_size * pcp_world_size self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, @@ -138,6 +136,7 @@ def __init__( enable_caching=self.enable_caching, enable_kv_cache_events=enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6e026215d402..01ecd881115d 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1219,11 +1219,16 @@ def _report_kv_cache_config( // len(kv_cache_config.kv_cache_groups) * min_block_size ) - if vllm_config.parallel_config.decode_context_parallel_size > 1: - num_tokens *= vllm_config.parallel_config.decode_context_parallel_size + dcp_size = vllm_config.parallel_config.decode_context_parallel_size + pcp_size = vllm_config.parallel_config.prefill_context_parallel_size + if pcp_size * dcp_size > 1: + num_tokens *= pcp_size * dcp_size logger.info( - "Multiplying the GPU KV cache size by the dcp_world_size %d.", - vllm_config.parallel_config.decode_context_parallel_size, + "Multiplying the GPU KV cache size by the cp_world_size %d " + "(pcp_world_size %d * dcp_world_size %d).", + pcp_size * dcp_size, + pcp_size, + dcp_size, ) num_tokens_str = f"{num_tokens:,}" logger.info_once("GPU KV cache size: %s tokens", num_tokens_str, scope="local") diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4323141c435b..4cc4c29591cc 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -121,6 +121,7 @@ def __init__( self.block_size = block_size self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size + self.pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size # req_id -> Request self.requests: dict[str, Request] = {} @@ -183,6 +184,7 @@ def __init__( log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, + pcp_world_size=self.pcp_world_size, ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 14ac83028ee4..d90ec550f766 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -32,6 +32,7 @@ def __init__( block_pool: BlockPool, kv_cache_group_id: int, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -42,8 +43,9 @@ def __init__( """ self.block_size = kv_cache_spec.block_size self.dcp_world_size = dcp_world_size - if self.dcp_world_size > 1: - self.block_size *= dcp_world_size + self.pcp_world_size = pcp_world_size + if dcp_world_size * pcp_world_size > 1: + self.block_size *= dcp_world_size * pcp_world_size self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool @@ -212,6 +214,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ Get the longest cache hit prefix of the blocks that is not longer than @@ -303,6 +306,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) @@ -314,8 +318,8 @@ def find_longest_cache_hit( [] for _ in range(len(kv_cache_group_ids)) ) block_size = kv_cache_spec.block_size - if dcp_world_size > 1: - block_size *= dcp_world_size + if dcp_world_size * pcp_world_size > 1: + block_size *= dcp_world_size * pcp_world_size max_num_blocks = max_length // block_size for block_hash in itertools.islice(block_hashes, max_num_blocks): # block_hashes is a chain of block hashes. If a block hash is not @@ -362,11 +366,13 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, SlidingWindowSpec), ( "SlidingWindowManager can only be used for sliding window groups" ) assert dcp_world_size == 1, "DCP not support sliding window attn now." + assert pcp_world_size == 1, "PCP not support sliding window attn now." # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window @@ -476,6 +482,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ For chunked local attention, we need to find the longest cache hit @@ -516,6 +523,7 @@ def find_longest_cache_hit( "Hybrid KV cache is not supported for " + "eagle + chunked local attention." ) assert dcp_world_size == 1, "DCP not support chunked local attn now." + assert pcp_world_size == 1, "PCP not support chunked local attn now." max_num_blocks = max_length // kv_cache_spec.block_size if max_length > 0: local_attention_start_idx = ( @@ -611,11 +619,13 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, MambaSpec), ( "MambaManager can only be used for mamba groups" ) assert dcp_world_size == 1, "DCP not support mamba now." + assert pcp_world_size == 1, "PCP not support mamba now." computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids)) ) @@ -705,6 +715,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, CrossAttentionSpec), ( "CrossAttentionManager can only be used for cross-attention groups" diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 3a25827cec38..6be19894d332 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -128,6 +128,7 @@ def __init__( scheduler_block_size = ( vllm_config.cache_config.block_size * vllm_config.parallel_config.decode_context_parallel_size + * vllm_config.parallel_config.prefill_context_parallel_size ) self.scheduler: SchedulerInterface = Scheduler( diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index ad2ece50f981..7e8ebe25c460 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -35,6 +35,7 @@ get_dp_group, get_ep_group, get_inner_dp_world_group, + get_pcp_group, get_pp_group, get_tp_group, ) @@ -110,12 +111,14 @@ def _init_executor(self) -> None: f"({self.parallel_config.nnodes_within_dp}). " ) self.local_world_size = self.parallel_config.local_world_size - tensor_parallel_size = self.parallel_config.tensor_parallel_size - pp_parallel_size = self.parallel_config.pipeline_parallel_size - assert self.world_size == tensor_parallel_size * pp_parallel_size, ( + tp_size = self.parallel_config.tensor_parallel_size + pp_size = self.parallel_config.pipeline_parallel_size + pcp_size = self.parallel_config.prefill_context_parallel_size + assert self.world_size == tp_size * pp_size * pcp_size, ( f"world_size ({self.world_size}) must be equal to the " - f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" - f"_parallel_size ({pp_parallel_size}). " + f"tensor_parallel_size ({tp_size}) x pipeline" + f"_parallel_size ({pp_size}) x prefill_context" + f"_parallel_size ({pcp_size}). " ) # Set multiprocessing envs @@ -424,7 +427,11 @@ def _get_output_rank(self) -> int: # 16-23, PP rank 2 # 24-31, PP rank 3 # so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3) - return self.world_size - self.parallel_config.tensor_parallel_size + return ( + self.world_size + - self.parallel_config.tensor_parallel_size + * self.parallel_config.prefill_context_parallel_size + ) @dataclass @@ -828,6 +835,8 @@ def setup_proc_title_and_log_prefix(enable_ep: bool) -> None: dp_rank = get_dp_group().rank_in_group pp_size = get_pp_group().world_size pp_rank = get_pp_group().rank_in_group + pcp_size = get_pcp_group().world_size + pcp_rank = get_pcp_group().rank_in_group tp_size = get_tp_group().world_size tp_rank = get_tp_group().rank_in_group dcp_size = get_dcp_group().world_size @@ -837,6 +846,8 @@ def setup_proc_title_and_log_prefix(enable_ep: bool) -> None: process_name += f"_DP{dp_rank}" if pp_size > 1: process_name += f"_PP{pp_rank}" + if pcp_size > 1: + process_name += f"_PCP{pcp_rank}" if tp_size > 1: process_name += f"_TP{tp_rank}" if dcp_size > 1: diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 7f33eb7e699c..751862aa9c76 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -95,10 +95,11 @@ class FullAttentionSpec(AttentionSpec): def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size + pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size # Note(hc): each dcp rank only need save # (max_model_len//dcp_world_size) tokens locally. - if dcp_world_size > 1: - max_model_len = cdiv(max_model_len, dcp_world_size) + if dcp_world_size * pcp_world_size > 1: + max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size) return cdiv(max_model_len, self.block_size) * self.page_size_bytes @classmethod diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 9f6c19e46430..76e17f3797a1 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -4,7 +4,7 @@ import numpy as np import torch -from vllm.distributed import get_dcp_group +from vllm.distributed import get_dcp_group, get_pcp_group from vllm.logger import init_logger from vllm.utils.math_utils import cdiv from vllm.v1.utils import CpuGpuBuffer @@ -22,7 +22,7 @@ def __init__( pin_memory: bool, device: torch.device, kernel_block_size: int, - dcp_kv_cache_interleave_size: int, + cp_kv_cache_interleave_size: int, ): """ Args: @@ -80,6 +80,13 @@ def __init__( else: self._kernel_block_arange = None + try: + self.pcp_world_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group + except AssertionError: + # DCP might not be initialized in testing + self.pcp_world_size = 1 + self.pcp_rank = 0 try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group @@ -87,7 +94,7 @@ def __init__( # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 - self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size + self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size def append_row( self, @@ -131,14 +138,16 @@ def compute_slot_mapping( # NOTE(woosuk): We can't simply use `token_indices // block_size` # here because M (max_model_len) is not necessarily divisible by # block_size. - if self.dcp_world_size > 1: + total_cp_world_size = self.pcp_world_size * self.dcp_world_size + total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank + if total_cp_world_size > 1: # Note(hc): The DCP implement store kvcache with an interleave # style, the kvcache for the token whose token_idx is i is # always stored on the GPU whose dcp_rank equals i % cp_world_size: # Use a "virtual block" which equals to world_size * block_size # for block_table_indices calculation. - virtual_block_size = self.block_size * self.dcp_world_size + virtual_block_size = self.block_size * total_cp_world_size block_table_indices = ( req_indices * self.max_num_blocks_per_req + positions // virtual_block_size @@ -150,16 +159,16 @@ def compute_slot_mapping( virtual_block_offsets = positions % virtual_block_size mask = ( virtual_block_offsets - // self.dcp_kv_cache_interleave_size - % self.dcp_world_size - == self.dcp_rank + // self.cp_kv_cache_interleave_size + % total_cp_world_size + == total_cp_rank ) # Calculate local block_offsets block_offsets = ( virtual_block_offsets - // (self.dcp_world_size * self.dcp_kv_cache_interleave_size) - * self.dcp_kv_cache_interleave_size - + virtual_block_offsets % self.dcp_kv_cache_interleave_size + // (total_cp_world_size * self.cp_kv_cache_interleave_size) + * self.cp_kv_cache_interleave_size + + virtual_block_offsets % self.cp_kv_cache_interleave_size ) # Calculate slot_mapping slot_mapping = block_numbers * self.block_size + block_offsets @@ -253,7 +262,7 @@ def __init__( block_sizes: list[int], kernel_block_sizes: list[int], num_speculative_tokens: int = 0, - dcp_kv_cache_interleave_size: int = 1, + cp_kv_cache_interleave_size: int = 1, ) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, @@ -283,7 +292,7 @@ def __init__( pin_memory, device, kernel_block_size, - dcp_kv_cache_interleave_size, + cp_kv_cache_interleave_size, ) for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes) ] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 023b5edb2c34..d7f31425b607 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -87,7 +87,7 @@ def __init__( is_spec_decode: bool = False, is_pooling_model: bool = False, num_speculative_tokens: int = 0, - dcp_kv_cache_interleave_size: int = 1, + cp_kv_cache_interleave_size: int = 1, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode @@ -141,7 +141,7 @@ def __init__( block_sizes=block_sizes, kernel_block_sizes=kernel_block_sizes, num_speculative_tokens=num_speculative_tokens, - dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size, + cp_kv_cache_interleave_size=cp_kv_cache_interleave_size, ) # Sampling-related. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 506118d2d762..4eff05ca6247 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -426,7 +426,7 @@ def __init__( # uses output token ids so we set this conservatively. logitsprocs_need_output_token_ids=bool(custom_logitsprocs), is_pooling_model=self.is_pooling_model, - dcp_kv_cache_interleave_size=self.parallel_config.dcp_kv_cache_interleave_size, + cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -1435,7 +1435,7 @@ def _build_attention_metadata( self.seq_lens.cpu[:num_reqs], self.dcp_world_size, self.dcp_rank, - self.parallel_config.dcp_kv_cache_interleave_size, + self.parallel_config.cp_kv_cache_interleave_size, ) self.dcp_local_seq_lens.copy_to_gpu(num_reqs) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 315f01b68499..b8339fc4dc8b 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -26,6 +26,7 @@ has_kv_transfer_group, ) from vllm.distributed.parallel_state import ( + get_pcp_group, get_pp_group, get_tp_group, ) @@ -733,6 +734,7 @@ def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int): module.global_num_experts = module.moe_config.num_experts module.moe_parallel_config = FusedMoEParallelConfig.make( tp_size_=get_tp_group().world_size, + pcp_size_=get_pcp_group().world_size, dp_size_=get_dp_group().world_size, vllm_parallel_config=parallel_config, ) @@ -886,6 +888,7 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size, + parallel_config.prefill_context_parallel_size, parallel_config.decode_context_parallel_size, )