Skip to content

Commit a583833

Browse files
pisceskkkgjc0824
andcommitted
[Perf] remove custom_mask
Co-authored-by: QiuChunshuo <[email protected]> Co-authored-by: gaojc <[email protected]> Signed-off-by: QiuChunshuo <[email protected]> Signed-off-by: gaojc <[email protected]>
1 parent d09bbc6 commit a583833

File tree

3 files changed

+256
-69
lines changed

3 files changed

+256
-69
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 96 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
MultipleOf,
2525
)
2626
from vllm.attention.ops.common import cp_lse_ag_out_ar
27+
from vllm.attention.ops.merge_attn_states import merge_attn_states
2728
from vllm.config import CUDAGraphMode, VllmConfig
2829
from vllm.distributed.parallel_state import get_pcp_group
2930
from vllm.logger import init_logger
@@ -51,6 +52,7 @@
5152
get_per_layer_parameters,
5253
infer_global_hyperparameters,
5354
split_decodes_and_prefills,
55+
PrefillContextParallelMetadata,
5456
)
5557
from vllm.v1.kv_cache_interface import AttentionSpec
5658

@@ -274,6 +276,7 @@ class FlashInferMetadata:
274276

275277
# For context parallel
276278
pcp_allgather_restore_idx: torch.Tensor | None = None
279+
pcp_metadata: PrefillContextParallelMetadata | None = None
277280

278281

279282
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
@@ -425,16 +428,18 @@ def _get_workspace_buffer(self):
425428
)
426429
return self._workspace_buffer
427430

428-
def _get_prefill_wrapper(self):
429-
if self._prefill_wrapper is None:
430-
if self.pcp_world_size > 1:
431-
self._prefill_wrapper = BatchPrefillWithRaggedKVCacheWrapper(
432-
self._get_workspace_buffer(), get_kv_cache_layout()
433-
)
434-
else:
435-
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
431+
def _get_prefill_wrapper(self, attn_metadata):
432+
# if self._prefill_wrapper is None:
433+
if self.pcp_world_size > 1:
434+
self._prefill_wrapper = {}
435+
for key in ["head", "tail"]:
436+
self._prefill_wrapper[key] = BatchPrefillWithRaggedKVCacheWrapper(
436437
self._get_workspace_buffer(), get_kv_cache_layout()
437438
)
439+
else:
440+
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
441+
self._get_workspace_buffer(), get_kv_cache_layout()
442+
)
438443
return self._prefill_wrapper
439444

440445
def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False):
@@ -667,6 +672,7 @@ def build(
667672
num_prefill_tokens=num_prefill_tokens,
668673
use_cascade=use_cascade,
669674
pcp_allgather_restore_idx=common_attn_metadata.pcp_allgather_restore_idx,
675+
pcp_metadata=common_attn_metadata.pcp_metadata,
670676
)
671677

