diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 5495640af07e..71df5c81b320 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -30,6 +30,7 @@ class ParallelSetup(NamedTuple): tp_size: int pp_size: int dcp_size: int + pcp_size: int eager_mode: bool chunked_prefill: bool @@ -37,6 +38,7 @@ class ParallelSetup(NamedTuple): class CPTestOptions(NamedTuple): multi_node_only: bool load_format: str | None = None + attn_backend: str = "FLASH_ATTN" @dataclass @@ -52,20 +54,25 @@ def detailed( tp_base: int = 4, pp_base: int = 1, dcp_base: int = 1, + pcp_base: int = 1, multi_node_only: bool = False, runner: RunnerOption = "auto", load_format: str | None = None, + attn_backend: str = "FLASH_ATTN", ): parallel_setups = [] for eager_mode_val in [False]: for pp_multiplier in [1]: - for dcp_multiplier in [0.5, 1]: + # TODO(qcs): Test the effect of mixed activation + # when PCP and DCP are compatible. + for pcp_multiplier, dcp_multiplier in zip([1, 2, 1], [0.5, 1, 1]): for chunked_prefill_val in [True]: parallel_setups.append( ParallelSetup( tp_size=tp_base, pp_size=pp_multiplier * pp_base, dcp_size=int(dcp_multiplier * tp_base), + pcp_size=int(pcp_multiplier * pcp_base), eager_mode=eager_mode_val, chunked_prefill=chunked_prefill_val, ) @@ -75,7 +82,9 @@ def detailed( distributed_backends=["mp"], runner=runner, test_options=CPTestOptions( - multi_node_only=multi_node_only, load_format=load_format + multi_node_only=multi_node_only, + load_format=load_format, + attn_backend=attn_backend, ), ) @@ -108,11 +117,12 @@ def _compare_cp_with_tp( tp_size, pp_size, dcp_size, + pcp_size, eager_mode, chunked_prefill, ) = parallel_setup - multi_node_only, load_format = test_options + multi_node_only, load_format, attn_backend = test_options model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_transformers_version(on_fail="skip") @@ -155,7 +165,7 @@ def _compare_cp_with_tp( "--max-model-len", "2048", "--max-num-seqs", - "8", + "16", ] if chunked_prefill: common_args.append("--enable-chunked-prefill") @@ -172,6 +182,10 @@ def _compare_cp_with_tp( if hf_overrides: common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + cp_env = tp_env = { + "VLLM_ATTENTION_BACKEND": attn_backend, + } + cp_args = [ *common_args, "--tensor-parallel-size", @@ -180,6 +194,8 @@ def _compare_cp_with_tp( str(pp_size), "--decode-context-parallel-size", str(dcp_size), + "--prefill-context-parallel-size", + str(pcp_size), "--distributed-executor-backend", distributed_backend, ] @@ -198,12 +214,15 @@ def _compare_cp_with_tp( model_id, cp_args, tp_args, + cp_env, + tp_env, method=method, max_wait_seconds=720, ) CP_TEXT_GENERATION_MODELS = { + # [MLA attention only] "deepseek-ai/DeepSeek-V2-Lite-Chat": [ CPTestSettings.detailed(), CPTestSettings.detailed(tp_base=2), @@ -211,6 +230,8 @@ def _compare_cp_with_tp( "bigcode/gpt_bigcode-santacoder": [ CPTestSettings.detailed(), CPTestSettings.detailed(tp_base=2), + CPTestSettings.detailed(attn_backend="FLASHINFER"), + CPTestSettings.detailed(tp_base=2, attn_backend="FLASHINFER"), ], } diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index e9c6a278a941..3a96bd7d6fd8 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -127,6 +127,9 @@ class AttentionImpl(ABC, Generic[T]): dcp_world_size: int dcp_rank: int + pcp_world_size: int + pcp_rank: int + def __new__(cls, *args, **kwargs): # use __new__ so that all subclasses will call this self = super().__new__(cls) @@ -139,6 +142,16 @@ def __new__(cls, *args, **kwargs): # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 + try: + from vllm.distributed.parallel_state import get_pcp_group + + self.pcp_world_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group + except AssertionError: + # PCP might not be initialized in testing + self.pcp_world_size = 1 + self.pcp_rank = 0 + self.need_to_return_lse_for_decode = ( self.dcp_world_size > 1 and self.can_return_lse_for_decode ) diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index b6b7ecd2552a..1b7c1dabc10c 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -168,12 +168,11 @@ def correct_attn_out( return out, lse -def cp_lse_ag_out_rs( +def _cp_lse_common( cp_attn_out: torch.Tensor, cp_attn_lse: torch.Tensor, cp_group: GroupCoordinator, ctx: CPTritonContext = None, - return_lse=False, ): """ cp_attn_out: [ B, H, D ] @@ -195,6 +194,21 @@ def cp_lse_ag_out_rs( lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) assert out.is_contiguous() + return out, lse + + +def cp_lse_ag_out_rs( + cp_attn_out: torch.Tensor, + cp_attn_lse: torch.Tensor, + cp_group: GroupCoordinator, + ctx: CPTritonContext = None, + return_lse: bool = False, +): + """ + cp_attn_out: [ B, H, D ] + cp_attn_lse: [ B, H ] + """ + out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx) out = cp_group.reduce_scatter(out, dim=1) if return_lse: @@ -205,6 +219,21 @@ def cp_lse_ag_out_rs( return out +def cp_lse_ag_out_ar( + cp_attn_out: torch.Tensor, + cp_attn_lse: torch.Tensor, + cp_group: GroupCoordinator, + ctx: CPTritonContext = None, +): + """ + cp_attn_out: [ B, H, D ] + cp_attn_lse: [ B, H ] + """ + out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx) + out = cp_group.all_reduce(out) + return out + + @triton.jit def _pack_seq_kernel( x_ptr, # [N, D] diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 953aa1a147de..760e01686dd5 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -71,6 +71,8 @@ class ParallelConfig: """Number of pipeline parallel groups.""" tensor_parallel_size: int = 1 """Number of tensor parallel groups.""" + prefill_context_parallel_size: int = 1 + """Number of prefill context parallel groups.""" data_parallel_size: int = 1 """Number of data parallel groups. MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.""" @@ -467,7 +469,11 @@ def __post_init__(self) -> None: ) # Continue with the rest of the initialization - self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size + self.world_size = ( + self.pipeline_parallel_size + * self.tensor_parallel_size + * self.prefill_context_parallel_size + ) if self.distributed_executor_backend == "external_launcher": logger.info("Using external launcher for distributed inference.") diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 7ee522ea9f0c..67eae381a58d 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -359,6 +359,15 @@ def __post_init__(self): ): self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + # prefill context parallel do not support full cudagraphs now. + if self.parallel_config.prefill_context_parallel_size > 1: + logger.warning( + "Prefill context parallel (PCP) is enabled, which is " + "incompatible with full CUDA graphs. Set " + "cudagraph_mode to PIECEWISE." + ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + # decode context parallel do not support full cudagraphs now. if self.parallel_config.decode_context_parallel_size > 1: logger.warning( diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 132fb9049163..6924464872a3 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1085,6 +1085,14 @@ def get_pp_group() -> GroupCoordinator: return _PP +_PCP: GroupCoordinator | None = None + + +def get_pcp_group() -> GroupCoordinator: + assert _PCP is not None, "prefill context parallel group is not initialized" + return _PCP + + @deprecated( "`get_pipeline_model_parallel_group` has been replaced with " "`get_pp_group` and may be removed in v0.12. Please use " @@ -1207,6 +1215,7 @@ def init_distributed_environment( def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + prefill_context_model_parallel_size: int = 1, decode_context_model_parallel_size: int | None = 1, backend: str | None = None, ) -> None: @@ -1256,7 +1265,11 @@ def initialize_model_parallel( # to get group_ranks for each dimension, transpose that dimension to the # last dimension, then reshape to 2D, then unbind the last dimension all_ranks = torch.arange(world_size).reshape( - -1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size + -1, + data_parallel_size, + pipeline_model_parallel_size, + prefill_context_model_parallel_size, + tensor_model_parallel_size, ) # noqa # Build the tensor model-parallel groups. @@ -1295,7 +1308,7 @@ def initialize_model_parallel( global _PP assert _PP is None, "pipeline model parallel group is already initialized" group_ranks = ( - all_ranks.transpose(2, 3).reshape(-1, pipeline_model_parallel_size).unbind(0) + all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] _PP = init_model_parallel_group( @@ -1304,7 +1317,7 @@ def initialize_model_parallel( global _DP assert _DP is None, "data parallel group is already initialized" - group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0) + group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] _DP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, group_name="dp" @@ -1314,7 +1327,12 @@ def initialize_model_parallel( assert _EP is None, "expert parallel group is already initialized" group_ranks = ( all_ranks.transpose(1, 2) - .reshape(-1, data_parallel_size * tensor_model_parallel_size) + .reshape( + -1, + data_parallel_size + * tensor_model_parallel_size + * prefill_context_model_parallel_size, + ) .unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] @@ -1322,21 +1340,33 @@ def initialize_model_parallel( group_ranks, get_world_group().local_rank, backend, group_name="ep" ) + global _PCP + assert _PCP is None, "prefill context parallel group is already initialized" + group_ranks = ( + all_ranks.transpose(3, 4).reshape(-1, prefill_context_model_parallel_size).unbind(0) + ) + group_ranks = [x.tolist() for x in group_ranks] + _PCP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="pcp" + ) + logger.info( "rank %s in world size %s is assigned as " - "DP rank %s, PP rank %s, TP rank %s, EP rank %s", + "DP rank %s, PP rank %s, TP rank %s, EP rank %s, PCP rank %s", rank, world_size, _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group, _EP.rank_in_group, + _PCP.rank_in_group, ) def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, + prefill_context_model_parallel_size: int = 1, decode_context_model_parallel_size: int | None = 1, backend: str | None = None, ) -> None: @@ -1349,6 +1379,7 @@ def ensure_model_parallel_initialized( initialize_model_parallel( tensor_model_parallel_size, pipeline_model_parallel_size, + prefill_context_model_parallel_size, decode_context_model_parallel_size, backend, ) @@ -1365,6 +1396,12 @@ def ensure_model_parallel_initialized( f"got: {pp_world_size=} vs. " f"wanted: {pipeline_model_parallel_size=}" ) + pcp_world_size = get_pcp_group().world_size + assert pcp_world_size == prefill_context_model_parallel_size, ( + "prefill context parallel group already initialized, but of unexpected size: " + f"{pcp_world_size=} vs. " + f"{prefill_context_model_parallel_size=}" + ) def prepare_communication_buffer_for_model(model: torch.nn.Module): @@ -1382,6 +1419,8 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module): _DP.prepare_communication_buffer_for_model(model) if _EP is not None: _EP.prepare_communication_buffer_for_model(model) + if _PCP is not None: + _PCP.prepare_communication_buffer_for_model(model) def model_parallel_is_initialized(): @@ -1427,16 +1466,6 @@ def get_tensor_model_parallel_rank(): return get_tp_group().rank_in_group -def get_decode_context_model_parallel_world_size(): - """Return world size for the decode context model parallel group.""" - return get_dcp_group().world_size - - -def get_decode_context_model_parallel_rank(): - """Return my rank for the decode context model parallel group.""" - return get_dcp_group().rank_in_group - - def get_node_count() -> int: """Return the total number of nodes in the distributed environment.""" assert _NODE_COUNT is not None, "distributed environment is not initialized" @@ -1471,6 +1500,11 @@ def destroy_model_parallel(): _EP.destroy() _EP = None + global _PCP + if _PCP: + _PCP.destroy() + _PCP = None + def destroy_distributed_environment(): global _WORLD, _NODE_COUNT diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 917d0ec9f7f3..fe288d5281bf 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -371,6 +371,7 @@ class EngineArgs: # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size + prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: int | None = None @@ -722,14 +723,19 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parallel_group.add_argument( "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"] ) + parallel_group.add_argument( + "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"] + ) + parallel_group.add_argument( + "--prefill-context-parallel-size", + "-pcp", + **parallel_kwargs["prefill_context_parallel_size"], + ) parallel_group.add_argument( "--decode-context-parallel-size", "-dcp", **parallel_kwargs["decode_context_parallel_size"], ) - parallel_group.add_argument( - "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"] - ) parallel_group.add_argument( "--data-parallel-rank", "-dpn", @@ -1466,6 +1472,7 @@ def create_engine_config( parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, + prefill_context_parallel_size=self.prefill_context_parallel_size, data_parallel_size=self.data_parallel_size, data_parallel_rank=self.data_parallel_rank or 0, data_parallel_external_lb=data_parallel_external_lb, @@ -1727,6 +1734,15 @@ def _set_default_args( self.enable_prefix_caching = False else: self.enable_prefix_caching = True + + if self.prefill_context_parallel_size > 1: + self.enable_chunked_prefill = False + self.enable_prefix_caching = False + logger.warning( + "--prefill-context-parallel-size > 1 is not compatible with " + "chunked prefill and prefix caching now. Chunked prefill " + "and prefix caching have been disabled." + ) else: pooling_type = model_config.pooler_config.pooling_type is_causal = getattr(model_config.hf_config, "is_causal", True) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 38ea6acc0fc5..13a0942fa54b 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -7,7 +7,11 @@ import vllm.envs as envs from vllm.config import ParallelConfig -from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank +from vllm.distributed import ( + get_dp_group, + get_pcp_group, + get_tensor_model_parallel_rank, +) from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_DTYPES, @@ -634,9 +638,11 @@ def biased_moe_quant_config( @dataclass class FusedMoEParallelConfig: tp_size: int + pcp_size: int dp_size: int ep_size: int tp_rank: int + pcp_rank: int dp_rank: int ep_rank: int @@ -664,7 +670,10 @@ def use_deepep_ll_kernels(self): @staticmethod def make( - tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig + tp_size_: int, + dp_size_: int, + pcp_size_: int, + vllm_parallel_config: ParallelConfig, ) -> "FusedMoEParallelConfig": """ Determine MoE parallel configuration. Based on the input `tp_size_`, @@ -737,24 +746,31 @@ def make( between the 4 devices. """ - def flatten_tp_across_dp(dp_rank: int): + def flatten_tp_across_dp_and_pcp(dp_rank: int, pcp_rank: int): tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() - # There are actually dp_size_ * tp_size_ devices. Update tp_size - # and tp_rank so we shard across all devices. - tp_size = dp_size_ * tp_size_ - tp_rank = dp_rank * tp_size_ + tp_rank + # There are actually dp_size_ * pcp_size_ * tp_size_ devices. + # Update tp_size and tp_rank so we shard across all devices. + tp_size = dp_size_ * pcp_size_ * tp_size_ + tp_rank = dp_rank * pcp_size_ * tp_size_ + pcp_rank * tp_size_ + tp_rank return tp_size, tp_rank - use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel + use_ep = ( + dp_size_ * tp_size_ * pcp_size_ > 1 + and vllm_parallel_config.enable_expert_parallel + ) dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 - tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + pcp_size = pcp_size_ + pcp_rank = get_pcp_group().rank_in_group if pcp_size_ > 1 else 0 + tp_size, tp_rank = flatten_tp_across_dp_and_pcp(dp_rank, pcp_rank) if not use_ep: return FusedMoEParallelConfig( tp_size=tp_size, tp_rank=tp_rank, + pcp_size=pcp_size, + pcp_rank=pcp_rank, dp_size=dp_size, dp_rank=dp_rank, ep_size=1, @@ -771,6 +787,8 @@ def flatten_tp_across_dp(dp_rank: int): return FusedMoEParallelConfig( tp_size=1, tp_rank=0, + pcp_size=pcp_size, + pcp_rank=pcp_rank, dp_size=dp_size, dp_rank=dp_rank, ep_size=ep_size, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3bb544a49f3a..f929a5cee604 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -18,6 +18,7 @@ from vllm.distributed import ( get_dp_group, get_ep_group, + get_pcp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) @@ -1061,6 +1062,7 @@ def __init__( tp_size: int | None = None, ep_size: int | None = None, dp_size: int | None = None, + pcp_size: int | None = None, prefix: str = "", custom_routing_function: Callable | None = None, scoring_func: str = "softmax", @@ -1098,6 +1100,11 @@ def __init__( tp_size if tp_size is not None else get_tensor_model_parallel_world_size() ) dp_size_ = dp_size if dp_size is not None else get_dp_group().world_size + pcp_size_ = ( + pcp_size + if pcp_size is not None + else get_pcp_group().world_size + ) self.is_sequence_parallel = is_sequence_parallel self.sp_size = tp_size_ if is_sequence_parallel else 1 @@ -1105,6 +1112,7 @@ def __init__( self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( tp_size_=tp_size_, dp_size_=dp_size_, + pcp_size_=pcp_size_, vllm_parallel_config=vllm_config.parallel_config, ) @@ -1334,6 +1342,10 @@ def tp_size(self): @property def dp_size(self): return self.moe_parallel_config.dp_size + + @property + def pcp_size(self): + return self.moe_parallel_config.pcp_size @property def ep_size(self): @@ -1346,6 +1358,10 @@ def tp_rank(self): @property def dp_rank(self): return self.moe_parallel_config.dp_rank + + @property + def pcp_rank(self): + return self.moe_parallel_config.pcp_rank @property def ep_rank(self): @@ -2332,6 +2348,19 @@ def forward_impl( hidden_states, router_logits, self.is_sequence_parallel ) + # NOTE: Similar with DP, PCP also needs dispatch and combine. For + # simplicity, AgRsAll2All was added separately for PCP here. Maybe + # we should modify All2AllManager abstract to better support PCP. + if self.pcp_size > 1: + hidden_states = get_pcp_group().all_gather( + hidden_states, + dim=0, + ) + router_logits = get_pcp_group().all_gather( + router_logits, + dim=0, + ) + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -2375,6 +2404,12 @@ def reduce_output( if do_naive_dispatch_combine and do_combine: states = get_ep_group().combine(states, self.is_sequence_parallel) + if self.pcp_size > 1: + states = get_pcp_group().reduce_scatter( + states, + dim=0, + ) + if ( not self.is_sequence_parallel and self.reduce_results diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index cd54b964c41f..96c1b1c220db 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -10,6 +10,7 @@ from flashinfer import ( BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, MultiLevelCascadeAttentionWrapper, ) from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache @@ -22,7 +23,10 @@ AttentionType, MultipleOf, ) +from vllm.attention.ops.common import cp_lse_ag_out_ar +from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.config import CUDAGraphMode, VllmConfig +from vllm.distributed.parallel_state import get_pcp_group from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -48,6 +52,7 @@ get_per_layer_parameters, infer_global_hyperparameters, split_decodes_and_prefills, + PrefillContextParallelMetadata, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -269,6 +274,10 @@ class FlashInferMetadata: qo_indptr_gpu: torch.Tensor | None = None paged_kv_indptr_gpu: torch.Tensor | None = None + # For context parallel + pcp_allgather_restore_idx: torch.Tensor | None = None + pcp_metadata: PrefillContextParallelMetadata | None = None + class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = ( @@ -326,6 +335,14 @@ def __init__( self.compilation_config.max_capture_size, ) + try: + self.pcp_world_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group + except AssertionError: + # PCP might not be initialized in testing + self.pcp_world_size = 1 + self.pcp_rank = 0 + self.num_qo_heads = self.model_config.get_num_attention_heads( self.vllm_config.parallel_config ) @@ -411,8 +428,15 @@ def _get_workspace_buffer(self): ) return self._workspace_buffer - def _get_prefill_wrapper(self): - if self._prefill_wrapper is None: + def _get_prefill_wrapper(self, attn_metadata): + # if self._prefill_wrapper is None: + if self.pcp_world_size > 1: + self._prefill_wrapper = {} + for key in ["head", "tail"]: + self._prefill_wrapper[key] = BatchPrefillWithRaggedKVCacheWrapper( + self._get_workspace_buffer(), get_kv_cache_layout() + ) + else: self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( self._get_workspace_buffer(), get_kv_cache_layout() ) @@ -461,6 +485,28 @@ def _get_cascade_wrapper(self): ) return self._cascade_wrapper + def _get_pcp_custom_mask( + self, + qo_indptr_cpu: torch.Tensor, + q_pos: torch.Tensor, + kv_lens: torch.Tensor, + ) -> torch.Tensor: + max_q_lens = int(q_pos.max().item()) + 1 + max_kv_lens = int(kv_lens.max().item()) + mask = torch.ones( + max_q_lens, + max_kv_lens, + dtype=torch.bool, + device=q_pos.device, + ).tril() + custom_mask_lst = [ + mask[q_pos[q_pos_start_loc:q_pos_end_loc], :kv_len].flatten() + for kv_len, q_pos_start_loc, q_pos_end_loc in + zip(kv_lens, qo_indptr_cpu[:-1], qo_indptr_cpu[1:]) + ] + custom_mask = torch.cat(custom_mask_lst) + return custom_mask + def build( self, common_prefix_len: int, @@ -482,7 +528,12 @@ def build( max_seq_len = common_attn_metadata.max_seq_len seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu + if self.pcp_world_size > 1: + seq_lens_cpu = seq_lens_cpu // self.pcp_world_size + ( + self.pcp_rank < seq_lens_cpu % self.pcp_world_size + ) seq_lens_np = seq_lens_cpu.numpy() + num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu block_table_tensor = common_attn_metadata.block_table_tensor num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size @@ -573,6 +624,11 @@ def build( has_sinks=self.has_sinks, has_spec=uses_spec_reorder, ) + if self.pcp_world_size > 1 and (prefill_use_trtllm or decode_use_trtllm): + raise NotImplementedError( + "Trtllm not support lse, please use flash attention " + "or disable attention sinks." + ) if not (prefill_use_trtllm and decode_use_trtllm): if self.has_sinks: @@ -615,6 +671,8 @@ def build( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, + pcp_allgather_restore_idx=common_attn_metadata.pcp_allgather_restore_idx, + pcp_metadata=common_attn_metadata.pcp_metadata, ) qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu @@ -647,7 +705,7 @@ def build( if num_prefills > 0: # Decodes are first so prefills start after the last decode prefill_start = num_decodes - attn_metadata.prefill_wrapper = self._get_prefill_wrapper() + attn_metadata.prefill_wrapper = self._get_prefill_wrapper(common_attn_metadata) assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 assert ( @@ -660,7 +718,6 @@ def build( qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start] ) paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] - # Recompute max_q_len for the slice of requests we are using # for prefills. This can be different from max_q_len when # we have a non-uniform batch with some short decodes offloaded @@ -669,24 +726,61 @@ def build( attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) if not attn_metadata.prefill_use_trtllm: - attn_metadata.prefill_wrapper.plan( - qo_indptr_cpu, - paged_kv_indptr_cpu, - paged_kv_indices, - paged_kv_last_page_len_cpu[prefill_start:], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - self.page_size, - causal=True, - sm_scale=self.sm_scale, - window_left=self.window_left, - logits_soft_cap=self.logits_soft_cap, - q_data_type=self.q_data_type, - kv_data_type=self.kv_cache_dtype, - fixed_split_size=self.prefill_fixed_split_size, - disable_split_kv=self.disable_split_kv, - ) + if self.pcp_world_size > 1: + assert common_attn_metadata.pcp_metadata is not None + assert common_attn_metadata.query_positions is not None + + pcp_metadata = common_attn_metadata.pcp_metadata + qo_indptr_cpu = pcp_metadata.q_head_start_loc + kv_for_head_indptr = pcp_metadata.kv_for_head_indptr + kv_for_tail_indptr = pcp_metadata.kv_for_tail_indptr + + attn_metadata.prefill_wrapper["head"].plan( + qo_indptr_cpu.to(self.device), + kv_for_head_indptr.to(self.device), + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + ) + # tail + attn_metadata.prefill_wrapper["tail"].plan( + qo_indptr_cpu.to(self.device), + kv_for_tail_indptr.to(self.device), + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + ) + else: + attn_metadata.prefill_wrapper.plan( + qo_indptr_cpu, + paged_kv_indptr_cpu, + paged_kv_indices, + paged_kv_last_page_len_cpu[prefill_start:], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + causal=True, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, + q_data_type=self.q_data_type, + kv_data_type=self.kv_cache_dtype, + fixed_split_size=self.prefill_fixed_split_size, + disable_split_kv=self.disable_split_kv, + ) else: attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( self.device, non_blocking=True @@ -757,6 +851,8 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: # TODO: The cascade wrapper currently does not support setting # kv cache dtype to something different from query dtype. return False + if self.pcp_world_size > 1: + return False # TODO: Cascade attention doesn't work, disable it for now # return use_cascade_attention(*args, **kwargs) return False @@ -838,6 +934,32 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): if self.sinks is not None and self.sinks.dtype != torch.float32: self.sinks = self.sinks.to(torch.float32) + def _attention_with_head_and_tail(self, + q_head: torch.Tensor, + q_tail: torch.Tensor, + k_head: torch.Tensor, + v_head: torch.Tensor, + k_tail: torch.Tensor, + v_tail: torch.Tensor, + prefill_wrapper: BatchPrefillWithRaggedKVCacheWrapper, + ): + output_head = torch.empty_like(q_head) + prefill_wrapper["head"].run( + q_head, + k_head, + v_head, + out=output_head, + ) + + output_tail = torch.empty_like(q_tail) + prefill_wrapper["tail"].run( + q_tail, + k_tail, + v_tail, + out=output_tail, + ) + return output_head, output_tail + def forward( self, layer: torch.nn.Module, @@ -926,6 +1048,25 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens + if (self.pcp_world_size > 1): + assert attn_metadata.pcp_allgather_restore_idx is not None + # NOTE(yyj): we must `slice` key and value because pcp_allgather_restore_idx + # ignores the padding from CUDA Graph. To be optimized for performance! + key_across_cp = get_pcp_group().all_gather( + key[:num_actual_tokens].contiguous(), dim=0 + ) + value_across_cp = get_pcp_group().all_gather( + value[:num_actual_tokens].contiguous(), dim=0 + ) + # Reorder kv after pcp allgather. + # Note that there are duplicate decoding tokens, + # but we only save the first one in kvcache. + key = torch.index_select( + key_across_cp, 0, attn_metadata.pcp_allgather_restore_idx + ) + value = torch.index_select( + value_across_cp, 0, attn_metadata.pcp_allgather_restore_idx + ) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. @@ -981,17 +1122,52 @@ def forward( assert prefill_wrapper is not None if not attn_metadata.prefill_use_trtllm: - assert prefill_wrapper._causal - assert prefill_wrapper._window_left == self.window_left - assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) - assert prefill_wrapper._sm_scale == self.scale - prefill_wrapper.run( - prefill_query, - kv_cache_permute, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=output[num_decode_tokens:], - ) + if self.pcp_world_size > 1: + assert type(prefill_wrapper) == dict + for _, prefill_wrapper_i in prefill_wrapper.items(): + assert prefill_wrapper_i._window_left == self.window_left + assert prefill_wrapper_i._logits_soft_cap == (self.logits_soft_cap or 0.0) + assert prefill_wrapper_i._sm_scale == self.scale + assert attn_metadata.pcp_metadata is not None + pcp_metadata = attn_metadata.pcp_metadata + q_head_indices = pcp_metadata.q_head_indices + q_tail_indices = pcp_metadata.q_tail_indices + kv_for_head_indices = pcp_metadata.kv_for_head_indices + kv_for_tail_indices = pcp_metadata.kv_for_tail_indices + q_full_indices = pcp_metadata.q_full_indices + + # NOTE(qcs): Allgather causes duplicate decoding tokens. + prefill_key = key[num_decode_tokens * self.pcp_world_size :] + prefill_value = value[num_decode_tokens * self.pcp_world_size :] + + output_head, output_tail = self._attention_with_head_and_tail( + torch.index_select(prefill_query, 0, q_head_indices), + torch.index_select(prefill_query, 0, q_tail_indices), + torch.index_select(prefill_key, 0, kv_for_head_indices), + torch.index_select(prefill_value, 0, kv_for_head_indices), + torch.index_select(prefill_key, 0, kv_for_tail_indices), + torch.index_select(prefill_value, 0, kv_for_tail_indices), + prefill_wrapper, + ) + + output_full = torch.index_select( + torch.cat([output_head, output_tail], dim=0), + 0, + q_full_indices + ) + output[num_decode_tokens:] = output_full + else: + assert prefill_wrapper._window_left == self.window_left + assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) + assert prefill_wrapper._sm_scale == self.scale + assert prefill_wrapper._causal + prefill_wrapper.run( + prefill_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[num_decode_tokens:], + ) else: # prefill_query may be non-contiguous prefill_query = prefill_query.contiguous() @@ -1067,13 +1243,25 @@ def forward( assert decode_wrapper._window_left == self.window_left assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._sm_scale == self.scale - decode_wrapper.run( - decode_query, - kv_cache_permute, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=output[:num_decode_tokens], - ) + if self.pcp_world_size > 1: + out, lse = decode_wrapper.run( + decode_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + return_lse=True, + ) + output[:num_decode_tokens] = cp_lse_ag_out_ar( + out, lse, get_pcp_group() + ) + else: + decode_wrapper.run( + decode_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[:num_decode_tokens], + ) else: # decode_query may be non-contiguous decode_query = decode_query.contiguous() diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 51a9032f4269..567b61db1bc7 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -976,9 +976,9 @@ def build( def reorg_kvcache( allgatered_kv_c_normed: torch.Tensor, allgatered_k_pe: torch.Tensor, - cp_chunk_seq_lens_lst: list[int], + dcp_chunk_seq_lens_lst: list[int], origin_context_lens: list[int], - cp_world_size: int, + dcp_world_size: int, sum_seq_len: int, max_seq_len: int, chunk_size: int, @@ -986,14 +986,14 @@ def reorg_kvcache( toks: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ - reorg kvcache after cp local gather to tp layout for attn kernel. + reorg kvcache after dcp local gather to tp layout for attn kernel. Args: - cp_chunk_seq_lens_lst: chunk context lengths under CP. - origin_context_lens: origin full context lengths under CP. - cp_world_size: CP size. - sum_seq_len: the sum of cp_chunk_seq_lens_lst. - max_seq_len: the max value of cp_chunk_seq_lens_lst. + dcp_chunk_seq_lens_lst: chunk context lengths under DCP. + origin_context_lens: origin full context lengths under DCP. + dcp_world_size: DCP size. + sum_seq_len: the sum of dcp_chunk_seq_lens_lst. + max_seq_len: the max value of dcp_chunk_seq_lens_lst. chunk_size: equals to max_context_chunk from chunked_context_metadata building. chunk_idx: chunk idx of chunked_prefill. @@ -1003,37 +1003,37 @@ def reorg_kvcache( k_pe_segments = [] src_token_idx = 0 max_seq_len_check = 0 - for cp_chunk_seq_len, origin_context_len in zip( - cp_chunk_seq_lens_lst, origin_context_lens + for dcp_chunk_seq_len, origin_context_len in zip( + dcp_chunk_seq_lens_lst, origin_context_lens ): chunk_context_len = chunk_size - if cp_chunk_seq_len != 0: + if dcp_chunk_seq_len != 0: chunk_context_len = min( chunk_context_len, origin_context_len - chunk_size * chunk_idx ) - cp_target_rank = (chunk_context_len - 1) % cp_world_size + dcp_target_rank = (chunk_context_len - 1) % dcp_world_size cur_seq_len = 0 - for rank in range(cp_world_size): - if rank > cp_target_rank and cp_chunk_seq_len: - real_cp_chunk_seq_len = cp_chunk_seq_len - 1 + for rank in range(dcp_world_size): + if rank > dcp_target_rank and dcp_chunk_seq_len: + real_dcp_chunk_seq_len = dcp_chunk_seq_len - 1 else: - real_cp_chunk_seq_len = cp_chunk_seq_len - if real_cp_chunk_seq_len: + real_dcp_chunk_seq_len = dcp_chunk_seq_len + if real_dcp_chunk_seq_len: kv_c_segment = allgatered_kv_c_normed[ rank * toks + src_token_idx : rank * toks + src_token_idx - + real_cp_chunk_seq_len + + real_dcp_chunk_seq_len ] k_pe_segment = allgatered_k_pe[ rank * toks + src_token_idx : rank * toks + src_token_idx - + real_cp_chunk_seq_len + + real_dcp_chunk_seq_len ] kv_c_segments.append(kv_c_segment) k_pe_segments.append(k_pe_segment) - cur_seq_len += real_cp_chunk_seq_len + cur_seq_len += real_dcp_chunk_seq_len max_seq_len_check = max(max_seq_len_check, cur_seq_len) - src_token_idx += cp_chunk_seq_len + src_token_idx += dcp_chunk_seq_len reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) reorganized_k_pe = torch.cat(k_pe_segments, dim=0) assert reorganized_kv_c_normed.shape[0] == sum_seq_len @@ -1637,11 +1637,11 @@ def _context_parallel_compute_prefill_context( kv_c_normed, k_pe = reorg_kvcache( allgatered_kv_c_normed, allgatered_k_pe, - cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.cp_chunk_seq_lens[ + dcp_chunk_seq_lens_lst=prefill_metadata.chunked_context.cp_chunk_seq_lens[ i ], origin_context_lens=prefill_metadata.chunked_context.origin_context_lens, - cp_world_size=dcp_world_size, + dcp_world_size=dcp_world_size, sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1], max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], chunk_size=prefill_metadata.chunked_context.chunk_size, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index cb5855548098..930dfc31881c 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -48,6 +48,19 @@ def is_valid_kv_cache_layout(value: str) -> bool: return value in get_args(KVCacheLayoutType) +@dataclass +class PrefillContextParallelMetadata: + """ + Attention metadata for prefill context parallel + """ + q_head_indices: torch.Tensor + q_tail_indices: torch.Tensor + q_head_start_loc: torch.Tensor + kv_for_head_indices: torch.Tensor + kv_for_tail_indices : torch.Tensor + kv_for_head_indptr: torch.Tensor + kv_for_tail_indptr: torch.Tensor + q_full_indices: torch.Tensor @dataclass class CommonAttentionMetadata: @@ -94,6 +107,11 @@ class CommonAttentionMetadata: dcp_local_seq_lens: torch.Tensor | None = None """Sequence lengths of the local rank in decode context parallelism world""" + # Needed by custom mask calc for context parallelism + query_positions: np.ndarray | None = None + pcp_allgather_restore_idx: torch.Tensor | None = None + pcp_metadata: PrefillContextParallelMetadata | None = None + def slice_query_start_locs( query_start_loc: torch.Tensor, @@ -190,6 +208,19 @@ def _make_metadata_with_slice( block_table_tensor = attn_metadata.block_table_tensor[request_slice] slot_mapping = attn_metadata.slot_mapping[token_slice] + # TODO(qcs): check if we can split query_positions and + # cp_kv_recover_idx as following approach + query_positions = ( + attn_metadata.query_positions[token_slice] + if attn_metadata.query_positions is not None + else None + ) + cp_allgather_restore_idx = ( + attn_metadata.pcp_allgather_restore_idx[token_slice] + if attn_metadata.pcp_allgather_restore_idx is not None + else None + ) + return CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, @@ -202,6 +233,8 @@ def _make_metadata_with_slice( max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, + query_positions=query_positions, + pcp_allgather_restore_idx=cp_allgather_restore_idx, ) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 137e5e0cdb6d..c65db42bebd4 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -27,6 +27,7 @@ def __init__( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + pcp_world_size: int, ): self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len @@ -44,6 +45,7 @@ def __init__( block_pool=self.block_pool, kv_cache_group_id=i, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups) ) @@ -210,6 +212,7 @@ def __init__( use_eagle: bool, enable_kv_cache_events: bool, dcp_world_size: int, + pcp_world_size: int, ): super().__init__( kv_cache_config, @@ -218,6 +221,7 @@ def __init__( False, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) self.num_single_type_manager = len(self.single_type_managers) @@ -250,6 +254,7 @@ def __init__( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + pcp_world_size: int, ): super().__init__( kv_cache_config, @@ -258,12 +263,16 @@ def __init__( enable_caching, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size self.dcp_world_size = dcp_world_size + self.pcp_world_size = pcp_world_size if dcp_world_size > 1: self.block_size *= dcp_world_size + if pcp_world_size > 1: + self.block_size *= pcp_world_size assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "UnitaryKVCacheCoordinator assumes only one kv cache group" ) @@ -281,6 +290,7 @@ def find_longest_cache_hit( kv_cache_spec=self.kv_cache_spec, use_eagle=self.use_eagle, dcp_world_size=self.dcp_world_size, + pcp_world_size=self.pcp_world_size, ) return hit_blocks, len(hit_blocks[0]) * self.block_size @@ -302,6 +312,7 @@ def __init__( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + pcp_world_size: int, ): super().__init__( kv_cache_config, @@ -310,8 +321,10 @@ def __init__( enable_caching, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) assert dcp_world_size == 1, "DCP not support hybrid attn now." + assert pcp_world_size == 1, "PCP not support hybrid attn now" self.verify_and_split_kv_cache_groups() def verify_and_split_kv_cache_groups(self) -> None: @@ -452,6 +465,7 @@ def get_kv_cache_coordinator( enable_caching: bool, enable_kv_cache_events: bool, dcp_world_size: int, + pcp_world_size: int, ) -> KVCacheCoordinator: if not enable_caching: return KVCacheCoordinatorNoPrefixCache( @@ -460,6 +474,7 @@ def get_kv_cache_coordinator( use_eagle, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator( @@ -469,6 +484,7 @@ def get_kv_cache_coordinator( enable_caching, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) return HybridKVCacheCoordinator( kv_cache_config, @@ -477,4 +493,5 @@ def get_kv_cache_coordinator( enable_caching, enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 74176e4b2051..ef9028b61eb1 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -100,6 +100,7 @@ def __init__( log_stats: bool = False, enable_kv_cache_events: bool = False, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> None: self.max_model_len = max_model_len @@ -124,12 +125,12 @@ def __init__( 0 ].kv_cache_spec.block_size - if dcp_world_size > 1: + if dcp_world_size * pcp_world_size > 1: assert len(kv_cache_config.kv_cache_groups) == 1 # Note(hc): need revisit. When both DCP and any future # PCP are enabled, the block_size may need to be scaled # by a factor of dcp_size × pcp_size? - self.block_size *= dcp_world_size + self.block_size *= dcp_world_size * pcp_world_size self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, @@ -138,6 +139,7 @@ def __init__( enable_caching=self.enable_caching, enable_kv_cache_events=enable_kv_cache_events, dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6c9a77ccb2b6..806f6ef34be8 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1189,11 +1189,20 @@ def _report_kv_cache_config( // len(kv_cache_config.kv_cache_groups) * min_block_size ) - if vllm_config.parallel_config.decode_context_parallel_size > 1: - num_tokens *= vllm_config.parallel_config.decode_context_parallel_size + if ( + vllm_config.parallel_config.prefill_context_parallel_size * + vllm_config.parallel_config.decode_context_parallel_size > 1 + ): + num_tokens *= (vllm_config.parallel_config.prefill_context_parallel_size * + vllm_config.parallel_config.decode_context_parallel_size) + cp_size = (vllm_config.parallel_config.prefill_context_parallel_size * + vllm_config.parallel_config.decode_context_parallel_size) logger.info( - "Multiplying the GPU KV cache size by the dcp_world_size %d.", - vllm_config.parallel_config.decode_context_parallel_size, + "Multiplying the GPU KV cache size by the cp_world_size %d " + "(pcp_world_size %d * dcp_world_size %d).", + cp_size, + vllm_config.parallel_config.prefill_context_parallel_size, + vllm_config.parallel_config.decode_context_parallel_size ) num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 08368b7d99ef..64c8e8185ba1 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -106,6 +106,7 @@ def __init__( self.block_size = block_size self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size + self.pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size # req_id -> Request self.requests: dict[str, Request] = {} @@ -170,6 +171,7 @@ def __init__( log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, + pcp_world_size=self.pcp_world_size, ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 586034182686..a1235581f813 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -32,6 +32,7 @@ def __init__( block_pool: BlockPool, kv_cache_group_id: int, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -42,8 +43,9 @@ def __init__( """ self.block_size = kv_cache_spec.block_size self.dcp_world_size = dcp_world_size - if self.dcp_world_size > 1: - self.block_size *= dcp_world_size + self.pcp_world_size = pcp_world_size + if self.dcp_world_size * self.pcp_world_size > 1: + self.block_size *= dcp_world_size * pcp_world_size self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool @@ -212,6 +214,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ Get the longest cache hit prefix of the blocks that is not longer than @@ -268,6 +271,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) @@ -279,8 +283,8 @@ def find_longest_cache_hit( [] for _ in range(len(kv_cache_group_ids)) ) block_size = kv_cache_spec.block_size - if dcp_world_size > 1: - block_size *= dcp_world_size + if dcp_world_size * pcp_world_size > 1: + block_size *= dcp_world_size * pcp_world_size max_num_blocks = max_length // block_size for block_hash in itertools.islice(block_hashes, max_num_blocks): # block_hashes is a chain of block hashes. If a block hash is not @@ -331,11 +335,13 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, SlidingWindowSpec), ( "SlidingWindowManager can only be used for sliding window groups" ) assert dcp_world_size == 1, "DCP not support sliding window attn now." + assert pcp_world_size == 1, "PCP not support sliding window attn now." # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window @@ -434,6 +440,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ For chunked local attention, we need to find the longest cache hit @@ -474,6 +481,7 @@ def find_longest_cache_hit( "Hybrid KV cache is not supported for " + "eagle + chunked local attention." ) assert dcp_world_size == 1, "DCP not support chunked local attn now." + assert pcp_world_size == 1, "PCP not support chunked local attn now." max_num_blocks = max_length // kv_cache_spec.block_size if max_length > 0: local_attention_start_idx = ( @@ -558,11 +566,13 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, MambaSpec), ( "MambaManager can only be used for mamba groups" ) assert dcp_world_size == 1, "DCP not support mamba now." + assert pcp_world_size == 1, "PCP not support mamba now." computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids)) ) @@ -658,6 +668,7 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, dcp_world_size: int = 1, + pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, CrossAttentionSpec), ( "CrossAttentionManager can only be used for cross-attention groups" diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0ca60ce5cf9a..8bf517d11f92 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -148,6 +148,7 @@ def __init__( scheduler_block_size = ( vllm_config.cache_config.block_size * vllm_config.parallel_config.decode_context_parallel_size + * vllm_config.parallel_config.prefill_context_parallel_size ) self.scheduler: SchedulerInterface = Scheduler( diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 38e8f4ab85d9..c4cbd3078a16 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -30,6 +30,7 @@ from vllm.distributed.parallel_state import ( get_dp_group, get_ep_group, + get_pcp_group, get_pp_group, get_tp_group, ) @@ -67,10 +68,12 @@ def _init_executor(self) -> None: self.world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size pp_parallel_size = self.parallel_config.pipeline_parallel_size - assert self.world_size == tensor_parallel_size * pp_parallel_size, ( + pcp_size = self.parallel_config.prefill_context_parallel_size + assert self.world_size == tensor_parallel_size * pp_parallel_size * pcp_size, ( f"world_size ({self.world_size}) must be equal to the " f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" - f"_parallel_size ({pp_parallel_size}). " + f"_parallel_size ({pp_parallel_size}) x prefill_context" + f"_parallel_size ({pcp_size}). " ) # Set multiprocessing envs @@ -362,7 +365,11 @@ def _get_output_rank(self) -> int: # 16-23, PP rank 2 # 24-31, PP rank 3 # so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3) - return self.world_size - self.parallel_config.tensor_parallel_size + return ( + self.world_size + - self.parallel_config.tensor_parallel_size + * self.parallel_config.prefill_context_parallel_size + ) @dataclass @@ -715,11 +722,15 @@ def setup_proc_title_and_log_prefix(enable_ep: bool) -> None: pp_rank = get_pp_group().rank_in_group tp_size = get_tp_group().world_size tp_rank = get_tp_group().rank_in_group + pcp_size = get_pcp_group().world_size + pcp_rank = get_pcp_group().rank_in_group process_name = "Worker" if dp_size > 1: process_name += f"_DP{dp_rank}" if pp_size > 1: process_name += f"_PP{pp_rank}" + if pcp_size > 1: + process_name += f"_PCP{pcp_rank}" if tp_size > 1: process_name += f"_TP{tp_rank}" if enable_ep: diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index a9ef1b92c243..c983bd21ee82 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -88,10 +88,13 @@ class FullAttentionSpec(AttentionSpec): def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size + pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size # Note(hc): each dcp rank only need save # (max_model_len//dcp_world_size) tokens locally. if dcp_world_size > 1: max_model_len = cdiv(max_model_len, dcp_world_size) + if pcp_world_size > 1: + max_model_len = cdiv(max_model_len, pcp_world_size) return cdiv(max_model_len, self.block_size) * self.page_size_bytes @classmethod diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 9bf06d51609f..813323755004 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -4,7 +4,7 @@ import numpy as np import torch -from vllm.distributed import get_dcp_group +from vllm.distributed import get_dcp_group, get_pcp_group from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.utils import CpuGpuBuffer @@ -80,12 +80,16 @@ def __init__( self._kernel_block_arange = None try: + self.pcp_world_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group except AssertionError: # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 + self.pcp_world_size = 1 + self.pcp_rank = 0 def append_row( self, @@ -127,14 +131,16 @@ def compute_slot_mapping( # NOTE(woosuk): We can't simply use `token_indices // block_size` # here because M (max_model_len) is not necessarily divisible by # block_size. - if self.dcp_world_size > 1: + if self.dcp_world_size * self.pcp_world_size > 1: # Note(hc): The DCP implement store kvcache with an interleave # style, the kvcache for the token whose token_idx is i is - # always stored on the GPU whose dcp_rank equals i % cp_world_size: + # always stored on the GPU whose dcp_rank equals i % pcp_world_size: # Use a "virtual block" which equals to world_size * block_size # for block_table_indices calculation. - virtual_block_size = self.block_size * self.dcp_world_size + virtual_block_size = ( + self.block_size * self.dcp_world_size * self.pcp_world_size + ) block_table_indices = ( req_indices * self.max_num_blocks_per_req + positions // virtual_block_size @@ -144,9 +150,15 @@ def compute_slot_mapping( # Use virtual_block_size for mask calculation, which marks local # tokens. virtual_block_offsets = positions % virtual_block_size - mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank + self.current_rank = self.dcp_world_size * self.pcp_rank + self.dcp_rank + mask = ( + virtual_block_offsets % (self.dcp_world_size * self.pcp_world_size) + == self.current_rank + ) # Calculate local block_offsets - block_offsets = virtual_block_offsets // self.dcp_world_size + block_offsets = virtual_block_offsets // ( + self.dcp_world_size * self.pcp_world_size + ) # Calculate slot_mapping slot_mapping = block_numbers * self.block_size + block_offsets # Write final slots, use -1 for not-local diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7e72ce937be4..015b9e638e02 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -35,6 +35,7 @@ from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( + get_pcp_group, get_pp_group, get_tp_group, graph_capture, @@ -92,6 +93,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, create_fast_prefill_custom_backend, + PrefillContextParallelMetadata, reorder_batch_to_split_decodes_and_prefills, split_attn_metadata, ) @@ -252,6 +254,8 @@ def __init__( # This will be overridden in load_model() self.is_multimodal_pruning_enabled = False self.max_model_len = model_config.max_model_len + self.pcp_world_size = get_pcp_group().world_size + self.pcp_rank = get_pcp_group().rank_in_group self.dcp_world_size = self.parallel_config.decode_context_parallel_size self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -397,9 +401,13 @@ def __init__( # Cache the device properties. self._init_device_properties() + if self.pcp_world_size > 1: + max_num_padded_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_world_size + else: + max_num_padded_tokens = self.max_num_tokens # Persistent buffers for CUDA graphs. - self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) - self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) + self.input_ids = self._make_buffer(max_num_padded_tokens, dtype=torch.int32) + self.positions = self._make_buffer(max_num_padded_tokens, dtype=torch.int64) self.query_start_loc = self._make_buffer( self.max_num_reqs + 1, dtype=torch.int32 ) @@ -414,7 +422,7 @@ def __init__( self.inputs_embeds = self._make_buffer( self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False ) - self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + self.is_token_ids = self._make_buffer(max_num_padded_tokens, dtype=torch.bool) self.discard_request_indices = self._make_buffer( self.max_num_reqs, dtype=torch.int64 ) @@ -431,6 +439,26 @@ def __init__( if self.supports_mm_inputs: self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + # Persistent buffers for Prefill Context Parallism + if self.pcp_world_size > 1: + self.pcp_allgather_restore_idx = self._make_buffer( + max_num_padded_tokens, + dtype=torch.int64 + ) + self.pcp_padded_slot_mapping = torch.empty( + (max_num_padded_tokens,), + dtype=torch.int64, + device=self.device, + ) + self.num_pcp_pads_cpu_tensor = torch.zeros( + (self.max_num_reqs,), device="cpu", dtype=torch.int64, pin_memory=True + ) + self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy() + self.pcp_unpad_mask_cpu_tensor = torch.zeros( + (max_num_padded_tokens,), device="cpu", dtype=torch.bool, pin_memory=True + ) + self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: # NOTE: `mrope_positions` is implemented with one additional dummy @@ -453,7 +481,7 @@ def __init__( # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context self.arange_np = np.arange( - max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), + max(self.max_num_reqs + 1, self.max_model_len, max_num_padded_tokens), dtype=np.int64, ) @@ -467,7 +495,7 @@ def __init__( self.kv_sharing_fast_prefill_logits_indices = None if self.cache_config.kv_sharing_fast_prefill: self.kv_sharing_fast_prefill_logits_indices = torch.zeros( - self.max_num_tokens, dtype=torch.int32, device=self.device + max_num_padded_tokens, dtype=torch.int32, device=self.device ) self.uniform_decode_query_len = ( @@ -919,6 +947,207 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: dummy_modality = mm_budget.get_modality_with_max_tokens() return self._get_mm_dummy_batch(dummy_modality, num_seqs) + def _get_pcp_metadata(self, q_lens, kv_lens): + """ + During the prefill phrase, the attention computation is divided into + two parts: q_head and q_tail. Here, we calculate the kv indices + corresponding to q_head or q_tail. Meawhile, the q and kv indptr are + also computed to build the attention wrapper. + If the pcp_size is 2, the variables are following: + >>> q_lens [4, 8] kv_lens [8, 16] + >>> pcp_chunk_sizes[2, 4] + >>> q_indptr [0, 2, 4] + >>> q_head_indices [0, 1, 4, 5, 6, 7] q_tail_indices [2, 3, 8, 9, 10, 11] + >>> kv_head_len r0 [2, 4] / r1 [4, 8] + >>> kv_for_head_indptr r0 [0, 2, 6] / r1 [0, 4, 12] + >>> kv_for_head_indices r0 [0, 1, 8, 9, 10, 11] + >>> r1 [0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15] + >>> kv_tail_len r0 [8, 16] / r1 [6, 12] + >>> kv_for_tail_indptr r0 [0, 8, 24] / r1 [0, 6, 18] + >>> kv_for_tail_indices r0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 23] + >>> r1 [0, 1, 2, 3, 4, 5, 8, 9, ..., 19] + """ + pcp_chunk_sizes = q_lens // 2 + q_indptr = np.zeros(len(pcp_chunk_sizes) + 1) + q_indptr[1:], q_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) + + # [4, 12] -> [12, 4] + q_head_start_loc = np.roll(np.cumsum(q_lens), 1) + q_head_start_loc[0] = 0 # [0, 4] + q_head_indices = q_chunk_arange + np.repeat( + q_head_start_loc, + pcp_chunk_sizes, + ) + + # [0, 4] + [2, 4] = [2, 8] + q_tail_start_loc = q_head_start_loc + pcp_chunk_sizes + q_tail_indices = q_chunk_arange + np.repeat( + q_tail_start_loc, + pcp_chunk_sizes, + ) + + # [8, 24] -> [24, 8] + kv_start_loc = np.roll(np.cumsum(kv_lens), 1) + kv_start_loc[0] = 0 # [0, 8] + # kv_for_q_head + kv_head_len = pcp_chunk_sizes * (self.pcp_rank + 1) + kv_for_head_indptr = np.zeros(len(kv_head_len) + 1) + kv_for_head_indptr[1:], kv_nomask_head_arange = self._get_cumsum_and_arange(kv_head_len) + kv_for_head_indices = kv_nomask_head_arange + np.repeat( + kv_start_loc, + kv_head_len, + ) + # kv_for_q_tail + kv_tail_len = pcp_chunk_sizes * (2 * self.pcp_world_size - self.pcp_rank) + kv_for_tail_indptr = np.zeros(len(kv_tail_len) + 1) + kv_for_tail_indptr[1:], kv_nomask_tail_arange = self._get_cumsum_and_arange(kv_tail_len) + kv_for_tail_indices = kv_nomask_tail_arange + np.repeat( + kv_start_loc, + kv_tail_len, + ) + + head_tail_indices = { + "q_head": q_head_indices, + "q_tail": q_tail_indices, + "kv_head": kv_for_head_indices, + "kv_tail": kv_for_tail_indices, + } + head_tail_indptr = { + "q": q_indptr, + "kv_head": kv_for_head_indptr, + "kv_tail": kv_for_tail_indptr + } + for key, value in head_tail_indices.items(): + head_tail_indices[key] = torch.from_numpy(value).to( + device=self.device, dtype=torch.int64, non_blocking=True + ) + for key, value in head_tail_indptr.items(): + head_tail_indptr[key] = torch.from_numpy(value).to( + dtype=torch.int64 + ) + + q_full_indices = torch.cat([head_tail_indices["q_head"], head_tail_indices["q_tail"]]) + q_full_indices = q_full_indices.to(torch.float32).argsort().to(torch.int32) + + return PrefillContextParallelMetadata( + q_head_indices=head_tail_indices["q_head"], + q_tail_indices=head_tail_indices["q_tail"], + q_head_start_loc=head_tail_indptr["q"], + kv_for_head_indices=head_tail_indices["kv_head"], + kv_for_tail_indices=head_tail_indices["kv_tail"], + kv_for_head_indptr=head_tail_indptr["kv_head"], + kv_for_tail_indptr=head_tail_indptr["kv_tail"], + q_full_indices=q_full_indices, + ) + + def _update_tokens_for_pcp(self, tokens): + """ + If prefill context parallelism is enabled, we will calculate + the number of tokens `tokens` after sequence splitting. + Meanwhile, we will compute: + `positions` the new token positions, + `self.num_pcp_pads_cpu` the number of padding tokens + per request for alignment, + `self.pcp_unpad_mask_cpu` the mask for non-padded tokens, + `self.pcp_allgather_restore_idx` indices to restore the + original vector order after PCP allgather. + Example: + >>> tokens = [1, 5, 8] + >>> pcp_world_size = 2 + >>> pcp_rank = 0 + >>> _update_tokens_for_pcp(tokens) + ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7]) + >>> pcp_rank = 1 + >>> _update_tokens_for_pcp(tokens) + ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5]) + >>> # the following results are same for each pcp rank + >>> self.num_pcp_pads_cpu + [1, 3, 0] + >>> self.pcp_unpad_mask_cpu + [True, False, True, True, True, True, True, False, False, + False, True, True, True, True, True, True, True, True] + >>> self.pcp_allgather_resotre_idx + [0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8] + """ + num_reqs = self.input_batch.num_reqs + self.num_pcp_pads_cpu[:num_reqs] = 0 + + num_decode_reqs = sum( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + >= self.input_batch.num_prompt_tokens[:num_reqs] + ) + num_decode_tokens = sum(tokens[:num_decode_reqs]) + + num_padded_scheduled_tokens = np.ceil( + tokens / (2 * self.pcp_world_size) + ).astype(np.int32) * (2 * self.pcp_world_size) + # we duplicate scheduled tokens of decode reqs to pcp_world_size + num_padded_scheduled_tokens[:num_decode_reqs] = ( + tokens[:num_decode_reqs] * self.pcp_world_size + ) + self.num_pcp_pads_cpu[:num_reqs] = num_padded_scheduled_tokens - tokens + cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange( + num_padded_scheduled_tokens + ) + self.pcp_unpad_mask_cpu[: pcp_padded_arange.shape[0]] = ( + pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens) + ) + + pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size + pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) + pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs] + _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens) + _, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) + pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, pcp_tokens) + + def get_current_rank_positions( + positions_start_loc: int | np.ndarray, rank: int + ): + positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32) + head_start_loc = positions_start_loc + rank * pcp_chunk_sizes + tail_start_loc = ( + positions_start_loc + + (2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes + ) + positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat( + head_start_loc, pcp_chunk_sizes + ) + # Decode reqs do not have tail chunks. + positions[~pcp_head_chunk_mask] = ( + pcp_chunk_arange[num_decode_tokens:] + + np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:] + ) + return positions + + positions = get_current_rank_positions(0, self.pcp_rank) + # Decode tokens are duplicated only after AG. But their positions are + # same without prefill context parallel. + if num_decode_reqs > 0: + positions[:num_decode_tokens] = self._get_cumsum_and_arange( + tokens[:num_decode_reqs] + )[1] + + padded_pos_start_loc = np.roll(cu_padded_tokens, 1) + padded_pos_start_loc[0] = 0 + all_positions_lst = [ + get_current_rank_positions(padded_pos_start_loc, rank_i) + for rank_i in range(self.pcp_world_size) + ] + all_positions = np.concatenate(all_positions_lst) + self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = ( + all_positions.argsort() + ) + self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) + return ( + pcp_tokens, + positions, + self._get_pcp_metadata( + pcp_tokens[num_decode_reqs:], + num_padded_scheduled_tokens[num_decode_reqs:], + ) if num_reqs > num_decode_reqs + else None, + ) + def _get_cumsum_and_arange( self, num_tokens: np.ndarray, @@ -1083,6 +1312,30 @@ def _prepare_inputs( out=positions_np, ) + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) + + pcp_metadata = None + if self.pcp_world_size > 1: + num_scheduled_tokens, pcp_positions, pcp_metadata = \ + self._update_tokens_for_pcp( + num_scheduled_tokens + ) + + # Re-update after PCP split sequences. + total_num_scheduled_tokens = sum(num_scheduled_tokens) + total_num_pcp_pads = sum(self.num_pcp_pads_cpu[:num_reqs]) + max_num_scheduled_tokens = max(num_scheduled_tokens) + + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + cu_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + positions_np = self.positions.np[:total_num_scheduled_tokens] + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + pcp_positions[:total_num_scheduled_tokens], + out=positions_np, + ) + # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -1153,9 +1406,6 @@ def _prepare_inputs( output_idx += num_sched - self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) - # Prepare the attention metadata. self.query_start_loc.np[0] = 0 self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens @@ -1201,7 +1451,10 @@ def _prepare_inputs( # Record the index of requests that should not be sampled, # so that we could clear the sampled tokens before returning - discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np + discard_requests_mask = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + np.array(tokens, dtype=np.int32) + ) < num_tokens_np discard_request_indices = np.nonzero(discard_requests_mask)[0] self.num_discarded_requests = len(discard_request_indices) self.discard_request_indices.np[: self.num_discarded_requests] = ( @@ -1230,10 +1483,18 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - logits_indices = query_start_loc[1:] - 1 + if self.pcp_world_size > 1: + logits_indices = ( + torch.from_numpy(cu_num_tokens) * self.pcp_world_size + - self.num_pcp_pads_cpu_tensor[:num_reqs] + - 1 + ) + else: + logits_indices = query_start_loc[1:] - 1 num_draft_tokens = None spec_decode_metadata = None else: + assert self.pcp_world_size == 1, "PCP not support spec decode now" # Get the number of draft tokens for each request. # Iterate over the dictionary rather than all requests since not all # requests have draft tokens. @@ -1299,6 +1560,12 @@ def _prepare_inputs( scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs ) + slot_mapping_size = ( + total_num_scheduled_tokens + if self.pcp_world_size == 1 + else total_num_scheduled_tokens * self.pcp_world_size + - total_num_pcp_pads + ) if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): # Encoder-only layers do not have KV cache, so we need to # create a dummy block table and slot mapping for them. @@ -1308,7 +1575,7 @@ def _prepare_inputs( device=self.device, ) slot_mapping = torch.zeros( - (total_num_scheduled_tokens,), + (slot_mapping_size,), dtype=torch.int64, device=self.device, ) @@ -1316,15 +1583,29 @@ def _prepare_inputs( else: blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor(num_reqs) - slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens] + + slot_mapping = blk_table.slot_mapping.gpu[:slot_mapping_size] # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. - blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) + blk_table.slot_mapping.gpu[slot_mapping_size:].fill_(-1) num_common_prefix_blocks = scheduler_output.num_common_prefix_blocks[ kv_cache_group_id ] + if self.pcp_world_size > 1: + # After pcp allgather and restore, there are padded tokens in + # kv, so we need pad slotmapping for alignment. + pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[ + : total_num_scheduled_tokens * self.pcp_world_size + ] + cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[ + : total_num_scheduled_tokens * self.pcp_world_size + ] + pcp_padded_slot_mapping.fill_(-1) + pcp_padded_slot_mapping[cp_unpad_mask] = slot_mapping + slot_mapping = pcp_padded_slot_mapping + common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, @@ -1341,6 +1622,13 @@ def _prepare_inputs( num_logits_indices=logits_indices.size(0), causal=True, encoder_seq_lens=encoder_seq_lens, + query_positions=positions_np, + pcp_allgather_restore_idx=self.pcp_allgather_restore_idx.gpu[ + : total_num_scheduled_tokens * self.pcp_world_size + ] + if self.pcp_world_size > 1 + else None, + pcp_metadata=pcp_metadata, dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None, @@ -2447,6 +2735,12 @@ def execute_model( self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) elif num_tokens_across_dp is not None: num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) + elif self.pcp_world_size > 1: + # NOTE(qcs): For PCP, we pad num_scheduled_tokens_np but + # do not update total_num_scheduled_tokens in scheduler_output + num_input_tokens = self._get_num_input_tokens( + num_scheduled_tokens_np.sum() + ) else: num_input_tokens = self._get_num_input_tokens( scheduler_output.total_num_scheduled_tokens @@ -2517,6 +2811,18 @@ def execute_model( hidden_states = model_output aux_hidden_states = None + if self.pcp_world_size > 1: + # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx + # ignores the padding from CUDA Graph. + hidden_states = get_pcp_group().all_gather( + hidden_states[:num_scheduled_tokens_np.sum()], + 0, + ) + hidden_states = torch.index_select( + hidden_states, + 0, + self.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]], + ) if not self.broadcast_pp_output: # Common case. if not get_pp_group().is_last_rank: @@ -3307,10 +3613,19 @@ def _dummy_run( self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() - cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + cum_num_tokens, query_positions = self._get_cumsum_and_arange( + num_scheduled_tokens + ) self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() - + + pcp_metadata = None + if self.pcp_world_size > 1: + num_decode_reqs = sum(num_scheduled_tokens == 1) + pcp_metadata = self._get_pcp_metadata( + num_scheduled_tokens[num_decode_reqs:], + num_scheduled_tokens[num_decode_reqs:] * 2, + ) for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups ): @@ -3333,6 +3648,12 @@ def _dummy_run( kv_cache_group_id ].slot_mapping.gpu[:num_tokens], causal=True, + query_positions=query_positions, + pcp_metadata=pcp_metadata if self.pcp_world_size > 1 else None, + pcp_allgather_restore_idx=self.pcp_allgather_restore_idx.gpu[ + : total_num_scheduled_tokens * self.pcp_world_size + ] if self.pcp_world_size > 1 + else None, dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 00dc7682c973..964ab774e6ec 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -776,6 +776,7 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size, + parallel_config.prefill_context_parallel_size, parallel_config.decode_context_parallel_size, )