diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 0bd7e80f544c..5b8c9fe5c32e 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -264,6 +264,9 @@ class AttentionImpl(ABC, Generic[T]): dcp_world_size: int dcp_rank: int + cp_world_size: int + cp_rank: int + def __new__(cls, *args, **kwargs): # use __new__ so that all subclasses will call this self = super().__new__(cls) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 8fe8f3053e35..1a1b1fea1933 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -7,7 +7,8 @@ import vllm.envs as envs from vllm.config import ParallelConfig -from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank, get_context_model_parallel_rank +from vllm.distributed import (get_context_model_parallel_rank, get_dp_group, + get_tensor_model_parallel_rank) from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) @@ -604,7 +605,8 @@ def make(tp_size_: int, dp_size_: int, cp_size_: int, level's of parallelism to use in the fused moe layer. Args: - tp_size_ (int): `tp_size` pa use_ep = (dp_size_ * tp_size_ssed into the FusedMoE constructor. + tp_size_ (int): `tp_size` pa use_ep = (dp_size_ * tp_size_ssed into + the FusedMoE constructor. dp_size_ (int): `dp_size` passed into the FusedMoE constructor. vllm_parallel_config (ParallelConfig): vLLM's parallel config object which contains the `enable_expert_parallel` flag. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 22aefc8e1617..217a2e2c2f8a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -13,9 +13,9 @@ import vllm.envs as envs from vllm.config import get_current_vllm_config from vllm.config.parallel import ExpertPlacementStrategy -from vllm.distributed import (get_dp_group, get_ep_group, +from vllm.distributed import (get_context_model_parallel_world_size, + get_dp_group, get_ep_group, get_tensor_model_parallel_world_size, - get_context_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.distributed.eplb.eplb_state import EplbState from vllm.forward_context import ForwardContext, get_forward_context diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 05f129f513a0..39f7a13786d7 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -192,6 +192,14 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: "compatible. Set the all_to_all backend to deepep_low_latency " "to use those kernels instead.") compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and parallel_config.context_parallel_size > 1): + logger.info( + "Context Parallel: disabling cudagraphs since CP." + ) + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + @classmethod def get_current_memory_usage(cls, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4bf47c8f3e08..f00829c69867 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -21,8 +21,8 @@ AttentionType) from vllm.attention.ops.common import cp_lse_ag_out_ar from vllm.config import CUDAGraphMode, VllmConfig -from vllm.logger import init_logger from vllm.distributed.parallel_state import get_cp_group +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, kNvfp4Quant) from vllm.platforms import current_platform @@ -239,7 +239,7 @@ class FlashInferMetadata: paged_kv_indptr_gpu: Optional[torch.Tensor] = None # For context parallel - cp_kv_recover_idx: Optional[torch.Tensor] = None + cp_allgather_restore_idx: Optional[torch.Tensor] = None class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): @@ -262,9 +262,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.kv_cache_spec.block_size) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req - # NOTE(qcs): Context Parallel do not support graph mode now self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\ - decode_mode() == CUDAGraphMode.FULL and self.cp_world_size == 1) + decode_mode() == CUDAGraphMode.FULL) if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. @@ -552,7 +551,7 @@ def build(self, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, - cp_kv_recover_idx=common_attn_metadata.cp_kv_recover_idx, + cp_allgather_restore_idx=common_attn_metadata.cp_allgather_restore_idx, ) qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu @@ -599,38 +598,30 @@ def build(self, qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[ prefill_start] paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] - prefill_num_computed_tokens_cpu = num_computed_tokens_cpu[prefill_start:] + prefill_num_computed_tokens_cpu = \ + num_computed_tokens_cpu[prefill_start:] if not attn_metadata.prefill_use_trtllm: if self.cp_world_size > 1: - # NOTE(qcs): no chunked prefill and prefix caching + assert common_attn_metadata.query_positions is not None kv_indptr_cpu = qo_indptr_cpu * self.cp_world_size # init custom mask for head-tail query order - mask_arr = [] - q_pos = common_attn_metadata.query_positions - for i in range(num_prefills): - # |---------|--|--| - # |-------| - # cp_world_size = 2 - # Q = 2 - # C = 8 - # cur_q_pos = [0,3] - # context_mask_i.shape = (2, 8) - # upper = [0,1,2,3] - # local_mask_i = [[True, False, False, False], - # [True, True, True, True]] # size=(2, 4) - # mask_i.shape = (2, 12) - cur_q_pos = torch.from_numpy(q_pos[qo_indptr_cpu[i]:qo_indptr_cpu[i+1]]) - Q = len(cur_q_pos) - C = prefill_num_computed_tokens_cpu[i] - if Q <= 0: - mask_arr.append(torch.zeros(0, dtype=torch.bool)) - continue - context_mask_i = torch.ones((Q, C), dtype=torch.bool) - upper = torch.arange(Q*self.cp_world_size) - local_mask_i = (upper.unsqueeze(0) <= cur_q_pos.unsqueeze(1)) - mask_i = torch.cat([context_mask_i, local_mask_i], dim=1) - mask_arr.append(mask_i.flatten()) - custom_mask = torch.cat(mask_arr, dim=0).to(self.device) + q_pos = torch.from_numpy( + common_attn_metadata.query_positions[ + prefill_start:]).long() + kv_lens = prefill_num_computed_tokens_cpu + \ + kv_indptr_cpu[1:] - kv_indptr_cpu[:-1] + max_q_lens = int(q_pos.max().item()) + 1 + max_kv_lens = int(kv_lens.max().item()) + mask = torch.ones(max_q_lens, max_kv_lens, + dtype=torch.bool).tril() + selected_rows = torch.index_select(mask, 0, q_pos) + col_indices = torch.arange(max_kv_lens).expand(q_pos.size(0), -1) + valid_mask = col_indices < torch.repeat_interleave( + kv_lens, + qo_indptr_cpu[1:] - \ + qo_indptr_cpu[:-1] + ).unsqueeze(1) + custom_mask = selected_rows[valid_mask].to(self.device) attn_metadata.prefill_wrapper.plan( qo_indptr_cpu.to(self.device), @@ -874,6 +865,28 @@ def forward( # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefill_tokens = attn_metadata.num_prefill_tokens + + key_across_cp = get_cp_group().all_gather( + key.contiguous(), dim=0) + value_across_cp = get_cp_group().all_gather( + value.contiguous(), dim=0) + if (self.cp_world_size > 1 + and attn_metadata.cp_allgather_restore_idx is not None): + # Reorder kv after cp allgather. + # Note that there are duplicate decoding tokens, + # but we only save the first one in kvcache. + key_across_cp = torch.index_select( + key_across_cp, 0, + attn_metadata.cp_allgather_restore_idx + ) + value_across_cp = torch.index_select( + value_across_cp, 0, + attn_metadata.cp_allgather_restore_idx + ) + key = key_across_cp + value = value_across_cp if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. @@ -883,17 +896,16 @@ def forward( # and value[:num_actual_tokens] because the reshape_and_cache_flash # op uses the slot_mapping's shape to determine the number of # actual tokens. - if self.cp_world_size == 1: - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 # to process the cache when the kv_cache_dtype is fp8 @@ -913,9 +925,6 @@ def forward( output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) return output - num_decode_tokens = attn_metadata.num_decode_tokens - num_prefill_tokens = attn_metadata.num_prefill_tokens - stride_order = FlashInferBackend.get_kv_cache_stride_order() kv_cache_permute = kv_cache.permute(*stride_order) # Regular attention (common case). @@ -933,34 +942,15 @@ def forward( self.logits_soft_cap or 0.0) assert prefill_wrapper._sm_scale == self.scale if self.cp_world_size > 1: - key_across_cp = get_cp_group().all_gather( - key[num_decode_tokens:].contiguous(), dim=0) - value_across_cp = get_cp_group().all_gather( - value[num_decode_tokens:].contiguous(), dim=0) - key_across_cp = torch.index_select( - key_across_cp, 0, - attn_metadata.cp_kv_recover_idx - ) - value_across_cp = torch.index_select( - value_across_cp, 0, - attn_metadata.cp_kv_recover_idx - ) - torch.ops._C_cache_ops.reshape_and_cache_flash( - key_across_cp, - value_across_cp, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping[num_decode_tokens:], - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - # TODO(qcs): 考虑 chunked prefill/ prefix cache 情况下 - # kvcache的获取与拼接 + # NOTE(qcs): Allgather causes duplicate decoding tokens. + prefill_key = key[ + num_decode_tokens*self.cp_world_size:] + prefill_value = value[ + num_decode_tokens*self.cp_world_size:] prefill_wrapper.run( prefill_query, - key_across_cp, - value_across_cp, + prefill_key, + prefill_value, out=output[num_decode_tokens:], ) else: @@ -1047,17 +1037,6 @@ def forward( or 0.0) assert decode_wrapper._sm_scale == self.scale if self.cp_world_size > 1: - torch.ops._C_cache_ops.reshape_and_cache_flash( - key[:num_decode_tokens], - value[:num_decode_tokens], - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping[:num_decode_tokens], - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - kv_cache_permute = kv_cache.permute(*stride_order) out, lse = decode_wrapper.run( decode_query, kv_cache_permute, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index ff4e10e82edd..df7e1ab9792d 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -83,7 +83,7 @@ class CommonAttentionMetadata: # Needed by custom mask calc for context parallelism query_positions: Optional[np.ndarray] = None - cp_kv_recover_idx: Optional[torch.Tensor] = None + cp_allgather_restore_idx: Optional[torch.Tensor] = None def slice_query_start_locs( query_start_loc: torch.Tensor, @@ -139,10 +139,13 @@ def _make_metadata_with_slice( block_table_tensor = attn_metadata.block_table_tensor[request_slice] slot_mapping = attn_metadata.slot_mapping[token_slice] + # TODO(qcs): check if we can split query_positions and + # cp_kv_recover_idx as following approach query_positions = attn_metadata.query_positions[token_slice] \ if attn_metadata.query_positions is not None else None - cp_kv_recover_idx = attn_metadata.cp_kv_recover_idx[token_slice] \ - if attn_metadata.cp_kv_recover_idx is not None else None + cp_allgather_restore_idx = attn_metadata.cp_allgather_restore_idx[ + token_slice] if attn_metadata.cp_allgather_restore_idx is not None \ + else None return CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -157,7 +160,7 @@ def _make_metadata_with_slice( block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, query_positions=query_positions, - cp_kv_recover_idx=cp_kv_recover_idx, + cp_allgather_restore_idx=cp_allgather_restore_idx, ) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e728dbb96272..c3c4b4b9be1b 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -28,9 +28,9 @@ destroy_model_parallel) from vllm.distributed.device_communicators.shm_broadcast import (Handle, MessageQueue) -from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, - get_pp_group, get_tp_group, - get_cp_group) +from vllm.distributed.parallel_state import (get_cp_group, get_dp_group, + get_ep_group, get_pp_group, + get_tp_group) from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import worker_receiver_cache_from_config @@ -64,7 +64,8 @@ def _init_executor(self) -> None: tensor_parallel_size = self.parallel_config.tensor_parallel_size pp_parallel_size = self.parallel_config.pipeline_parallel_size context_parallel_size = self.parallel_config.context_parallel_size - assert self.world_size == tensor_parallel_size * pp_parallel_size * context_parallel_size, ( + assert self.world_size == tensor_parallel_size * pp_parallel_size * \ + context_parallel_size, ( f"world_size ({self.world_size}) must be equal to the " f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" f"_parallel_size ({pp_parallel_size}) x context" @@ -345,7 +346,8 @@ def _get_output_rank(self) -> int: # 16-23, PP rank 2 # 24-31, PP rank 3 # so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3) - return self.world_size - self.parallel_config.tensor_parallel_size * self.parallel_config.context_parallel_size + return self.world_size - self.parallel_config.tensor_parallel_size * \ + self.parallel_config.context_parallel_size @dataclass diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index d25cb699d346..ac722a332503 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -5,7 +5,7 @@ import numpy as np import torch -from vllm.distributed import get_dcp_group, get_cp_group +from vllm.distributed import get_cp_group, get_dcp_group from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.utils import CpuGpuBuffer @@ -92,18 +92,21 @@ def compute_slot_mapping(self, req_indices: np.ndarray, # Use a "virtual block" which equals to world_size * block_size # for block_table_indices calculation. - virtual_block_size = self.block_size * self.dcp_world_size * self.cp_world_size + virtual_block_size = self.block_size * self.dcp_world_size * \ + self.cp_world_size block_table_indices = (req_indices * self.max_num_blocks_per_req + positions // virtual_block_size) block_numbers = self.block_table.np.ravel()[block_table_indices] # Use virtual_block_size for mask calculation, which marks local # tokens. virtual_block_offsets = positions % virtual_block_size - self.current_rank = self.dcp_world_size * self.cp_rank + self.dcp_rank - mask = (virtual_block_offsets % - (self.dcp_world_size * self.cp_world_size) == self.current_rank) + self.current_rank = self.dcp_world_size * self.cp_rank + \ + self.dcp_rank + mask = (virtual_block_offsets % (self.dcp_world_size * \ + self.cp_world_size) == self.current_rank) # Calculate local block_offsets - block_offsets = virtual_block_offsets // (self.dcp_world_size * self.cp_world_size) + block_offsets = virtual_block_offsets // \ + (self.dcp_world_size * self.cp_world_size) # Calculate slot_mapping slot_mapping = block_numbers * self.block_size + block_offsets # Write final slots, use -1 for not-local @@ -147,8 +150,12 @@ def _make_buffer(self, *size: Union[int, torch.SymInt], device=self.device, pin_memory=self.pin_memory) - def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) -> list[list[list[int]]]: - "Splits computed token counts across dcp and sp dimensions for distributed allocation." + def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) \ + -> list[list[list[int]]]: + """ + Splits computed token counts across dcp and sp dimensions for + distributed allocation. + """ num_requests = len(num_computed_tokens) num_computed_tokens_of_dcp_sp = [[ [0] * self.dcp_world_size for _ in range(self.cp_world_size) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 757cc5e7fccc..887f87263aba 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,7 +8,7 @@ from collections.abc import Iterator from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast, List +from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional, Union, cast import numpy as np import torch @@ -31,8 +31,8 @@ has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( - get_pp_group, get_tp_group, get_dcp_group, get_cp_group, graph_capture, is_global_first_rank, - prepare_communication_buffer_for_model) + get_cp_group, get_pp_group, get_tp_group, graph_capture, + is_global_first_rank, prepare_communication_buffer_for_model) from vllm.forward_context import (BatchDescriptor, DPMetadata, set_forward_context) from vllm.logger import init_logger @@ -55,10 +55,10 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, check_use_alibi, get_dtype_size, + GiB_bytes, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, length_from_prompt_token_ids_or_embeds, round_up, - supports_dynamo, cdiv) + supports_dynamo) from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( @@ -359,6 +359,24 @@ def __init__( self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, dtype=torch.int64) + # Persistent buffers for Context Parallism + self.cp_allgather_restore_idx = self._make_buffer(self.max_num_tokens, + dtype=torch.int64) + self.cp_padded_slot_mapping = torch.empty((self.max_num_tokens, ), + dtype=torch.int64, + device=self.device,) + self.num_cp_pads_cpu_tensor = torch.zeros((self.max_num_reqs, ), + device="cpu", + dtype=torch.int64, + pin_memory=True) + self.num_cp_pads_cpu = self.num_cp_pads_cpu_tensor.numpy() + self.cp_unpad_mask_cpu_tensor = torch.zeros((self.max_num_tokens, ), + device="cpu", + dtype=torch.bool, + pin_memory=True) + self.cp_unpad_mask_cpu = self.cp_unpad_mask_cpu_tensor.numpy() + + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: # NOTE: `mrope_positions` is implemented with one additional dummy @@ -797,90 +815,91 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: dummy_modality = mm_budget.get_modality_with_max_tokens() return self._get_mm_dummy_batch(dummy_modality, num_seqs) - def _num_scheduled_tokens_prefill_cp(self, num_tokens, - num_computed_tokens, - cp_kv_recover_idx): - num_scheduled_tokens = num_tokens - num_computed_tokens - num_cp_padded_scheduled_tokens = cdiv( - num_scheduled_tokens, 2 * self.cp_world_size) * (2 * self.cp_world_size - ) # pad to 2*cp_world_size - cp_pad = num_cp_padded_scheduled_tokens - num_scheduled_tokens - full_indices = list( - range(self.max_num_tokens * self.cp_world_size * self.dcp_world_size + - self.cp_world_size * self.dcp_world_size * self.max_num_reqs)) - chunk_size = num_cp_padded_scheduled_tokens // (2 * self.cp_world_size) - - # split position_ids (and use split position_ids to split input_ids afterwards) - req_position_cp = [] - req_position_cp.extend( - full_indices[self.cp_rank * chunk_size:(self.cp_rank + 1) * - chunk_size]) - req_position_cp.extend( - full_indices[num_cp_padded_scheduled_tokens - (self.cp_rank + 1) * - chunk_size:num_cp_padded_scheduled_tokens - - self.cp_rank * chunk_size]) - - # used to recover kv order in cp prefill (after all-gather kv and before storing kv_cache) - num_added_recover_tokens = len(cp_kv_recover_idx[0]) * self.cp_world_size - for rank in range(self.cp_world_size): - cp_kv_recover_idx[rank].extend( - full_indices[rank * chunk_size + - num_added_recover_tokens:(rank + 1) * chunk_size + - num_added_recover_tokens]) - cp_kv_recover_idx[rank].extend(full_indices[ - num_cp_padded_scheduled_tokens - (rank + 1) * chunk_size + - num_added_recover_tokens:num_cp_padded_scheduled_tokens - - rank * chunk_size + num_added_recover_tokens]) - - return req_position_cp, num_cp_padded_scheduled_tokens, cp_pad - - def _update_tokens_for_cp(self, tokens, scheduler_output: "SchedulerOutput"): - if not self.cp_world_size > 1: - return tokens - num_reqs = self.input_batch.num_reqs - self.num_cp_pads = np.empty(num_reqs, dtype=np.int32) - self.cp_kv_recover_idx: List[List[int]] = [[] - for _ in range(self.cp_world_size) - ] - self.position_cp = np.zeros(self.max_num_tokens, dtype=np.int32) - start_index = 0 - - for i, req_id in enumerate(self.input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - is_prefill = num_tokens > 1 # todo: compare num prompt tokens and num sch tokens + computed tokens - if is_prefill: - # when cp > 1 & prefill, need to pad & split sequence here - req_position_cp, num_cp_padded_scheduled_tokens, self.num_cp_pads[ - i] = self._num_scheduled_tokens_prefill_cp( - num_tokens, - self.input_batch.num_computed_tokens_cpu[i], - self.cp_kv_recover_idx) - num_tokens = len(req_position_cp) - self.position_cp[start_index:start_index + - num_tokens] = req_position_cp - start_index += num_tokens - tokens[i] = num_tokens - else: - self.num_cp_pads[i] = 0 - self.position_cp[start_index:start_index + - num_tokens] = [idx for idx in range(num_tokens)] - start_index += num_tokens - for rank in range(len(self.cp_kv_recover_idx)): - self.cp_kv_recover_idx[rank].append(rank) - return tokens - - def _update_logits_indices_for_cp(self, cu_num_tokens, scheduler_output: "SchedulerOutput"): - # todo: find a better way to get is_prefill - is_prefill = list( - scheduler_output.num_scheduled_tokens.values())[0] > 1 + def _update_tokens_for_cp(self, tokens): + """ + If context parallelism is enabled, we will calculate + the number of tokens `tokens` after sequence splitting. + Meanwhile, we will compute: + `positions` the new token positions, + `num_cp_pads` the number of padding tokens per request for alignment, + `unpad_mask` the mask for non-padded tokens, + `cp_allgather_restore_idx` indices to restore the original vector + order after CP allgather. + Example: + >>> tokens = [1, 5, 8] + >>> cp_world_size = 2 + >>> cp_rank = 0 + >>> _update_tokens_for_cp(tokens) + ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5], [1, 3, 0], [True, False, + True, True, True, True, True, False, False, False, True, True, + True, True, True, True, True, True], [0, 9, 1, 2, 10, 11, 12, 13, + 3, 4, 5, 6, 14, 15, 16, 17, 7, 8]) + >>> cp_rank = 1 + >>> _update_tokens_for_cp(tokens) + ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7], [1, 3, 0], [True, False, + True, True, True, True, True, False, False, False, True, True, + True, True, True, True, True, True], [0, 9, 1, 2, 10, 11, 12, 13, + 3, 4, 5, 6, 14, 15, 16, 17, 7, 8]) + """ num_reqs = self.input_batch.num_reqs - if self.cp_world_size > 1 and is_prefill: - # logits_indices = cu_num_tokens - num_cp_pads[:num_reqs] - 1 # if without all-gather and only sample on cp0 - logits_indices = cu_num_tokens * self.cp_world_size \ - - torch.tensor(self.num_cp_pads[:num_reqs]).to(cu_num_tokens) - 1 - else: - logits_indices = cu_num_tokens - 1 - return logits_indices + self.num_cp_pads_cpu[:num_reqs] = 0 + if not self.cp_world_size > 1: + return tokens, None + + num_decode_reqs = sum(self.input_batch.num_computed_tokens_cpu[ + :num_reqs] >= self.input_batch.num_prompt_tokens[:num_reqs]) + + num_padded_scheduled_tokens = np.ceil( + tokens / (2 * self.cp_world_size) + ).astype(np.int32) * (2 * self.cp_world_size) + # we align scheduled tokens of decode reqs to cp_world_size instead + # of 2*cp_world_size + num_padded_scheduled_tokens[:num_decode_reqs] = self.cp_world_size + self.num_cp_pads_cpu[:num_reqs] = num_padded_scheduled_tokens - tokens + cu_padded_tokens, cp_padded_arange = \ + self._get_cumsum_and_arange(num_padded_scheduled_tokens) + self.cp_unpad_mask_cpu[:cp_padded_arange.shape[0]] = \ + cp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens) + + cp_tokens = num_padded_scheduled_tokens // self.cp_world_size + cp_chunk_sizes = (cp_tokens // 2).clip(min=1) + _, cp_arange = self._get_cumsum_and_arange(cp_tokens) + _, cp_chunk_arange = self._get_cumsum_and_arange(cp_chunk_sizes) + cp_head_chunk_mask = cp_arange < np.repeat(cp_chunk_sizes, + cp_tokens) + + + def get_current_rank_positions( + positions_start_loc: Union[int, np.ndarray], + rank: int + ): + positions = np.zeros(len(cp_head_chunk_mask), dtype=np.int32) + head_start_loc = positions_start_loc + rank * cp_chunk_sizes + tail_start_loc = positions_start_loc + \ + (2 * self.cp_world_size - rank - 1) * cp_chunk_sizes + positions[cp_head_chunk_mask] = cp_chunk_arange + \ + np.repeat(head_start_loc, cp_chunk_sizes) + # Decode reqs do not have tail chunks. + positions[~cp_head_chunk_mask] = \ + cp_chunk_arange[num_decode_reqs:] + \ + np.repeat(tail_start_loc, cp_chunk_sizes)[num_decode_reqs:] + return positions + + positions = get_current_rank_positions(0, self.cp_rank) + # Decode tokens are duplicate and their positions always be 0. + positions[:num_decode_reqs] = 0 + + padded_pos_start_loc = np.roll(cu_padded_tokens, 1) + padded_pos_start_loc[0] = 0 + all_positions_lst = [get_current_rank_positions(padded_pos_start_loc, + rank_i) + for rank_i in range(self.cp_world_size)] + all_positions = np.concatenate(all_positions_lst) + self.cp_allgather_restore_idx.np[:all_positions.shape[0]] = \ + all_positions.argsort() + self.cp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) + return cp_tokens, positions + def _get_cumsum_and_arange( self, @@ -1016,10 +1035,12 @@ def _prepare_inputs( # Get the number of scheduled tokens for each request. req_ids = self.input_batch.req_ids tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - total_num_scheduled_tokens_for_slotmapping = total_num_scheduled_tokens + # NOTE(qcs): we need compute slotmapping for all kv + # instead of sliced sequences + total_num_scheduled_tokens4sltmap = total_num_scheduled_tokens original_num_scheduled_tokens = np.array(tokens, dtype=np.int32) - tokens = self._update_tokens_for_cp(tokens, scheduler_output) - num_scheduled_tokens = np.array(tokens, dtype=np.int32) + num_scheduled_tokens, positions_cp = self._update_tokens_for_cp( + original_num_scheduled_tokens) # update total_num_scheduled_tokens total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) max_num_scheduled_tokens = max(tokens) @@ -1037,17 +1058,17 @@ def _prepare_inputs( # Get positions. positions_np = self.positions.np[:total_num_scheduled_tokens] if self.cp_world_size > 1: + assert positions_cp is not None req_indices_for_slotmapping = np.repeat(self.arange_np[:num_reqs], original_num_scheduled_tokens) _, original_arange = self._get_cumsum_and_arange( original_num_scheduled_tokens) - positions_np_for_slotmapping = self.positions.np[ - :total_num_scheduled_tokens_for_slotmapping].copy() + positions_np_for_slotmapping = \ np.add(self.input_batch.num_computed_tokens_cpu[req_indices_for_slotmapping], original_arange, - out=positions_np_for_slotmapping) + ) np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - self.position_cp[:total_num_scheduled_tokens], + positions_cp[:total_num_scheduled_tokens], out=positions_np) else: np.add(self.input_batch.num_computed_tokens_cpu[req_indices], @@ -1125,7 +1146,7 @@ def _prepare_inputs( self.input_batch.block_table.compute_slot_mapping( req_indices_for_slotmapping, positions_np_for_slotmapping) self.input_batch.block_table.commit_slot_mapping( - total_num_scheduled_tokens_for_slotmapping) + total_num_scheduled_tokens4sltmap) # Prepare the attention metadata. self.query_start_loc.np[0] = 0 @@ -1191,10 +1212,8 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - logits_indices = self._update_logits_indices_for_cp( - query_start_loc[1:], - scheduler_output - ) + logits_indices = torch.from_numpy(cu_num_tokens) * \ + self.cp_world_size - self.num_cp_pads_cpu_tensor[:num_reqs] - 1 num_draft_tokens = None spec_decode_metadata = None else: @@ -1236,21 +1255,6 @@ def _prepare_inputs( self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() - if self.cp_world_size > 1: - # Prepare the metadata for Context Parallel - total_num_scheduled_tokens_for_slotmapping = sum(original_num_scheduled_tokens[:num_reqs]) - - total_prefill_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) - cp_kv_recover_idx = torch.zeros(total_prefill_num_scheduled_tokens * self.cp_world_size, - dtype=torch.int32, - device=self.device) - cp_kv_recover_idx.copy_(torch.tensor( - np.array(self.cp_kv_recover_idx).flatten().tolist()), - non_blocking=True) - self.cp_kv_recover_idx = cp_kv_recover_idx.to( - torch.float32).argsort().to(torch.int32) - else: - self.cp_kv_recover_idx = None # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -1268,7 +1272,7 @@ def _prepare_inputs( device=self.device, ) slot_mapping = torch.zeros( - (total_num_scheduled_tokens_for_slotmapping, ), + (total_num_scheduled_tokens4sltmap, ), dtype=torch.int64, device=self.device, ) @@ -1277,16 +1281,26 @@ def _prepare_inputs( blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor(num_reqs) slot_mapping = blk_table.slot_mapping.gpu[: - total_num_scheduled_tokens_for_slotmapping] + total_num_scheduled_tokens4sltmap] # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. - blk_table.slot_mapping.gpu[total_num_scheduled_tokens_for_slotmapping:].fill_( - -1) + blk_table.slot_mapping.gpu[total_num_scheduled_tokens4sltmap: + ].fill_(-1) num_common_prefix_blocks = ( scheduler_output. num_common_prefix_blocks[kv_cache_group_id]) + if self.cp_world_size > 1: + # After cp allgather and restore, there are padded tokens in + # kv, so we need pad slotmapping for alignment. + cp_padded_slot_mapping = self.cp_padded_slot_mapping[ + :total_num_scheduled_tokens*self.cp_world_size] + cp_unpad_mask = self.cp_unpad_mask_cpu_tensor[ + :total_num_scheduled_tokens*self.cp_world_size] + cp_padded_slot_mapping.fill_(-1) + cp_padded_slot_mapping[cp_unpad_mask] = slot_mapping + slot_mapping = cp_padded_slot_mapping common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, @@ -1304,7 +1318,8 @@ def _prepare_inputs( causal=True, encoder_seq_lens=encoder_seq_lens, query_positions=positions_np, - cp_kv_recover_idx=self.cp_kv_recover_idx, + cp_allgather_restore_idx=self.cp_allgather_restore_idx.gpu[ + :total_num_scheduled_tokens*self.cp_world_size], ) if self.speculative_config and \ @@ -1970,11 +1985,6 @@ def _pool( ) def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: - cp_size = self.vllm_config.parallel_config.context_parallel_size - if cp_size > 1: - # TODO(qcs): When ContextParallel is adapted to GraphMode, - # revise this length alignment strategy again. - return cdiv(num_scheduled_tokens, self.cp_world_size * 2) * 2 if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH and hasattr(self, "cudagraph_batch_sizes") @@ -1999,6 +2009,7 @@ def _preprocess( intermediate_tensors: Optional[IntermediateTensors] = None, ubatch_slices: Optional[UBatchSlices] = None, num_tokens_after_padding: Optional[torch.Tensor] = None, + num_scheduled_tokens_after_cp: Optional[int] = None, ) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor, Optional[IntermediateTensors], dict[str, Any]]: @@ -2009,10 +2020,17 @@ def _preprocess( num_input_tokens = int(num_tokens_after_padding[0].item() * 2) self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) elif ubatch_slices is None: - num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) + if self.cp_world_size == 1: + num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) + else: + assert num_scheduled_tokens_after_cp is not None + num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens_after_cp) num_pad, num_tokens_after_padding = self.get_dp_padding( num_input_tokens) num_input_tokens += num_pad + else: + raise RuntimeError(f"Unreachable branch, please check the value " + f"of ubatch_slikces({ubatch_slices}).") # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -2316,7 +2334,8 @@ def execute_model( intermediate_tensors, model_kwargs, ) = self._preprocess(scheduler_output, intermediate_tensors, - ubatch_slices, num_tokens_after_padding) + ubatch_slices, num_tokens_after_padding, + sum(num_scheduled_tokens_np)) if ubatch_slices is not None: num_input_tokens = num_input_tokens // 2 @@ -2361,13 +2380,11 @@ def execute_model( aux_hidden_states = None if self.cp_world_size > 1: - if isinstance(attn_metadata, dict): - cp_kv_recover_idx = list(attn_metadata.values())[0].cp_kv_recover_idx - else: - cp_kv_recover_idx = attn_metadata.cp_kv_recover_idx hidden_states = get_cp_group().all_gather(hidden_states, 0) hidden_states = torch.index_select( - hidden_states, 0, cp_kv_recover_idx) + hidden_states, 0, self.cp_allgather_restore_idx.gpu[ + :hidden_states.shape[0] + ]) if not self.broadcast_pp_output: # Common case. if not get_pp_group().is_last_rank: