Skip to content

Commit 68e3ea6

Browse files
author
Jingchun Gao
committed
[Fix] modelrunner support
Signed-off-by: Jingchun Gao <[email protected]>
1 parent b3fcd1b commit 68e3ea6

File tree

1 file changed

+63
-45
lines changed

1 file changed

+63
-45
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 63 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,24 @@ def _get_pcp_metadata(
10011001
>>> kv_for_tail_indices r0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 23]
10021002
>>> r1 [0, 1, 2, 3, 4, 5, 8, 9, ..., 19]
10031003
"""
1004+
if len(q_lens) == 0:
1005+
return PrefillContextParallelMetadata(
1006+
allgather_restore_idx=allgather_restore_idx,
1007+
)
1008+
1009+
def _get_partial_kv_idx(kv_len_per_pcp_chunk):
1010+
kv_partial_len = pcp_chunk_sizes * kv_len_per_pcp_chunk
1011+
kv_partial_indptr = np.zeros(len(kv_partial_len) + 1)
1012+
kv_partial_indptr[1:], kv_partial_arange = self._get_cumsum_and_arange(kv_partial_len)
1013+
kv_parial_indices = kv_partial_arange + np.repeat(
1014+
kv_start_loc,
1015+
kv_partial_len,
1016+
)
1017+
return kv_partial_indptr, kv_parial_indices
1018+
1019+
def _to_tensor(data, **kwargs):
1020+
return {k: torch.from_numpy(v).to(**kwargs) for k, v in data.items()}
1021+
10041022
pcp_chunk_sizes = q_lens // 2
10051023
q_indptr = np.zeros(len(pcp_chunk_sizes) + 1)
10061024
q_indptr[1:], q_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes)
@@ -1021,41 +1039,23 @@ def _get_pcp_metadata(
10211039
kv_start_loc = np.roll(np.cumsum(kv_lens), 1)
10221040
kv_start_loc[0] = 0
10231041
# kv_for_q_head
1024-
kv_head_len = pcp_chunk_sizes * (self.pcp_rank + 1)
1025-
kv_for_head_indptr = np.zeros(len(kv_head_len) + 1)
1026-
kv_for_head_indptr[1:], kv_nomask_head_arange = self._get_cumsum_and_arange(kv_head_len)
1027-
kv_for_head_indices = kv_nomask_head_arange + np.repeat(
1028-
kv_start_loc,
1029-
kv_head_len,
1030-
)
1042+
kv_for_head_indptr, kv_for_head_indices = _get_partial_kv_idx(self.pcp_rank + 1)
10311043
# kv_for_q_tail
1032-
kv_tail_len = pcp_chunk_sizes * (2 * self.pcp_world_size - self.pcp_rank)
1033-
kv_for_tail_indptr = np.zeros(len(kv_tail_len) + 1)
1034-
kv_for_tail_indptr[1:], kv_nomask_tail_arange = self._get_cumsum_and_arange(kv_tail_len)
1035-
kv_for_tail_indices = kv_nomask_tail_arange + np.repeat(
1036-
kv_start_loc,
1037-
kv_tail_len,
1044+
kv_for_tail_indptr, kv_for_tail_indices = _get_partial_kv_idx(
1045+
2 * self.pcp_world_size - self.pcp_rank
10381046
)
1039-
1040-
head_tail_indices = {
1047+
1048+
head_tail_indices = _to_tensor({
10411049
"q_head": q_head_indices,
10421050
"q_tail": q_tail_indices,
1043-
"kv_head": kv_for_head_indices,
1051+
"kv_head": kv_for_head_indices,
10441052
"kv_tail": kv_for_tail_indices,
1045-
}
1046-
head_tail_indptr = {
1053+
}, device=self.device, dtype=torch.int64, non_blocking=True)
1054+
head_tail_indptr = _to_tensor({
10471055
"q": q_indptr,
10481056
"kv_head": kv_for_head_indptr,
10491057
"kv_tail": kv_for_tail_indptr
1050-
}
1051-
for key, value in head_tail_indices.items():
1052-
head_tail_indices[key] = torch.from_numpy(value).to(
1053-
device=self.device, dtype=torch.int64, non_blocking=True
1054-
)
1055-
for key, value in head_tail_indptr.items():
1056-
head_tail_indptr[key] = torch.from_numpy(value).to(
1057-
dtype=torch.int64
1058-
)
1058+
}, dtype=torch.int64)
10591059

10601060
q_full_indices = torch.cat([head_tail_indices["q_head"], head_tail_indices["q_tail"]])
10611061
q_full_indices = q_full_indices.to(torch.float32).argsort().to(torch.int32)
@@ -1074,7 +1074,10 @@ def _get_pcp_metadata(
10741074

10751075
def _update_tokens_for_pcp(
10761076
self,
1077-
tokens: np.ndarray
1077+
tokens: np.ndarray,
1078+
dummy_input: bool = False,
1079+
num_reqs: int | None = None,
1080+
num_decode_reqs: int | None = None,
10781081
) -> tuple[np.ndarray, np.ndarray, PrefillContextParallelMetadata]:
10791082
"""
10801083
If prefill context parallelism is enabled, we will update
@@ -1104,13 +1107,14 @@ def _update_tokens_for_pcp(
11041107
>>> self.pcp_allgather_resotre_idx
11051108
[0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8]
11061109
"""
1107-
num_reqs = self.input_batch.num_reqs
1110+
if not dummy_input:
1111+
num_reqs = self.input_batch.num_reqs
1112+
num_decode_reqs = sum(
1113+
self.input_batch.num_computed_tokens_cpu[:num_reqs]
1114+
>= self.input_batch.num_prompt_tokens[:num_reqs]
1115+
)
11081116
self.num_pcp_pads_cpu[:num_reqs] = 0
1109-
1110-
num_decode_reqs = sum(
1111-
self.input_batch.num_computed_tokens_cpu[:num_reqs]
1112-
>= self.input_batch.num_prompt_tokens[:num_reqs]
1113-
)
1117+
11141118
num_decode_tokens = sum(tokens[:num_decode_reqs])
11151119

11161120
num_padded_scheduled_tokens = np.ceil(
@@ -1175,11 +1179,11 @@ def get_current_rank_positions(
11751179
self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0])
11761180
return (
11771181
pcp_tokens[:num_reqs],
1178-
positions,
1182+
positions,
11791183
self._get_pcp_metadata(
11801184
pcp_tokens[num_decode_reqs:],
11811185
num_padded_scheduled_tokens[num_decode_reqs:],
1182-
self.pcp_allgather_restore_idx.gpu[:all_positions.shape[0]]
1186+
self.pcp_allgather_restore_idx.gpu[: all_positions.shape[0]]
11831187
)
11841188
)
11851189

@@ -1474,11 +1478,14 @@ def _prepare_inputs(
14741478

14751479
# Record the index of requests that should not be sampled,
14761480
# so that we could clear the sampled tokens before returning
1477-
discard_requests_mask = (
1478-
self.input_batch.num_computed_tokens_cpu[:num_reqs]
1479-
+ num_scheduled_tokens * self.pcp_world_size
1480-
- self.num_pcp_pads_cpu[:num_reqs]
1481-
) < num_tokens_np
1481+
if self.pcp_world_size > 1:
1482+
discard_requests_mask = (
1483+
self.input_batch.num_computed_tokens_cpu[:num_reqs]
1484+
+ num_scheduled_tokens * self.pcp_world_size
1485+
- self.num_pcp_pads_cpu[:num_reqs]
1486+
) < num_tokens_np
1487+
else:
1488+
discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np
14821489
discard_request_indices = np.nonzero(discard_requests_mask)[0]
14831490
self.num_discarded_requests = len(discard_request_indices)
14841491
self.discard_request_indices.np[: self.num_discarded_requests] = (
@@ -1702,7 +1709,7 @@ def _build_attention_metadata(
17021709
num_logits_indices=num_logits_indices,
17031710
causal=True,
17041711
encoder_seq_lens=encoder_seq_lens,
1705-
cp_local_seq_lens=self.cp_local_seq_lens.gpu[:num_reqs]
1712+
dcp_local_seq_lens=self.cp_local_seq_lens.gpu[:num_reqs]
17061713
if self.total_cp_world_size > 1
17071714
else None,
17081715
pcp_metadata=pcp_metadata,
@@ -2872,7 +2879,7 @@ def execute_model(
28722879
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
28732880
attn_metadata, spec_decode_common_attn_metadata = (
28742881
self._build_attention_metadata(
2875-
total_num_scheduled_tokens=total_num_scheduled_tokens,
2882+
total_num_scheduled_tokens=total_num_scheduled_tokens if self.pcp_world_size == 1 else num_scheduled_tokens_np.sum(),
28762883
max_num_scheduled_tokens=max_num_scheduled_tokens,
28772884
num_reqs=num_reqs,
28782885
ubatch_slices=ubatch_slices,
@@ -3811,6 +3818,16 @@ def _dummy_run(
38113818
assert sum(num_scheduled_tokens_list) == num_tokens
38123819
assert len(num_scheduled_tokens_list) == num_reqs
38133820
num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
3821+
pcp_metadata = None
3822+
if self.pcp_world_size > 1 and force_attention:
3823+
num_decode_reqs = sum(num_scheduled_tokens == 1)
3824+
num_scheduled_tokens[:num_reqs], _, pcp_metadata = \
3825+
self._update_tokens_for_pcp(
3826+
num_scheduled_tokens[:num_reqs],
3827+
dummy_input=True,
3828+
num_reqs=num_reqs,
3829+
num_decode_reqs=num_decode_reqs,
3830+
)
38143831
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
38153832
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
38163833

@@ -3828,7 +3845,7 @@ def _dummy_run(
38283845
uniform_decode=uniform_decode,
38293846
num_scheduled_tokens_per_request=num_scheduled_tokens,
38303847
)
3831-
num_tokens_after_padding = num_tokens
3848+
num_tokens_after_padding = num_tokens if self.pcp_world_size == 1 else total_num_scheduled_tokens
38323849
if num_tokens_across_dp is not None:
38333850
dp_rank = self.parallel_config.data_parallel_rank
38343851
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
@@ -3854,11 +3871,12 @@ def _dummy_run(
38543871
self.query_start_loc.copy_to_gpu()
38553872

38563873
attn_metadata, _ = self._build_attention_metadata(
3857-
total_num_scheduled_tokens=num_tokens,
3874+
total_num_scheduled_tokens=total_num_scheduled_tokens,
38583875
max_num_scheduled_tokens=max_query_len,
38593876
num_reqs=num_reqs,
38603877
ubatch_slices=ubatch_slices,
38613878
for_cudagraph_capture=True,
3879+
pcp_metadata=pcp_metadata if self.pcp_world_size > 1 else None,
38623880
)
38633881

38643882
with self.maybe_dummy_run_with_lora(

0 commit comments

Comments
 (0)