2424 MultipleOf ,
2525)
2626from vllm .attention .ops .common import cp_lse_ag_out_ar
27+ from vllm .attention .ops .merge_attn_states import merge_attn_states
2728from vllm .config import CUDAGraphMode , VllmConfig
2829from vllm .distributed .parallel_state import get_pcp_group
2930from vllm .logger import init_logger
5152 get_per_layer_parameters ,
5253 infer_global_hyperparameters ,
5354 split_decodes_and_prefills ,
55+ PrefillContextParallelMetadata ,
5456)
5557from vllm .v1 .kv_cache_interface import AttentionSpec
5658
@@ -274,6 +276,7 @@ class FlashInferMetadata:
274276
275277 # For context parallel
276278 pcp_allgather_restore_idx : torch .Tensor | None = None
279+ pcp_metadata : PrefillContextParallelMetadata | None = None
277280
278281
279282class FlashInferMetadataBuilder (AttentionMetadataBuilder [FlashInferMetadata ]):
@@ -425,16 +428,18 @@ def _get_workspace_buffer(self):
425428 )
426429 return self ._workspace_buffer
427430
428- def _get_prefill_wrapper (self ):
429- if self ._prefill_wrapper is None :
430- if self .pcp_world_size > 1 :
431- self ._prefill_wrapper = BatchPrefillWithRaggedKVCacheWrapper (
432- self ._get_workspace_buffer (), get_kv_cache_layout ()
433- )
434- else :
435- self ._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper (
431+ def _get_prefill_wrapper (self , attn_metadata ):
432+ # if self._prefill_wrapper is None:
433+ if self .pcp_world_size > 1 :
434+ self ._prefill_wrapper = {}
435+ for key in ["head" , "tail" ]:
436+ self ._prefill_wrapper [key ] = BatchPrefillWithRaggedKVCacheWrapper (
436437 self ._get_workspace_buffer (), get_kv_cache_layout ()
437438 )
439+ else :
440+ self ._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper (
441+ self ._get_workspace_buffer (), get_kv_cache_layout ()
442+ )
438443 return self ._prefill_wrapper
439444
440445 def _get_decode_wrapper (self , batch_size : int , use_cudagraph : bool = False ):
@@ -667,6 +672,7 @@ def build(
667672 num_prefill_tokens = num_prefill_tokens ,
668673 use_cascade = use_cascade ,
669674 pcp_allgather_restore_idx = common_attn_metadata .pcp_allgather_restore_idx ,
675+ pcp_metadata = common_attn_metadata .pcp_metadata ,
670676 )
671677
672678 qo_indptr_cpu = common_attn_metadata .query_start_loc_cpu
@@ -699,7 +705,7 @@ def build(
699705 if num_prefills > 0 :
700706 # Decodes are first so prefills start after the last decode
701707 prefill_start = num_decodes
702- attn_metadata .prefill_wrapper = self ._get_prefill_wrapper ()
708+ attn_metadata .prefill_wrapper = self ._get_prefill_wrapper (common_attn_metadata )
703709 assert qo_indptr_cpu [prefill_start :].shape [0 ] == num_prefills + 1
704710 assert paged_kv_indptr_cpu [prefill_start :].shape [0 ] == num_prefills + 1
705711 assert (
@@ -721,38 +727,40 @@ def build(
721727
722728 if not attn_metadata .prefill_use_trtllm :
723729 if self .pcp_world_size > 1 :
730+ assert common_attn_metadata .pcp_metadata is not None
724731 assert common_attn_metadata .query_positions is not None
725- prefill_num_computed_tokens_cpu = num_computed_tokens_cpu [
726- prefill_start :
727- ]
728- kv_indptr_cpu = qo_indptr_cpu * self .pcp_world_size
729- # init custom mask for head-tail query order
730- custom_mask = self ._get_pcp_custom_mask (
731- qo_indptr_cpu = qo_indptr_cpu ,
732- q_pos = torch .from_numpy (
733- common_attn_metadata .query_positions [prefill_start :]
734- ).long ().to (self .device ),
735- kv_lens = (
736- prefill_num_computed_tokens_cpu
737- + kv_indptr_cpu [1 :]
738- - kv_indptr_cpu [:- 1 ]
739- ).to (self .device ),
740- )
741732
742- attn_metadata .prefill_wrapper .plan (
733+ pcp_metadata = common_attn_metadata .pcp_metadata
734+ qo_indptr_cpu = pcp_metadata .q_head_start_loc
735+ kv_for_head_indptr = pcp_metadata .kv_for_head_indptr
736+ kv_for_tail_indptr = pcp_metadata .kv_for_tail_indptr
737+
738+ attn_metadata .prefill_wrapper ["head" ].plan (
743739 qo_indptr_cpu .to (self .device ),
744- kv_indptr_cpu .to (self .device ),
740+ kv_for_head_indptr .to (self .device ),
745741 self .num_qo_heads ,
746742 self .num_kv_heads ,
747743 self .head_dim ,
748- custom_mask = custom_mask ,
744+ causal = True ,
745+ sm_scale = self .sm_scale ,
746+ window_left = self .window_left ,
747+ logits_soft_cap = self .logits_soft_cap ,
748+ q_data_type = self .q_data_type ,
749+ kv_data_type = self .kv_cache_dtype ,
750+ )
751+ # tail
752+ attn_metadata .prefill_wrapper ["tail" ].plan (
753+ qo_indptr_cpu .to (self .device ),
754+ kv_for_tail_indptr .to (self .device ),
755+ self .num_qo_heads ,
756+ self .num_kv_heads ,
757+ self .head_dim ,
758+ causal = True ,
749759 sm_scale = self .sm_scale ,
750760 window_left = self .window_left ,
751761 logits_soft_cap = self .logits_soft_cap ,
752762 q_data_type = self .q_data_type ,
753763 kv_data_type = self .kv_cache_dtype ,
754- fixed_split_size = self .prefill_fixed_split_size ,
755- disable_split_kv = self .disable_split_kv ,
756764 )
757765 else :
758766 attn_metadata .prefill_wrapper .plan (
@@ -926,6 +934,32 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
926934 if self .sinks is not None and self .sinks .dtype != torch .float32 :
927935 self .sinks = self .sinks .to (torch .float32 )
928936
937+ def _attention_with_head_and_tail (self ,
938+ q_head : torch .Tensor ,
939+ q_tail : torch .Tensor ,
940+ k_head : torch .Tensor ,
941+ v_head : torch .Tensor ,
942+ k_tail : torch .Tensor ,
943+ v_tail : torch .Tensor ,
944+ prefill_wrapper : BatchPrefillWithRaggedKVCacheWrapper ,
945+ ):
946+ output_head = torch .empty_like (q_head )
947+ prefill_wrapper ["head" ].run (
948+ q_head ,
949+ k_head ,
950+ v_head ,
951+ out = output_head ,
952+ )
953+
954+ output_tail = torch .empty_like (q_tail )
955+ prefill_wrapper ["tail" ].run (
956+ q_tail ,
957+ k_tail ,
958+ v_tail ,
959+ out = output_tail ,
960+ )
961+ return output_head , output_tail
962+
929963 def forward (
930964 self ,
931965 layer : torch .nn .Module ,
@@ -1088,20 +1122,44 @@ def forward(
10881122 assert prefill_wrapper is not None
10891123
10901124 if not attn_metadata .prefill_use_trtllm :
1091- assert prefill_wrapper ._window_left == self .window_left
1092- assert prefill_wrapper ._logits_soft_cap == (self .logits_soft_cap or 0.0 )
1093- assert prefill_wrapper ._sm_scale == self .scale
10941125 if self .pcp_world_size > 1 :
1126+ assert type (prefill_wrapper ) == dict
1127+ for _ , prefill_wrapper_i in prefill_wrapper .items ():
1128+ assert prefill_wrapper_i ._window_left == self .window_left
1129+ assert prefill_wrapper_i ._logits_soft_cap == (self .logits_soft_cap or 0.0 )
1130+ assert prefill_wrapper_i ._sm_scale == self .scale
1131+ assert attn_metadata .pcp_metadata is not None
1132+ pcp_metadata = attn_metadata .pcp_metadata
1133+ q_head_indices = pcp_metadata .q_head_indices
1134+ q_tail_indices = pcp_metadata .q_tail_indices
1135+ kv_for_head_indices = pcp_metadata .kv_for_head_indices
1136+ kv_for_tail_indices = pcp_metadata .kv_for_tail_indices
1137+ q_full_indices = pcp_metadata .q_full_indices
1138+
10951139 # NOTE(qcs): Allgather causes duplicate decoding tokens.
10961140 prefill_key = key [num_decode_tokens * self .pcp_world_size :]
10971141 prefill_value = value [num_decode_tokens * self .pcp_world_size :]
1098- prefill_wrapper .run (
1099- prefill_query ,
1100- prefill_key ,
1101- prefill_value ,
1102- out = output [num_decode_tokens :],
1142+
1143+ output_head , output_tail = self ._attention_with_head_and_tail (
1144+ torch .index_select (prefill_query , 0 , q_head_indices ),
1145+ torch .index_select (prefill_query , 0 , q_tail_indices ),
1146+ torch .index_select (prefill_key , 0 , kv_for_head_indices ),
1147+ torch .index_select (prefill_value , 0 , kv_for_head_indices ),
1148+ torch .index_select (prefill_key , 0 , kv_for_tail_indices ),
1149+ torch .index_select (prefill_value , 0 , kv_for_tail_indices ),
1150+ prefill_wrapper ,
1151+ )
1152+
1153+ output_full = torch .index_select (
1154+ torch .cat ([output_head , output_tail ], dim = 0 ),
1155+ 0 ,
1156+ q_full_indices
11031157 )
1158+ output [num_decode_tokens :] = output_full
11041159 else :
1160+ assert prefill_wrapper ._window_left == self .window_left
1161+ assert prefill_wrapper ._logits_soft_cap == (self .logits_soft_cap or 0.0 )
1162+ assert prefill_wrapper ._sm_scale == self .scale
11051163 assert prefill_wrapper ._causal
11061164 prefill_wrapper .run (
11071165 prefill_query ,
0 commit comments