@@ -1001,6 +1001,24 @@ def _get_pcp_metadata(
10011001 >>> kv_for_tail_indices r0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 23]
10021002 >>> r1 [0, 1, 2, 3, 4, 5, 8, 9, ..., 19]
10031003 """
1004+ if len (q_lens ) == 0 :
1005+ return PrefillContextParallelMetadata (
1006+ allgather_restore_idx = allgather_restore_idx ,
1007+ )
1008+
1009+ def _get_partial_kv_idx (kv_len_per_pcp_chunk ):
1010+ kv_partial_len = pcp_chunk_sizes * kv_len_per_pcp_chunk
1011+ kv_partial_indptr = np .zeros (len (kv_partial_len ) + 1 )
1012+ kv_partial_indptr [1 :], kv_partial_arange = self ._get_cumsum_and_arange (kv_partial_len )
1013+ kv_parial_indices = kv_partial_arange + np .repeat (
1014+ kv_start_loc ,
1015+ kv_partial_len ,
1016+ )
1017+ return kv_partial_indptr , kv_parial_indices
1018+
1019+ def _to_tensor (data , ** kwargs ):
1020+ return {k : torch .from_numpy (v ).to (** kwargs ) for k , v in data .items ()}
1021+
10041022 pcp_chunk_sizes = q_lens // 2
10051023 q_indptr = np .zeros (len (pcp_chunk_sizes ) + 1 )
10061024 q_indptr [1 :], q_chunk_arange = self ._get_cumsum_and_arange (pcp_chunk_sizes )
@@ -1021,41 +1039,23 @@ def _get_pcp_metadata(
10211039 kv_start_loc = np .roll (np .cumsum (kv_lens ), 1 )
10221040 kv_start_loc [0 ] = 0
10231041 # kv_for_q_head
1024- kv_head_len = pcp_chunk_sizes * (self .pcp_rank + 1 )
1025- kv_for_head_indptr = np .zeros (len (kv_head_len ) + 1 )
1026- kv_for_head_indptr [1 :], kv_nomask_head_arange = self ._get_cumsum_and_arange (kv_head_len )
1027- kv_for_head_indices = kv_nomask_head_arange + np .repeat (
1028- kv_start_loc ,
1029- kv_head_len ,
1030- )
1042+ kv_for_head_indptr , kv_for_head_indices = _get_partial_kv_idx (self .pcp_rank + 1 )
10311043 # kv_for_q_tail
1032- kv_tail_len = pcp_chunk_sizes * (2 * self .pcp_world_size - self .pcp_rank )
1033- kv_for_tail_indptr = np .zeros (len (kv_tail_len ) + 1 )
1034- kv_for_tail_indptr [1 :], kv_nomask_tail_arange = self ._get_cumsum_and_arange (kv_tail_len )
1035- kv_for_tail_indices = kv_nomask_tail_arange + np .repeat (
1036- kv_start_loc ,
1037- kv_tail_len ,
1044+ kv_for_tail_indptr , kv_for_tail_indices = _get_partial_kv_idx (
1045+ 2 * self .pcp_world_size - self .pcp_rank
10381046 )
1039-
1040- head_tail_indices = {
1047+
1048+ head_tail_indices = _to_tensor ( {
10411049 "q_head" : q_head_indices ,
10421050 "q_tail" : q_tail_indices ,
1043- "kv_head" : kv_for_head_indices ,
1051+ "kv_head" : kv_for_head_indices ,
10441052 "kv_tail" : kv_for_tail_indices ,
1045- }
1046- head_tail_indptr = {
1053+ }, device = self . device , dtype = torch . int64 , non_blocking = True )
1054+ head_tail_indptr = _to_tensor ( {
10471055 "q" : q_indptr ,
10481056 "kv_head" : kv_for_head_indptr ,
10491057 "kv_tail" : kv_for_tail_indptr
1050- }
1051- for key , value in head_tail_indices .items ():
1052- head_tail_indices [key ] = torch .from_numpy (value ).to (
1053- device = self .device , dtype = torch .int64 , non_blocking = True
1054- )
1055- for key , value in head_tail_indptr .items ():
1056- head_tail_indptr [key ] = torch .from_numpy (value ).to (
1057- dtype = torch .int64
1058- )
1058+ }, dtype = torch .int64 )
10591059
10601060 q_full_indices = torch .cat ([head_tail_indices ["q_head" ], head_tail_indices ["q_tail" ]])
10611061 q_full_indices = q_full_indices .to (torch .float32 ).argsort ().to (torch .int32 )
@@ -1074,7 +1074,10 @@ def _get_pcp_metadata(
10741074
10751075 def _update_tokens_for_pcp (
10761076 self ,
1077- tokens : np .ndarray
1077+ tokens : np .ndarray ,
1078+ dummy_input : bool = False ,
1079+ num_reqs : int | None = None ,
1080+ num_decode_reqs : int | None = None ,
10781081 ) -> tuple [np .ndarray , np .ndarray , PrefillContextParallelMetadata ]:
10791082 """
10801083 If prefill context parallelism is enabled, we will update
@@ -1104,13 +1107,14 @@ def _update_tokens_for_pcp(
11041107 >>> self.pcp_allgather_resotre_idx
11051108 [0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8]
11061109 """
1107- num_reqs = self .input_batch .num_reqs
1110+ if not dummy_input :
1111+ num_reqs = self .input_batch .num_reqs
1112+ num_decode_reqs = sum (
1113+ self .input_batch .num_computed_tokens_cpu [:num_reqs ]
1114+ >= self .input_batch .num_prompt_tokens [:num_reqs ]
1115+ )
11081116 self .num_pcp_pads_cpu [:num_reqs ] = 0
1109-
1110- num_decode_reqs = sum (
1111- self .input_batch .num_computed_tokens_cpu [:num_reqs ]
1112- >= self .input_batch .num_prompt_tokens [:num_reqs ]
1113- )
1117+
11141118 num_decode_tokens = sum (tokens [:num_decode_reqs ])
11151119
11161120 num_padded_scheduled_tokens = np .ceil (
@@ -1175,11 +1179,11 @@ def get_current_rank_positions(
11751179 self .pcp_allgather_restore_idx .copy_to_gpu (all_positions .shape [0 ])
11761180 return (
11771181 pcp_tokens [:num_reqs ],
1178- positions ,
1182+ positions ,
11791183 self ._get_pcp_metadata (
11801184 pcp_tokens [num_decode_reqs :],
11811185 num_padded_scheduled_tokens [num_decode_reqs :],
1182- self .pcp_allgather_restore_idx .gpu [:all_positions .shape [0 ]]
1186+ self .pcp_allgather_restore_idx .gpu [: all_positions .shape [0 ]]
11831187 )
11841188 )
11851189
@@ -1474,11 +1478,14 @@ def _prepare_inputs(
14741478
14751479 # Record the index of requests that should not be sampled,
14761480 # so that we could clear the sampled tokens before returning
1477- discard_requests_mask = (
1478- self .input_batch .num_computed_tokens_cpu [:num_reqs ]
1479- + num_scheduled_tokens * self .pcp_world_size
1480- - self .num_pcp_pads_cpu [:num_reqs ]
1481- ) < num_tokens_np
1481+ if self .pcp_world_size > 1 :
1482+ discard_requests_mask = (
1483+ self .input_batch .num_computed_tokens_cpu [:num_reqs ]
1484+ + num_scheduled_tokens * self .pcp_world_size
1485+ - self .num_pcp_pads_cpu [:num_reqs ]
1486+ ) < num_tokens_np
1487+ else :
1488+ discard_requests_mask = self .seq_lens .np [:num_reqs ] < num_tokens_np
14821489 discard_request_indices = np .nonzero (discard_requests_mask )[0 ]
14831490 self .num_discarded_requests = len (discard_request_indices )
14841491 self .discard_request_indices .np [: self .num_discarded_requests ] = (
@@ -1702,7 +1709,7 @@ def _build_attention_metadata(
17021709 num_logits_indices = num_logits_indices ,
17031710 causal = True ,
17041711 encoder_seq_lens = encoder_seq_lens ,
1705- cp_local_seq_lens = self .cp_local_seq_lens .gpu [:num_reqs ]
1712+ dcp_local_seq_lens = self .cp_local_seq_lens .gpu [:num_reqs ]
17061713 if self .total_cp_world_size > 1
17071714 else None ,
17081715 pcp_metadata = pcp_metadata ,
@@ -2872,7 +2879,7 @@ def execute_model(
28722879 use_spec_decode = len (scheduler_output .scheduled_spec_decode_tokens ) > 0
28732880 attn_metadata , spec_decode_common_attn_metadata = (
28742881 self ._build_attention_metadata (
2875- total_num_scheduled_tokens = total_num_scheduled_tokens ,
2882+ total_num_scheduled_tokens = total_num_scheduled_tokens if self . pcp_world_size == 1 else num_scheduled_tokens_np . sum () ,
28762883 max_num_scheduled_tokens = max_num_scheduled_tokens ,
28772884 num_reqs = num_reqs ,
28782885 ubatch_slices = ubatch_slices ,
@@ -3811,6 +3818,16 @@ def _dummy_run(
38113818 assert sum (num_scheduled_tokens_list ) == num_tokens
38123819 assert len (num_scheduled_tokens_list ) == num_reqs
38133820 num_scheduled_tokens = np .array (num_scheduled_tokens_list , dtype = np .int32 )
3821+ pcp_metadata = None
3822+ if self .pcp_world_size > 1 and force_attention :
3823+ num_decode_reqs = sum (num_scheduled_tokens == 1 )
3824+ num_scheduled_tokens [:num_reqs ], _ , pcp_metadata = \
3825+ self ._update_tokens_for_pcp (
3826+ num_scheduled_tokens [:num_reqs ],
3827+ dummy_input = True ,
3828+ num_reqs = num_reqs ,
3829+ num_decode_reqs = num_decode_reqs ,
3830+ )
38143831 total_num_scheduled_tokens = int (num_scheduled_tokens .sum ())
38153832 num_sampled_tokens = np .ones (num_reqs , dtype = np .int32 )
38163833
@@ -3828,7 +3845,7 @@ def _dummy_run(
38283845 uniform_decode = uniform_decode ,
38293846 num_scheduled_tokens_per_request = num_scheduled_tokens ,
38303847 )
3831- num_tokens_after_padding = num_tokens
3848+ num_tokens_after_padding = num_tokens if self . pcp_world_size == 1 else total_num_scheduled_tokens
38323849 if num_tokens_across_dp is not None :
38333850 dp_rank = self .parallel_config .data_parallel_rank
38343851 num_tokens_after_padding = int (num_tokens_across_dp [dp_rank ])
@@ -3854,11 +3871,12 @@ def _dummy_run(
38543871 self .query_start_loc .copy_to_gpu ()
38553872
38563873 attn_metadata , _ = self ._build_attention_metadata (
3857- total_num_scheduled_tokens = num_tokens ,
3874+ total_num_scheduled_tokens = total_num_scheduled_tokens ,
38583875 max_num_scheduled_tokens = max_query_len ,
38593876 num_reqs = num_reqs ,
38603877 ubatch_slices = ubatch_slices ,
38613878 for_cudagraph_capture = True ,
3879+ pcp_metadata = pcp_metadata if self .pcp_world_size > 1 else None ,
38623880 )
38633881
38643882 with self .maybe_dummy_run_with_lora (
0 commit comments