1010import torch
1111from flashinfer import (BatchDecodeWithPagedKVCacheWrapper ,
1212 BatchPrefillWithPagedKVCacheWrapper ,
13+ BatchPrefillWithRaggedKVCacheWrapper ,
1314 MultiLevelCascadeAttentionWrapper )
1415from flashinfer .decode import _get_range_buf , trtllm_batch_decode_with_kv_cache
1516from flashinfer .prefill import trtllm_batch_context_with_kv_cache
1819from vllm import _custom_ops as ops
1920from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
2021 AttentionType )
22+ from vllm .attention .ops .common import cp_lse_ag_out_ar
2123from vllm .config import CUDAGraphMode , VllmConfig
2224from vllm .logger import init_logger
25+ from vllm .distributed .parallel_state import get_cp_group
2326from vllm .model_executor .layers .quantization .utils .quant_utils import (
2427 QuantKey , kFp8StaticTensorSym , kNvfp4Quant )
2528from vllm .platforms import current_platform
@@ -234,6 +237,9 @@ class FlashInferMetadata:
234237
235238 qo_indptr_gpu : Optional [torch .Tensor ] = None
236239 paged_kv_indptr_gpu : Optional [torch .Tensor ] = None
240+
241+ # For context parallel
242+ cp_kv_recover_idx : Optional [torch .Tensor ] = None
237243
238244
239245class FlashInferMetadataBuilder (AttentionMetadataBuilder [FlashInferMetadata ]):
@@ -256,8 +262,9 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
256262 self .kv_cache_spec .block_size )
257263 max_num_reqs = vllm_config .scheduler_config .max_num_seqs
258264 max_num_pages = max_num_reqs * max_num_pages_per_req
265+ # NOTE(qcs): Context Parallel do not support graph mode now
259266 self .enable_cuda_graph = (self .compilation_config .cudagraph_mode .\
260- decode_mode () == CUDAGraphMode .FULL )
267+ decode_mode () == CUDAGraphMode .FULL and self . cp_world_size == 1 )
261268 if self .enable_cuda_graph :
262269 # For full cudagraph capture, one `decode_wrapper` for each batch
263270 # size is needed for FlashInfer.
@@ -266,6 +273,14 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
266273 self ._decode_cudagraph_max_bs = min (
267274 max_num_reqs , self .compilation_config .max_capture_size )
268275
276+ try :
277+ self .cp_world_size = get_cp_group ().world_size
278+ self .cp_rank = get_cp_group ().rank_in_group
279+ except AssertionError :
280+ # CP might not be initialized in testing
281+ self .cp_world_size = 1
282+ self .cp_rank = 0
283+
269284 self .num_qo_heads = self .model_config .get_num_attention_heads (
270285 self .vllm_config .parallel_config )
271286 self .num_kv_heads = self .kv_cache_spec .num_kv_heads
@@ -348,8 +363,12 @@ def _get_workspace_buffer(self):
348363
349364 def _get_prefill_wrapper (self ):
350365 if self ._prefill_wrapper is None :
351- self ._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper (
352- self ._get_workspace_buffer (), get_kv_cache_layout ())
366+ if self .cp_world_size > 1 :
367+ self ._prefill_wrapper = BatchPrefillWithRaggedKVCacheWrapper (
368+ self ._get_workspace_buffer (), get_kv_cache_layout ())
369+ else :
370+ self ._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper (
371+ self ._get_workspace_buffer (), get_kv_cache_layout ())
353372 return self ._prefill_wrapper
354373
355374 def _get_decode_wrapper (self ,
@@ -413,7 +432,12 @@ def build(self,
413432 max_seq_len = common_attn_metadata .max_seq_len
414433 seq_lens = common_attn_metadata .seq_lens
415434 seq_lens_cpu = common_attn_metadata .seq_lens_cpu
435+ if self .cp_world_size > 1 :
436+ seq_lens_cpu = seq_lens_cpu // \
437+ self .cp_world_size + (self .cp_rank < seq_lens_cpu \
438+ % self .cp_world_size )
416439 seq_lens_np = seq_lens_cpu .numpy ()
440+ num_computed_tokens_cpu = common_attn_metadata .num_computed_tokens_cpu
417441 block_table_tensor = common_attn_metadata .block_table_tensor
418442
419443 num_blocks_np = (seq_lens_np + (page_size - 1 )) // page_size
@@ -495,6 +519,13 @@ def build(self,
495519 self .cache_dtype ,
496520 self .q_data_type ,
497521 has_sinks = self .has_sinks )
522+
523+ if self .cp_world_size > 1 and (prefill_use_trtllm
524+ or decode_use_trtllm ):
525+ raise NotImplementedError (
526+ "Trtllm not support lse, please use flash attention "
527+ "or disable attention sinks." )
528+
498529 if self .has_sinks and not (prefill_use_trtllm and decode_use_trtllm ):
499530 raise NotImplementedError (
500531 "FlashInfer backend currently does not support attention "
@@ -521,6 +552,7 @@ def build(self,
521552 num_prefills = num_prefills ,
522553 num_prefill_tokens = num_prefill_tokens ,
523554 use_cascade = use_cascade ,
555+ cp_kv_recover_idx = common_attn_metadata .cp_kv_recover_idx ,
524556 )
525557
526558 qo_indptr_cpu = common_attn_metadata .query_start_loc_cpu
@@ -567,23 +599,69 @@ def build(self,
567599 qo_indptr_cpu = qo_indptr_cpu [prefill_start :] - qo_indptr_cpu [
568600 prefill_start ]
569601 paged_kv_indptr_cpu = paged_kv_indptr_cpu [prefill_start :]
602+ prefill_num_computed_tokens_cpu = num_computed_tokens_cpu [prefill_start :]
570603 if not attn_metadata .prefill_use_trtllm :
571- attn_metadata .prefill_wrapper .plan (
572- qo_indptr_cpu ,
573- paged_kv_indptr_cpu ,
574- paged_kv_indices ,
575- paged_kv_last_page_len_cpu [prefill_start :],
576- self .num_qo_heads ,
577- self .num_kv_heads ,
578- self .head_dim ,
579- self .page_size ,
580- causal = True ,
581- sm_scale = self .sm_scale ,
582- window_left = self .window_left ,
583- logits_soft_cap = self .logits_soft_cap ,
584- q_data_type = self .q_data_type ,
585- kv_data_type = self .kv_cache_dtype ,
586- )
604+ if self .cp_world_size > 1 :
605+ # NOTE(qcs): no chunked prefill and prefix caching
606+ kv_indptr_cpu = qo_indptr_cpu * self .cp_world_size
607+ # init custom mask for head-tail query order
608+ mask_arr = []
609+ q_pos = common_attn_metadata .query_positions
610+ for i in range (num_prefills ):
611+ # |----<C>-----|-<Q0>-|-<Q1>-|
612+ # |---<C+Q*cp_world_size>----|
613+ # cp_world_size = 2
614+ # Q = 2
615+ # C = 8
616+ # cur_q_pos = [0,3]
617+ # context_mask_i.shape = (2, 8)
618+ # upper = [0,1,2,3]
619+ # local_mask_i = [[True, False, False, False],
620+ # [True, True, True, True]] # size=(2, 4)
621+ # mask_i.shape = (2, 12)
622+ cur_q_pos = torch .from_numpy (q_pos [qo_indptr_cpu [i ]:qo_indptr_cpu [i + 1 ]])
623+ Q = len (cur_q_pos )
624+ C = prefill_num_computed_tokens_cpu [i ]
625+ if Q <= 0 :
626+ mask_arr .append (torch .zeros (0 , dtype = torch .bool ))
627+ continue
628+ context_mask_i = torch .ones ((Q , C ), dtype = torch .bool )
629+ upper = torch .arange (Q * self .cp_world_size )
630+ local_mask_i = (upper .unsqueeze (0 ) <= cur_q_pos .unsqueeze (1 ))
631+ mask_i = torch .cat ([context_mask_i , local_mask_i ], dim = 1 )
632+ mask_arr .append (mask_i .flatten ())
633+ custom_mask = torch .cat (mask_arr , dim = 0 ).to (self .device )
634+
635+ attn_metadata .prefill_wrapper .plan (
636+ qo_indptr_cpu .to (self .device ),
637+ kv_indptr_cpu .to (self .device ),
638+ self .num_qo_heads ,
639+ self .num_kv_heads ,
640+ self .head_dim ,
641+ custom_mask = custom_mask ,
642+ sm_scale = self .sm_scale ,
643+ window_left = self .window_left ,
644+ logits_soft_cap = self .logits_soft_cap ,
645+ q_data_type = self .q_data_type ,
646+ kv_data_type = self .kv_cache_dtype ,
647+ )
648+ else :
649+ attn_metadata .prefill_wrapper .plan (
650+ qo_indptr_cpu ,
651+ paged_kv_indptr_cpu ,
652+ paged_kv_indices ,
653+ paged_kv_last_page_len_cpu [prefill_start :],
654+ self .num_qo_heads ,
655+ self .num_kv_heads ,
656+ self .head_dim ,
657+ self .page_size ,
658+ causal = True ,
659+ sm_scale = self .sm_scale ,
660+ window_left = self .window_left ,
661+ logits_soft_cap = self .logits_soft_cap ,
662+ q_data_type = self .q_data_type ,
663+ kv_data_type = self .kv_cache_dtype ,
664+ )
587665 else :
588666 attn_metadata .qo_indptr_gpu = qo_indptr_cpu .to (
589667 self .device , non_blocking = True )
@@ -644,6 +722,8 @@ def use_cascade_attention(self, *args, **kwargs) -> bool:
644722 # TODO: The cascade wrapper currently does not support setting
645723 # kv cache dtype to something different from query dtype.
646724 return False
725+ if self .cp_world_size > 1 :
726+ return False
647727 return use_cascade_attention (* args , ** kwargs )
648728
649729
@@ -803,16 +883,17 @@ def forward(
803883 # and value[:num_actual_tokens] because the reshape_and_cache_flash
804884 # op uses the slot_mapping's shape to determine the number of
805885 # actual tokens.
806- torch .ops ._C_cache_ops .reshape_and_cache_flash (
807- key ,
808- value ,
809- kv_cache [:, 0 ],
810- kv_cache [:, 1 ],
811- attn_metadata .slot_mapping ,
812- self .kv_cache_dtype ,
813- layer ._k_scale ,
814- layer ._v_scale ,
815- )
886+ if self .cp_world_size == 1 :
887+ torch .ops ._C_cache_ops .reshape_and_cache_flash (
888+ key ,
889+ value ,
890+ kv_cache [:, 0 ],
891+ kv_cache [:, 1 ],
892+ attn_metadata .slot_mapping ,
893+ self .kv_cache_dtype ,
894+ layer ._k_scale ,
895+ layer ._v_scale ,
896+ )
816897
817898 # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
818899 # to process the cache when the kv_cache_dtype is fp8
@@ -847,18 +928,50 @@ def forward(
847928 assert prefill_wrapper is not None
848929
849930 if not attn_metadata .prefill_use_trtllm :
850- assert prefill_wrapper ._causal
851931 assert prefill_wrapper ._window_left == self .window_left
852932 assert prefill_wrapper ._logits_soft_cap == (
853933 self .logits_soft_cap or 0.0 )
854934 assert prefill_wrapper ._sm_scale == self .scale
855- prefill_wrapper .run (
856- prefill_query ,
857- kv_cache_permute ,
858- k_scale = layer ._k_scale_float ,
859- v_scale = layer ._v_scale_float ,
860- out = output [num_decode_tokens :],
861- )
935+ if self .cp_world_size > 1 :
936+ key_across_cp = get_cp_group ().all_gather (
937+ key [num_decode_tokens :].contiguous (), dim = 0 )
938+ value_across_cp = get_cp_group ().all_gather (
939+ value [num_decode_tokens :].contiguous (), dim = 0 )
940+ key_across_cp = torch .index_select (
941+ key_across_cp , 0 ,
942+ attn_metadata .cp_kv_recover_idx
943+ )
944+ value_across_cp = torch .index_select (
945+ value_across_cp , 0 ,
946+ attn_metadata .cp_kv_recover_idx
947+ )
948+ torch .ops ._C_cache_ops .reshape_and_cache_flash (
949+ key_across_cp ,
950+ value_across_cp ,
951+ kv_cache [:, 0 ],
952+ kv_cache [:, 1 ],
953+ attn_metadata .slot_mapping [num_decode_tokens :],
954+ self .kv_cache_dtype ,
955+ layer ._k_scale ,
956+ layer ._v_scale ,
957+ )
958+ # TODO(qcs): 考虑 chunked prefill/ prefix cache 情况下
959+ # kvcache的获取与拼接
960+ prefill_wrapper .run (
961+ prefill_query ,
962+ key_across_cp ,
963+ value_across_cp ,
964+ out = output [num_decode_tokens :],
965+ )
966+ else :
967+ assert prefill_wrapper ._causal
968+ prefill_wrapper .run (
969+ prefill_query ,
970+ kv_cache_permute ,
971+ k_scale = layer ._k_scale_float ,
972+ v_scale = layer ._v_scale_float ,
973+ out = output [num_decode_tokens :],
974+ )
862975 else :
863976 # prefill_query may be non-contiguous
864977 prefill_query = prefill_query .contiguous ()
@@ -933,13 +1046,35 @@ def forward(
9331046 assert decode_wrapper ._logits_soft_cap == (self .logits_soft_cap
9341047 or 0.0 )
9351048 assert decode_wrapper ._sm_scale == self .scale
936- decode_wrapper .run (
937- decode_query ,
938- kv_cache_permute ,
939- k_scale = layer ._k_scale_float ,
940- v_scale = layer ._v_scale_float ,
941- out = output [:num_decode_tokens ],
942- )
1049+ if self .cp_world_size > 1 :
1050+ torch .ops ._C_cache_ops .reshape_and_cache_flash (
1051+ key [:num_decode_tokens ],
1052+ value [:num_decode_tokens ],
1053+ kv_cache [:, 0 ],
1054+ kv_cache [:, 1 ],
1055+ attn_metadata .slot_mapping [:num_decode_tokens ],
1056+ self .kv_cache_dtype ,
1057+ layer ._k_scale ,
1058+ layer ._v_scale ,
1059+ )
1060+ kv_cache_permute = kv_cache .permute (* stride_order )
1061+ out , lse = decode_wrapper .run (
1062+ decode_query ,
1063+ kv_cache_permute ,
1064+ k_scale = layer ._k_scale_float ,
1065+ v_scale = layer ._v_scale_float ,
1066+ return_lse = True ,
1067+ )
1068+ output [:num_decode_tokens ] = \
1069+ cp_lse_ag_out_ar (out , lse , get_cp_group ())
1070+ else :
1071+ decode_wrapper .run (
1072+ decode_query ,
1073+ kv_cache_permute ,
1074+ k_scale = layer ._k_scale_float ,
1075+ v_scale = layer ._v_scale_float ,
1076+ out = output [:num_decode_tokens ],
1077+ )
9431078 else :
9441079 # decode_query may be non-contiguous
9451080 decode_query = decode_query .contiguous ()
0 commit comments