9393 AttentionCGSupport ,
9494 AttentionMetadataBuilder ,
9595 CommonAttentionMetadata ,
96+ PrefillContextParallelMetadata ,
9697 create_fast_prefill_custom_backend ,
9798 get_cp_local_seq_lens ,
98- PrefillContextParallelMetadata ,
9999 reorder_batch_to_split_decodes_and_prefills ,
100100 split_attn_metadata ,
101101)
@@ -461,7 +461,9 @@ def __init__(
461461 if self .pcp_world_size > 1 :
462462 # Note(qcs): we will pad the tokens of each request
463463 # to a multiple of 2 * pcp_size.
464- max_num_tokens = self .max_num_tokens + self .max_num_reqs * 2 * self .pcp_world_size
464+ max_num_tokens = (
465+ self .max_num_tokens + self .max_num_reqs * 2 * self .pcp_world_size
466+ )
465467 else :
466468 max_num_tokens = self .max_num_tokens
467469 # Persistent buffers for CUDA graphs.
@@ -501,24 +503,15 @@ def __init__(
501503 # Persistent buffers for Prefill Context Parallism
502504 if self .pcp_world_size > 1 :
503505 self .pcp_allgather_restore_idx = self ._make_buffer (
504- max_num_tokens ,
505- dtype = torch .int64
506- )
507- self .q_head_indices = self ._make_buffer (
508- max_num_tokens ,
509- dtype = torch .int64
510- )
511- self .q_tail_indices = self ._make_buffer (
512- max_num_tokens ,
513- dtype = torch .int64
506+ max_num_tokens , dtype = torch .int64
514507 )
508+ self .q_head_indices = self ._make_buffer (max_num_tokens , dtype = torch .int64 )
509+ self .q_tail_indices = self ._make_buffer (max_num_tokens , dtype = torch .int64 )
515510 self .kv_for_head_indices = self ._make_buffer (
516- max_num_tokens ,
517- dtype = torch .int64
511+ max_num_tokens , dtype = torch .int64
518512 )
519513 self .kv_for_tail_indices = self ._make_buffer (
520- max_num_tokens ,
521- dtype = torch .int64
514+ max_num_tokens , dtype = torch .int64
522515 )
523516 self .pcp_padded_slot_mapping = torch .empty (
524517 (max_num_tokens ,),
@@ -534,15 +527,24 @@ def __init__(
534527 )
535528 self .pcp_unpad_mask_cpu = self .pcp_unpad_mask_cpu_tensor .numpy ()
536529 self .q_indptr_cpu_tensor = torch .zeros (
537- (self .max_num_reqs + 1 ,), device = "cpu" , dtype = torch .int64 , pin_memory = True
530+ (self .max_num_reqs + 1 ,),
531+ device = "cpu" ,
532+ dtype = torch .int64 ,
533+ pin_memory = True ,
538534 )
539535 self .q_indptr_cpu = self .q_indptr_cpu_tensor .numpy ()
540536 self .kv_for_head_indptr_cpu_tensor = torch .zeros (
541- (self .max_num_reqs + 1 ,), device = "cpu" , dtype = torch .int64 , pin_memory = True
537+ (self .max_num_reqs + 1 ,),
538+ device = "cpu" ,
539+ dtype = torch .int64 ,
540+ pin_memory = True ,
542541 )
543542 self .kv_for_head_indptr_cpu = self .kv_for_head_indptr_cpu_tensor .numpy ()
544543 self .kv_for_tail_indptr_cpu_tensor = torch .zeros (
545- (self .max_num_reqs + 1 ,), device = "cpu" , dtype = torch .int64 , pin_memory = True
544+ (self .max_num_reqs + 1 ,),
545+ device = "cpu" ,
546+ dtype = torch .int64 ,
547+ pin_memory = True ,
546548 )
547549 self .kv_for_tail_indptr_cpu = self .kv_for_tail_indptr_cpu_tensor .numpy ()
548550
@@ -1070,45 +1072,53 @@ def _get_pcp_metadata(
10701072 ) -> PrefillContextParallelMetadata :
10711073 """
10721074 During the prefill phrase, the attention computation is divided into
1073- two parts: q_head and q_tail. Here, we calculate the kv indices
1074- corresponding to q_head or q_tail. Meawhile, the q and kv indptr are
1075+ two parts: q_head and q_tail. Here, we calculate the kv indices
1076+ corresponding to q_head or q_tail. Meawhile, the q and kv indptr are
10751077 also computed to build the attention wrapper.
10761078 If the pcp_size is 2, the variables are following:
10771079 >>> q_lens [4, 8] kv_lens [8, 16]
10781080 >>> pcp_chunk_sizes[2, 4]
1079- >>> q_indptr [0, 2, 4]
1081+ >>> q_indptr[0, 2, 4]
10801082 >>> q_head_indices [0, 1, 4, 5, 6, 7] q_tail_indices [2, 3, 8, 9, 10, 11]
10811083 >>> kv_head_len r0 [2, 4] / r1 [4, 8]
10821084 >>> kv_for_head_indptr r0 [0, 2, 6] / r1 [0, 4, 12]
10831085 >>> kv_for_head_indices r0 [0, 1, 8, 9, 10, 11]
1084- >>> r1 [0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15]
1086+ >>> r1[0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15]
10851087 >>> kv_tail_len r0 [8, 16] / r1 [6, 12]
10861088 >>> kv_for_tail_indptr r0 [0, 8, 24] / r1 [0, 6, 18]
10871089 >>> kv_for_tail_indices r0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 23]
1088- >>> r1 [0, 1, 2, 3, 4, 5, 8, 9, ..., 19]
1090+ >>> r1[0, 1, 2, 3, 4, 5, 8, 9, ..., 19]
10891091 """
10901092 if len (q_lens ) == 0 :
10911093 return PrefillContextParallelMetadata (
10921094 allgather_restore_idx = allgather_restore_idx ,
10931095 )
10941096
10951097 def _get_partial_kv_idx (kv_partial_len , kv_partial_indptr , kv_parial_indices ):
1096- kv_partial_indptr [1 : len (kv_partial_len ) + 1 ], kv_partial_arange = self ._get_cumsum_and_arange (kv_partial_len )
1097- kv_parial_indices .np [: kv_partial_arange .shape [0 ]] = kv_partial_arange + np .repeat (
1098- kv_start_loc ,
1099- kv_partial_len ,
1098+ kv_partial_indptr [1 : len (kv_partial_len ) + 1 ], kv_partial_arange = (
1099+ self ._get_cumsum_and_arange (kv_partial_len )
1100+ )
1101+ kv_parial_indices .np [: kv_partial_arange .shape [0 ]] = (
1102+ kv_partial_arange
1103+ + np .repeat (
1104+ kv_start_loc ,
1105+ kv_partial_len ,
1106+ )
11001107 )
11011108 return kv_partial_arange .shape [0 ]
11021109
11031110 pcp_chunk_sizes = q_lens // 2
1104- self .q_indptr_cpu [1 : len (pcp_chunk_sizes ) + 1 ], q_chunk_arange = self ._get_cumsum_and_arange (pcp_chunk_sizes )
1111+ self .q_indptr_cpu [1 : len (pcp_chunk_sizes ) + 1 ], q_chunk_arange = (
1112+ self ._get_cumsum_and_arange (pcp_chunk_sizes )
1113+ )
11051114
11061115 q_head_start_loc = np .roll (np .cumsum (q_lens ), 1 )
11071116 q_head_start_loc [0 ] = 0
11081117 self .q_head_indices .np [: q_chunk_arange .shape [0 ]] = q_chunk_arange + np .repeat (
11091118 q_head_start_loc ,
11101119 pcp_chunk_sizes ,
11111120 )
1121+
11121122 self .q_head_indices .copy_to_gpu (q_chunk_arange .shape [0 ])
11131123
11141124 q_tail_start_loc = q_head_start_loc + pcp_chunk_sizes
@@ -1122,17 +1132,25 @@ def _get_partial_kv_idx(kv_partial_len, kv_partial_indptr, kv_parial_indices):
11221132 kv_start_loc [0 ] = 0
11231133 # kv_for_q_head
11241134 kv_for_head_len = (self .pcp_rank + 1 ) * pcp_chunk_sizes
1125- kv_head_tokens_sum = _get_partial_kv_idx (kv_for_head_len , self .kv_for_head_indptr_cpu , self .kv_for_head_indices )
1135+ kv_head_tokens_sum = _get_partial_kv_idx (
1136+ kv_for_head_len ,
1137+ self .kv_for_head_indptr_cpu ,
1138+ self .kv_for_head_indices ,
1139+ )
11261140 self .kv_for_head_indices .copy_to_gpu (kv_head_tokens_sum )
11271141 # kv_for_q_tail
11281142 kv_for_tail_len = (2 * self .pcp_world_size - self .pcp_rank ) * pcp_chunk_sizes
1129- kv_tail_tokens_sum = _get_partial_kv_idx (kv_for_tail_len , self .kv_for_tail_indptr_cpu , self .kv_for_tail_indices )
1143+ kv_tail_tokens_sum = _get_partial_kv_idx (
1144+ kv_for_tail_len ,
1145+ self .kv_for_tail_indptr_cpu ,
1146+ self .kv_for_tail_indices ,
1147+ )
11301148 self .kv_for_tail_indices .copy_to_gpu (kv_tail_tokens_sum )
11311149
11321150 q_full_indices = torch .cat (
11331151 [
11341152 self .q_head_indices .gpu [: q_chunk_arange .shape [0 ]],
1135- self .q_tail_indices .gpu [: q_chunk_arange .shape [0 ]]
1153+ self .q_tail_indices .gpu [: q_chunk_arange .shape [0 ]],
11361154 ]
11371155 ).argsort ()
11381156
@@ -1141,13 +1159,17 @@ def _get_partial_kv_idx(kv_partial_len, kv_partial_indptr, kv_parial_indices):
11411159 q_head_indices = self .q_head_indices .gpu [: q_chunk_arange .shape [0 ]],
11421160 q_tail_indices = self .q_tail_indices .gpu [: q_chunk_arange .shape [0 ]],
11431161 q_head_start_loc = self .q_indptr_cpu_tensor [: len (pcp_chunk_sizes ) + 1 ],
1144- kv_for_head_indices = self .kv_for_head_indices .gpu [: kv_head_tokens_sum ],
1145- kv_for_tail_indices = self .kv_for_tail_indices .gpu [: kv_tail_tokens_sum ],
1146- kv_for_head_indptr = self .kv_for_head_indptr_cpu_tensor [: len (kv_for_head_len ) + 1 ],
1147- kv_for_tail_indptr = self .kv_for_tail_indptr_cpu_tensor [: len (kv_for_tail_len ) + 1 ],
1162+ kv_for_head_indices = self .kv_for_head_indices .gpu [:kv_head_tokens_sum ],
1163+ kv_for_tail_indices = self .kv_for_tail_indices .gpu [:kv_tail_tokens_sum ],
1164+ kv_for_head_indptr = (
1165+ self .kv_for_head_indptr_cpu_tensor [: len (kv_for_head_len ) + 1 ]
1166+ ),
1167+ kv_for_tail_indptr = (
1168+ self .kv_for_tail_indptr_cpu_tensor [: len (kv_for_tail_len ) + 1 ]
1169+ ),
11481170 q_full_indices = q_full_indices ,
11491171 )
1150-
1172+
11511173 def _update_tokens_for_pcp (
11521174 self ,
11531175 tokens : np .ndarray ,
@@ -1189,8 +1211,15 @@ def _update_tokens_for_pcp(
11891211 self .input_batch .num_computed_tokens_cpu [:num_reqs ]
11901212 >= self .input_batch .num_prompt_tokens [:num_reqs ]
11911213 )
1214+ else :
1215+ if num_reqs is None or num_decode_reqs is None :
1216+ raise ValueError (
1217+ "num_reqs and num_decode_reqs must be provided for dummy input"
1218+ )
1219+ assert num_reqs is not None
1220+ assert num_decode_reqs is not None
11921221 self .num_pcp_pads_cpu [:num_reqs ] = 0
1193-
1222+
11941223 num_decode_tokens = sum (tokens [:num_decode_reqs ])
11951224
11961225 num_padded_scheduled_tokens = np .ceil (
@@ -1259,8 +1288,8 @@ def get_current_rank_positions(
12591288 self ._get_pcp_metadata (
12601289 pcp_tokens [num_decode_reqs :],
12611290 num_padded_scheduled_tokens [num_decode_reqs :],
1262- self .pcp_allgather_restore_idx .gpu [: all_positions .shape [0 ]]
1263- )
1291+ self .pcp_allgather_restore_idx .gpu [: all_positions .shape [0 ]],
1292+ ),
12641293 )
12651294
12661295 def _get_cumsum_and_arange (
@@ -1471,10 +1500,9 @@ def _prepare_inputs(
14711500
14721501 pcp_metadata = None
14731502 if self .pcp_world_size > 1 :
1474- num_scheduled_tokens [:num_reqs ], pcp_positions , pcp_metadata = \
1475- self ._update_tokens_for_pcp (
1476- num_scheduled_tokens [:num_reqs ]
1477- )
1503+ num_scheduled_tokens [:num_reqs ], pcp_positions , pcp_metadata = (
1504+ self ._update_tokens_for_pcp (num_scheduled_tokens [:num_reqs ])
1505+ )
14781506
14791507 # Re-update after PCP split sequences.
14801508 total_num_scheduled_tokens = sum (num_scheduled_tokens )
@@ -1605,7 +1633,7 @@ def _prepare_inputs(
16051633 if self .pcp_world_size > 1 :
16061634 discard_requests_mask = (
16071635 self .input_batch .num_computed_tokens_cpu [:num_reqs ]
1608- + num_scheduled_tokens * self .pcp_world_size
1636+ + num_scheduled_tokens * self .pcp_world_size
16091637 - self .num_pcp_pads_cpu [:num_reqs ]
16101638 ) < num_tokens_np
16111639 else :
@@ -3167,7 +3195,7 @@ def execute_model(
31673195 # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx
31683196 # ignores the padding from CUDA Graph.
31693197 hidden_states = get_pcp_group ().all_gather (
3170- hidden_states [:num_scheduled_tokens_np .sum ()],
3198+ hidden_states [: num_scheduled_tokens_np .sum ()],
31713199 0 ,
31723200 )
31733201 hidden_states = torch .index_select (
@@ -4077,13 +4105,14 @@ def _dummy_run(
40774105 pcp_metadata = None
40784106 if self .pcp_world_size > 1 and force_attention :
40794107 num_decode_reqs = sum (num_scheduled_tokens == 1 )
4080- num_scheduled_tokens [:num_reqs ], _ , pcp_metadata = \
4108+ num_scheduled_tokens [:num_reqs ], _ , pcp_metadata = (
40814109 self ._update_tokens_for_pcp (
40824110 num_scheduled_tokens [:num_reqs ],
40834111 dummy_input = True ,
40844112 num_reqs = num_reqs ,
40854113 num_decode_reqs = num_decode_reqs ,
40864114 )
4115+ )
40874116 total_num_scheduled_tokens = int (num_scheduled_tokens .sum ())
40884117 num_sampled_tokens = np .ones (num_reqs , dtype = np .int32 )
40894118
0 commit comments