672678
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
@@ -699,7 +705,7 @@ def build(
699705
if num_prefills > 0:
700706
# Decodes are first so prefills start after the last decode
701707
prefill_start = num_decodes
702-
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
708+
attn_metadata.prefill_wrapper = self._get_prefill_wrapper(common_attn_metadata)
703709
assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
704710
assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1
705711
assert (
@@ -721,38 +727,40 @@ def build(
721727

722728
if not attn_metadata.prefill_use_trtllm:
723729
if self.pcp_world_size > 1:
730+
assert common_attn_metadata.pcp_metadata is not None
724731
assert common_attn_metadata.query_positions is not None
725-
prefill_num_computed_tokens_cpu = num_computed_tokens_cpu[
726-
prefill_start:
727-
]
728-
kv_indptr_cpu = qo_indptr_cpu * self.pcp_world_size
729-
# init custom mask for head-tail query order
730-
custom_mask = self._get_pcp_custom_mask(
731-
qo_indptr_cpu=qo_indptr_cpu,
732-
q_pos=torch.from_numpy(
733-
common_attn_metadata.query_positions[prefill_start:]
734-
).long().to(self.device),
735-
kv_lens=(
736-
prefill_num_computed_tokens_cpu
737-
+ kv_indptr_cpu[1:]
738-
- kv_indptr_cpu[:-1]
739-
).to(self.device),
740-
)
741732

742-
attn_metadata.prefill_wrapper.plan(
733+
pcp_metadata = common_attn_metadata.pcp_metadata
734+
qo_indptr_cpu = pcp_metadata.q_head_start_loc
735+
kv_for_head_indptr = pcp_metadata.kv_for_head_indptr
736+
kv_for_tail_indptr = pcp_metadata.kv_for_tail_indptr
737+
738+
attn_metadata.prefill_wrapper["head"].plan(
743739
qo_indptr_cpu.to(self.device),
744-
kv_indptr_cpu.to(self.device),
740+
kv_for_head_indptr.to(self.device),
745741
self.num_qo_heads,
746742
self.num_kv_heads,
747743
self.head_dim,
748-
custom_mask=custom_mask,
744+
causal=True,
745+
sm_scale=self.sm_scale,
746+
window_left=self.window_left,
747+
logits_soft_cap=self.logits_soft_cap,
748+
q_data_type=self.q_data_type,
749+
kv_data_type=self.kv_cache_dtype,
750+
)
751+
# tail
752+
attn_metadata.prefill_wrapper["tail"].plan(
753+
qo_indptr_cpu.to(self.device),
754+
kv_for_tail_indptr.to(self.device),
755+
self.num_qo_heads,
756+
self.num_kv_heads,
757+
self.head_dim,
758+
causal=True,
749759
sm_scale=self.sm_scale,
750760
window_left=self.window_left,
751761
logits_soft_cap=self.logits_soft_cap,
752762
q_data_type=self.q_data_type,
753763
kv_data_type=self.kv_cache_dtype,
754-
fixed_split_size=self.prefill_fixed_split_size,
755-
disable_split_kv=self.disable_split_kv,
756764
)
757765
else:
758766
attn_metadata.prefill_wrapper.plan(
@@ -926,6 +934,32 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
926934
if self.sinks is not None and self.sinks.dtype != torch.float32:
927935
self.sinks = self.sinks.to(torch.float32)
928936

937+
def _attention_with_head_and_tail(self,
938+
q_head: torch.Tensor,
939+
q_tail: torch.Tensor,
940+
k_head: torch.Tensor,
941+
v_head: torch.Tensor,
942+
k_tail: torch.Tensor,
943+
v_tail: torch.Tensor,
944+
prefill_wrapper: BatchPrefillWithRaggedKVCacheWrapper,
945+
):
946+
output_head = torch.empty_like(q_head)
947+
prefill_wrapper["head"].run(
948+
q_head,
949+
k_head,
950+
v_head,
951+
out=output_head,
952+
)
953+
954+
output_tail = torch.empty_like(q_tail)
955+
prefill_wrapper["tail"].run(
956+
q_tail,
957+
k_tail,
958+
v_tail,
959+
out=output_tail,
960+
)
961+
return output_head, output_tail
962+
929963
def forward(
930964
self,
931965
layer: torch.nn.Module,
@@ -1088,20 +1122,44 @@ def forward(
10881122
assert prefill_wrapper is not None
10891123

10901124
if not attn_metadata.prefill_use_trtllm:
1091-
assert prefill_wrapper._window_left == self.window_left
1092-
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
1093-
assert prefill_wrapper._sm_scale == self.scale
10941125
if self.pcp_world_size > 1:
1126+
assert type(prefill_wrapper) == dict
1127+
for _, prefill_wrapper_i in prefill_wrapper.items():
1128+
assert prefill_wrapper_i._window_left == self.window_left
1129+
assert prefill_wrapper_i._logits_soft_cap == (self.logits_soft_cap or 0.0)
1130+
assert prefill_wrapper_i._sm_scale == self.scale
1131+
assert attn_metadata.pcp_metadata is not None
1132+
pcp_metadata = attn_metadata.pcp_metadata
1133+
q_head_indices = pcp_metadata.q_head_indices
1134+
q_tail_indices = pcp_metadata.q_tail_indices
1135+
kv_for_head_indices = pcp_metadata.kv_for_head_indices
1136+
kv_for_tail_indices = pcp_metadata.kv_for_tail_indices
1137+
q_full_indices = pcp_metadata.q_full_indices
1138+
10951139
# NOTE(qcs): Allgather causes duplicate decoding tokens.
10961140
prefill_key = key[num_decode_tokens * self.pcp_world_size :]
10971141
prefill_value = value[num_decode_tokens * self.pcp_world_size :]
1098-
prefill_wrapper.run(
1099-
prefill_query,
1100-
prefill_key,
1101-
prefill_value,
1102-
out=output[num_decode_tokens:],
1142+
1143+
output_head, output_tail = self._attention_with_head_and_tail(
1144+
torch.index_select(prefill_query, 0, q_head_indices),
1145+
torch.index_select(prefill_query, 0, q_tail_indices),
1146+
torch.index_select(prefill_key, 0, kv_for_head_indices),
1147+
torch.index_select(prefill_value, 0, kv_for_head_indices),
1148+
torch.index_select(prefill_key, 0, kv_for_tail_indices),
1149+
torch.index_select(prefill_value, 0, kv_for_tail_indices),
1150+
prefill_wrapper,
1151+
)
1152+
1153+
output_full = torch.index_select(
1154+
torch.cat([output_head, output_tail], dim=0),
1155+
0,
1156+
q_full_indices
11031157
)
1158+
output[num_decode_tokens:] = output_full
11041159
else:
1160+
assert prefill_wrapper._window_left == self.window_left
1161+
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
1162+
assert prefill_wrapper._sm_scale == self.scale
11051163
assert prefill_wrapper._causal
11061164
prefill_wrapper.run(
11071165
prefill_query,

vllm/v1/attention/backends/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,19 @@
4848
def is_valid_kv_cache_layout(value: str) -> bool:
4949
return value in get_args(KVCacheLayoutType)
5050

51+
@dataclass
52+
class PrefillContextParallelMetadata:
53+
"""
54+
Attention metadata for prefill context parallel
55+
"""
56+
q_head_indices: torch.Tensor
57+
q_tail_indices: torch.Tensor
58+
q_head_start_loc: torch.Tensor
59+
kv_for_head_indices: torch.Tensor
60+
kv_for_tail_indices : torch.Tensor
61+
kv_for_head_indptr: torch.Tensor
62+
kv_for_tail_indptr: torch.Tensor
63+
q_full_indices: torch.Tensor
5164

5265
@dataclass
5366
class CommonAttentionMetadata:
@@ -97,6 +110,7 @@ class CommonAttentionMetadata:
97110
# Needed by custom mask calc for context parallelism
98111
query_positions: np.ndarray | None = None
99112
pcp_allgather_restore_idx: torch.Tensor | None = None
113+
pcp_metadata: PrefillContextParallelMetadata | None = None
100114

101115

102116
def slice_query_start_locs(

0 commit comments

Comments
 (0)