Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ class AttentionImpl(ABC, Generic[T]):
dcp_world_size: int
dcp_rank: int

cp_world_size: int
cp_rank: int

def __new__(cls, *args, **kwargs):
# use __new__ so that all subclasses will call this
self = super().__new__(cls)
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

import vllm.envs as envs
from vllm.config import ParallelConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank, get_context_model_parallel_rank
from vllm.distributed import (get_context_model_parallel_rank, get_dp_group,
get_tensor_model_parallel_rank)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
Expand Down Expand Up @@ -604,7 +605,8 @@ def make(tp_size_: int, dp_size_: int, cp_size_: int,
level's of parallelism to use in the fused moe layer.

Args:
tp_size_ (int): `tp_size` pa use_ep = (dp_size_ * tp_size_ssed into the FusedMoE constructor.
tp_size_ (int): `tp_size` pa use_ep = (dp_size_ * tp_size_ssed into
the FusedMoE constructor.
dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
vllm_parallel_config (ParallelConfig): vLLM's parallel config
object which contains the `enable_expert_parallel` flag.
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import (get_dp_group, get_ep_group,
from vllm.distributed import (get_context_model_parallel_world_size,
get_dp_group, get_ep_group,
get_tensor_model_parallel_world_size,
get_context_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.forward_context import ForwardContext, get_forward_context
Expand Down
8 changes: 8 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,14 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"compatible. Set the all_to_all backend to deepep_low_latency "
"to use those kernels instead.")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE

if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and parallel_config.context_parallel_size > 1):
logger.info(
"Context Parallel: disabling cudagraphs since CP."
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE


@classmethod
def get_current_memory_usage(cls,
Expand Down
147 changes: 63 additions & 84 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
AttentionType)
from vllm.attention.ops.common import cp_lse_ag_out_ar
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.distributed.parallel_state import get_cp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.platforms import current_platform
Expand Down Expand Up @@ -239,7 +239,7 @@ class FlashInferMetadata:
paged_kv_indptr_gpu: Optional[torch.Tensor] = None

# For context parallel
cp_kv_recover_idx: Optional[torch.Tensor] = None
cp_allgather_restore_idx: Optional[torch.Tensor] = None


class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
Expand All @@ -262,9 +262,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
# NOTE(qcs): Context Parallel do not support graph mode now
self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\
decode_mode() == CUDAGraphMode.FULL and self.cp_world_size == 1)
decode_mode() == CUDAGraphMode.FULL)
if self.enable_cuda_graph:
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
Expand Down Expand Up @@ -552,7 +551,7 @@ def build(self,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
use_cascade=use_cascade,
cp_kv_recover_idx=common_attn_metadata.cp_kv_recover_idx,
cp_allgather_restore_idx=common_attn_metadata.cp_allgather_restore_idx,
)

qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
Expand Down Expand Up @@ -599,38 +598,30 @@ def build(self,
qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[
prefill_start]
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
prefill_num_computed_tokens_cpu = num_computed_tokens_cpu[prefill_start:]
prefill_num_computed_tokens_cpu = \
num_computed_tokens_cpu[prefill_start:]
if not attn_metadata.prefill_use_trtllm:
if self.cp_world_size > 1:
# NOTE(qcs): no chunked prefill and prefix caching
assert common_attn_metadata.query_positions is not None
kv_indptr_cpu = qo_indptr_cpu * self.cp_world_size
# init custom mask for head-tail query order
mask_arr = []
q_pos = common_attn_metadata.query_positions
for i in range(num_prefills):
# |----<C>-----|-<Q0>-|-<Q1>-|
# |---<C+Q*cp_world_size>----|
# cp_world_size = 2
# Q = 2
# C = 8
# cur_q_pos = [0,3]
# context_mask_i.shape = (2, 8)
# upper = [0,1,2,3]
# local_mask_i = [[True, False, False, False],
# [True, True, True, True]] # size=(2, 4)
# mask_i.shape = (2, 12)
cur_q_pos = torch.from_numpy(q_pos[qo_indptr_cpu[i]:qo_indptr_cpu[i+1]])
Q = len(cur_q_pos)
C = prefill_num_computed_tokens_cpu[i]
if Q <= 0:
mask_arr.append(torch.zeros(0, dtype=torch.bool))
continue
context_mask_i = torch.ones((Q, C), dtype=torch.bool)
upper = torch.arange(Q*self.cp_world_size)
local_mask_i = (upper.unsqueeze(0) <= cur_q_pos.unsqueeze(1))
mask_i = torch.cat([context_mask_i, local_mask_i], dim=1)
mask_arr.append(mask_i.flatten())
custom_mask = torch.cat(mask_arr, dim=0).to(self.device)
q_pos = torch.from_numpy(
common_attn_metadata.query_positions[
prefill_start:]).long()
kv_lens = prefill_num_computed_tokens_cpu + \
kv_indptr_cpu[1:] - kv_indptr_cpu[:-1]
max_q_lens = int(q_pos.max().item()) + 1
max_kv_lens = int(kv_lens.max().item())
mask = torch.ones(max_q_lens, max_kv_lens,
dtype=torch.bool).tril()
selected_rows = torch.index_select(mask, 0, q_pos)
col_indices = torch.arange(max_kv_lens).expand(q_pos.size(0), -1)
valid_mask = col_indices < torch.repeat_interleave(
kv_lens,
qo_indptr_cpu[1:] - \
qo_indptr_cpu[:-1]
).unsqueeze(1)
custom_mask = selected_rows[valid_mask].to(self.device)

attn_metadata.prefill_wrapper.plan(
qo_indptr_cpu.to(self.device),
Expand Down Expand Up @@ -874,6 +865,28 @@ def forward(
# performance to make sure it does not introduce any overhead.

num_actual_tokens = attn_metadata.num_actual_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens

key_across_cp = get_cp_group().all_gather(
key.contiguous(), dim=0)
value_across_cp = get_cp_group().all_gather(
value.contiguous(), dim=0)
if (self.cp_world_size > 1
and attn_metadata.cp_allgather_restore_idx is not None):
# Reorder kv after cp allgather.
# Note that there are duplicate decoding tokens,
# but we only save the first one in kvcache.
key_across_cp = torch.index_select(
key_across_cp, 0,
attn_metadata.cp_allgather_restore_idx
)
value_across_cp = torch.index_select(
value_across_cp, 0,
attn_metadata.cp_allgather_restore_idx
)
key = key_across_cp
value = value_across_cp

if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
Expand All @@ -883,17 +896,16 @@ def forward(
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
if self.cp_world_size == 1:
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)

# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
Expand All @@ -913,9 +925,6 @@ def forward(
output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
return output

num_decode_tokens = attn_metadata.num_decode_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens

stride_order = FlashInferBackend.get_kv_cache_stride_order()
kv_cache_permute = kv_cache.permute(*stride_order)
# Regular attention (common case).
Expand All @@ -933,34 +942,15 @@ def forward(
self.logits_soft_cap or 0.0)
assert prefill_wrapper._sm_scale == self.scale
if self.cp_world_size > 1:
key_across_cp = get_cp_group().all_gather(
key[num_decode_tokens:].contiguous(), dim=0)
value_across_cp = get_cp_group().all_gather(
value[num_decode_tokens:].contiguous(), dim=0)
key_across_cp = torch.index_select(
key_across_cp, 0,
attn_metadata.cp_kv_recover_idx
)
value_across_cp = torch.index_select(
value_across_cp, 0,
attn_metadata.cp_kv_recover_idx
)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key_across_cp,
value_across_cp,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping[num_decode_tokens:],
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# TODO(qcs): 考虑 chunked prefill/ prefix cache 情况下
# kvcache的获取与拼接
# NOTE(qcs): Allgather causes duplicate decoding tokens.
prefill_key = key[
num_decode_tokens*self.cp_world_size:]
prefill_value = value[
num_decode_tokens*self.cp_world_size:]
prefill_wrapper.run(
prefill_query,
key_across_cp,
value_across_cp,
prefill_key,
prefill_value,
out=output[num_decode_tokens:],
)
else:
Expand Down Expand Up @@ -1047,17 +1037,6 @@ def forward(
or 0.0)
assert decode_wrapper._sm_scale == self.scale
if self.cp_world_size > 1:
torch.ops._C_cache_ops.reshape_and_cache_flash(
key[:num_decode_tokens],
value[:num_decode_tokens],
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping[:num_decode_tokens],
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
kv_cache_permute = kv_cache.permute(*stride_order)
out, lse = decode_wrapper.run(
decode_query,
kv_cache_permute,
Expand Down
11 changes: 7 additions & 4 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class CommonAttentionMetadata:

# Needed by custom mask calc for context parallelism
query_positions: Optional[np.ndarray] = None
cp_kv_recover_idx: Optional[torch.Tensor] = None
cp_allgather_restore_idx: Optional[torch.Tensor] = None

def slice_query_start_locs(
query_start_loc: torch.Tensor,
Expand Down Expand Up @@ -139,10 +139,13 @@ def _make_metadata_with_slice(
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
slot_mapping = attn_metadata.slot_mapping[token_slice]

# TODO(qcs): check if we can split query_positions and
# cp_kv_recover_idx as following approach
query_positions = attn_metadata.query_positions[token_slice] \
if attn_metadata.query_positions is not None else None
cp_kv_recover_idx = attn_metadata.cp_kv_recover_idx[token_slice] \
if attn_metadata.cp_kv_recover_idx is not None else None
cp_allgather_restore_idx = attn_metadata.cp_allgather_restore_idx[
token_slice] if attn_metadata.cp_allgather_restore_idx is not None \
else None

return CommonAttentionMetadata(
query_start_loc=query_start_loc,
Expand All @@ -157,7 +160,7 @@ def _make_metadata_with_slice(
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
query_positions=query_positions,
cp_kv_recover_idx=cp_kv_recover_idx,
cp_allgather_restore_idx=cp_allgather_restore_idx,
)


Expand Down
12 changes: 7 additions & 5 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
destroy_model_parallel)
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
MessageQueue)
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_pp_group, get_tp_group,
get_cp_group)
from vllm.distributed.parallel_state import (get_cp_group, get_dp_group,
get_ep_group, get_pp_group,
get_tp_group)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
Expand Down Expand Up @@ -64,7 +64,8 @@ def _init_executor(self) -> None:
tensor_parallel_size = self.parallel_config.tensor_parallel_size
pp_parallel_size = self.parallel_config.pipeline_parallel_size
context_parallel_size = self.parallel_config.context_parallel_size
assert self.world_size == tensor_parallel_size * pp_parallel_size * context_parallel_size, (
assert self.world_size == tensor_parallel_size * pp_parallel_size * \
context_parallel_size, (
f"world_size ({self.world_size}) must be equal to the "
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
f"_parallel_size ({pp_parallel_size}) x context"
Expand Down Expand Up @@ -345,7 +346,8 @@ def _get_output_rank(self) -> int:
# 16-23, PP rank 2
# 24-31, PP rank 3
# so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3)
return self.world_size - self.parallel_config.tensor_parallel_size * self.parallel_config.context_parallel_size
return self.world_size - self.parallel_config.tensor_parallel_size * \
self.parallel_config.context_parallel_size


@dataclass
Expand Down
23 changes: 15 additions & 8 deletions vllm/v1/worker/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import torch

from vllm.distributed import get_dcp_group, get_cp_group
from vllm.distributed import get_cp_group, get_dcp_group
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
Expand Down Expand Up @@ -92,18 +92,21 @@ def compute_slot_mapping(self, req_indices: np.ndarray,

# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
virtual_block_size = self.block_size * self.dcp_world_size * self.cp_world_size
virtual_block_size = self.block_size * self.dcp_world_size * \
self.cp_world_size
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions // virtual_block_size)
block_numbers = self.block_table.np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
self.current_rank = self.dcp_world_size * self.cp_rank + self.dcp_rank
mask = (virtual_block_offsets %
(self.dcp_world_size * self.cp_world_size) == self.current_rank)
self.current_rank = self.dcp_world_size * self.cp_rank + \
self.dcp_rank
mask = (virtual_block_offsets % (self.dcp_world_size * \
self.cp_world_size) == self.current_rank)
# Calculate local block_offsets
block_offsets = virtual_block_offsets // (self.dcp_world_size * self.cp_world_size)
block_offsets = virtual_block_offsets // \
(self.dcp_world_size * self.cp_world_size)
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
Expand Down Expand Up @@ -147,8 +150,12 @@ def _make_buffer(self, *size: Union[int, torch.SymInt],
device=self.device,
pin_memory=self.pin_memory)

def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) -> list[list[list[int]]]:
"Splits computed token counts across dcp and sp dimensions for distributed allocation."
def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) \
-> list[list[list[int]]]:
"""
Splits computed token counts across dcp and sp dimensions for
distributed allocation.
"""
num_requests = len(num_computed_tokens)
num_computed_tokens_of_dcp_sp = [[
[0] * self.dcp_world_size for _ in range(self.cp_world_size)
Expand Down
Loading
Loading