Skip to content

Commit df36e76

Browse files
author
Jingchun Gao
committed
[Lint]
Signed-off-by: Jingchun Gao <[email protected]>
1 parent 3d65330 commit df36e76

File tree

4 files changed

+94
-63
lines changed

4 files changed

+94
-63
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,17 +1285,17 @@ def forward(
12851285
assert pcp_allgather_restore_idx is not None
12861286
# NOTE(yyj): we must `slice` key and value because pcp_allgather_restore_idx
12871287
# ignores the padding from CUDA Graph. To be optimized for performance!
1288-
key_across_pcp = get_pcp_group().all_gather(key[:num_actual_tokens].contiguous(), dim=0)
1289-
value_across_pcp = get_pcp_group().all_gather(value[:num_actual_tokens].contiguous(), dim=0)
1288+
key_across_pcp = get_pcp_group().all_gather(
1289+
key[:num_actual_tokens].contiguous(), dim=0
1290+
)
1291+
value_across_pcp = get_pcp_group().all_gather(
1292+
value[:num_actual_tokens].contiguous(), dim=0
1293+
)
12901294
# Reorder kv after pcp allgather.
12911295
# Note that there are duplicate decoding tokens,
12921296
# but we only save the first one in kvcache.
1293-
key = torch.index_select(
1294-
key_across_pcp, 0, pcp_allgather_restore_idx
1295-
)
1296-
value = torch.index_select(
1297-
value_across_pcp, 0, pcp_allgather_restore_idx
1298-
)
1297+
key = torch.index_select(key_across_pcp, 0, pcp_allgather_restore_idx)
1298+
value = torch.index_select(value_across_pcp, 0, pcp_allgather_restore_idx)
12991299
if self.kv_sharing_target_layer_name is None:
13001300
# Reshape the input keys and values and store them in the cache.
13011301
# Skip this if sharing KV cache with an earlier attention layer.
@@ -1356,7 +1356,7 @@ def forward(
13561356
if self.total_cp_world_size > 1:
13571357
assert isinstance(prefill_wrapper, BatchCPPrefillWrapper)
13581358
expected_logits_soft_cap = self.logits_soft_cap or 0.0
1359-
1359+
13601360
wrappers_to_check = [(prefill_wrapper._context, False)]
13611361
if self.pcp_world_size > 1:
13621362
wrappers_to_check.extend(
@@ -1367,7 +1367,7 @@ def forward(
13671367
)
13681368
else:
13691369
wrappers_to_check.append((prefill_wrapper._new_tokens, True))
1370-
1370+
13711371
for wrapper, expected_causal in wrappers_to_check:
13721372
assert wrapper._window_left == self.window_left
13731373
assert wrapper._logits_soft_cap == expected_logits_soft_cap
@@ -1497,9 +1497,7 @@ def forward(
14971497
return_lse=True,
14981498
)
14991499
if self.pcp_world_size > 1:
1500-
out = cp_lse_ag_out_ar(
1501-
out, lse, get_pcp_group()
1502-
)
1500+
out = cp_lse_ag_out_ar(out, lse, get_pcp_group())
15031501
output[:num_decode_tokens] = out
15041502
else:
15051503
decode_wrapper.run(

vllm/v1/attention/backends/mla/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -754,8 +754,8 @@ def build(
754754
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
755755
seq_lens = common_attn_metadata.seq_lens
756756
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
757-
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
758-
dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
757+
dcp_local_seq_lens = common_attn_metadata.cp_local_seq_lens
758+
dcp_local_seq_lens_cpu = common_attn_metadata.cp_local_seq_lens_cpu
759759

760760
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
761761

vllm/v1/attention/backends/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,13 @@
4848
def is_valid_kv_cache_layout(value: str) -> bool:
4949
return value in get_args(KVCacheLayoutType)
5050

51+
5152
@dataclass
5253
class PrefillContextParallelMetadata:
5354
"""
5455
Attention metadata for prefill context parallel
5556
"""
57+
5658
allgather_restore_idx: torch.Tensor
5759
"""
5860
We split and concatenate the sequence in a head-tail style,
@@ -62,11 +64,12 @@ class PrefillContextParallelMetadata:
6264
q_tail_indices: torch.Tensor | None = None
6365
q_head_start_loc: torch.Tensor | None = None
6466
kv_for_head_indices: torch.Tensor | None = None
65-
kv_for_tail_indices : torch.Tensor | None = None
67+
kv_for_tail_indices: torch.Tensor | None = None
6668
kv_for_head_indptr: torch.Tensor | None = None
6769
kv_for_tail_indptr: torch.Tensor | None = None
6870
q_full_indices: torch.Tensor | None = None
6971

72+
7073
@dataclass
7174
class CommonAttentionMetadata:
7275
"""
@@ -115,6 +118,7 @@ class CommonAttentionMetadata:
115118

116119
pcp_metadata: PrefillContextParallelMetadata | None = None
117120

121+
118122
def slice_query_start_locs(
119123
query_start_loc: torch.Tensor,
120124
request_slice: slice,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 76 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@
9393
AttentionCGSupport,
9494
AttentionMetadataBuilder,
9595
CommonAttentionMetadata,
96+
PrefillContextParallelMetadata,
9697
create_fast_prefill_custom_backend,
9798
get_cp_local_seq_lens,
98-
PrefillContextParallelMetadata,
9999
reorder_batch_to_split_decodes_and_prefills,
100100
split_attn_metadata,
101101
)
@@ -461,7 +461,9 @@ def __init__(
461461
if self.pcp_world_size > 1:
462462
# Note(qcs): we will pad the tokens of each request
463463
# to a multiple of 2 * pcp_size.
464-
max_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_world_size
464+
max_num_tokens = (
465+
self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_world_size
466+
)
465467
else:
466468
max_num_tokens = self.max_num_tokens
467469
# Persistent buffers for CUDA graphs.
@@ -501,24 +503,15 @@ def __init__(
501503
# Persistent buffers for Prefill Context Parallism
502504
if self.pcp_world_size > 1:
503505
self.pcp_allgather_restore_idx = self._make_buffer(
504-
max_num_tokens,
505-
dtype=torch.int64
506-
)
507-
self.q_head_indices = self._make_buffer(
508-
max_num_tokens,
509-
dtype=torch.int64
510-
)
511-
self.q_tail_indices = self._make_buffer(
512-
max_num_tokens,
513-
dtype=torch.int64
506+
max_num_tokens, dtype=torch.int64
514507
)
508+
self.q_head_indices = self._make_buffer(max_num_tokens, dtype=torch.int64)
509+
self.q_tail_indices = self._make_buffer(max_num_tokens, dtype=torch.int64)
515510
self.kv_for_head_indices = self._make_buffer(
516-
max_num_tokens,
517-
dtype=torch.int64
511+
max_num_tokens, dtype=torch.int64
518512
)
519513
self.kv_for_tail_indices = self._make_buffer(
520-
max_num_tokens,
521-
dtype=torch.int64
514+
max_num_tokens, dtype=torch.int64
522515
)
523516
self.pcp_padded_slot_mapping = torch.empty(
524517
(max_num_tokens,),
@@ -534,15 +527,24 @@ def __init__(
534527
)
535528
self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy()
536529
self.q_indptr_cpu_tensor = torch.zeros(
537-
(self.max_num_reqs + 1,), device="cpu", dtype=torch.int64, pin_memory=True
530+
(self.max_num_reqs + 1,),
531+
device="cpu",
532+
dtype=torch.int64,
533+
pin_memory=True,
538534
)
539535
self.q_indptr_cpu = self.q_indptr_cpu_tensor.numpy()
540536
self.kv_for_head_indptr_cpu_tensor = torch.zeros(
541-
(self.max_num_reqs + 1,), device="cpu", dtype=torch.int64, pin_memory=True
537+
(self.max_num_reqs + 1,),
538+
device="cpu",
539+
dtype=torch.int64,
540+
pin_memory=True,
542541
)
543542
self.kv_for_head_indptr_cpu = self.kv_for_head_indptr_cpu_tensor.numpy()
544543
self.kv_for_tail_indptr_cpu_tensor = torch.zeros(
545-
(self.max_num_reqs + 1,), device="cpu", dtype=torch.int64, pin_memory=True
544+
(self.max_num_reqs + 1,),
545+
device="cpu",
546+
dtype=torch.int64,
547+
pin_memory=True,
546548
)
547549
self.kv_for_tail_indptr_cpu = self.kv_for_tail_indptr_cpu_tensor.numpy()
548550

@@ -1070,45 +1072,53 @@ def _get_pcp_metadata(
10701072
) -> PrefillContextParallelMetadata:
10711073
"""
10721074
During the prefill phrase, the attention computation is divided into
1073-
two parts: q_head and q_tail. Here, we calculate the kv indices
1074-
corresponding to q_head or q_tail. Meawhile, the q and kv indptr are
1075+
two parts: q_head and q_tail. Here, we calculate the kv indices
1076+
corresponding to q_head or q_tail. Meawhile, the q and kv indptr are
10751077
also computed to build the attention wrapper.
10761078
If the pcp_size is 2, the variables are following:
10771079
>>> q_lens [4, 8] kv_lens [8, 16]
10781080
>>> pcp_chunk_sizes[2, 4]
1079-
>>> q_indptr [0, 2, 4]
1081+
>>> q_indptr[0, 2, 4]
10801082
>>> q_head_indices [0, 1, 4, 5, 6, 7] q_tail_indices [2, 3, 8, 9, 10, 11]
10811083
>>> kv_head_len r0 [2, 4] / r1 [4, 8]
10821084
>>> kv_for_head_indptr r0 [0, 2, 6] / r1 [0, 4, 12]
10831085
>>> kv_for_head_indices r0 [0, 1, 8, 9, 10, 11]
1084-
>>> r1 [0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15]
1086+
>>> r1[0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15]
10851087
>>> kv_tail_len r0 [8, 16] / r1 [6, 12]
10861088
>>> kv_for_tail_indptr r0 [0, 8, 24] / r1 [0, 6, 18]
10871089
>>> kv_for_tail_indices r0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 23]
1088-
>>> r1 [0, 1, 2, 3, 4, 5, 8, 9, ..., 19]
1090+
>>> r1[0, 1, 2, 3, 4, 5, 8, 9, ..., 19]
10891091
"""
10901092
if len(q_lens) == 0:
10911093
return PrefillContextParallelMetadata(
10921094
allgather_restore_idx=allgather_restore_idx,
10931095
)
10941096

10951097
def _get_partial_kv_idx(kv_partial_len, kv_partial_indptr, kv_parial_indices):
1096-
kv_partial_indptr[1 : len(kv_partial_len) + 1], kv_partial_arange = self._get_cumsum_and_arange(kv_partial_len)
1097-
kv_parial_indices.np[: kv_partial_arange.shape[0]] = kv_partial_arange + np.repeat(
1098-
kv_start_loc,
1099-
kv_partial_len,
1098+
kv_partial_indptr[1 : len(kv_partial_len) + 1], kv_partial_arange = (
1099+
self._get_cumsum_and_arange(kv_partial_len)
1100+
)
1101+
kv_parial_indices.np[: kv_partial_arange.shape[0]] = (
1102+
kv_partial_arange
1103+
+ np.repeat(
1104+
kv_start_loc,
1105+
kv_partial_len,
1106+
)
11001107
)
11011108
return kv_partial_arange.shape[0]
11021109

11031110
pcp_chunk_sizes = q_lens // 2
1104-
self.q_indptr_cpu[1 : len(pcp_chunk_sizes) + 1], q_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes)
1111+
self.q_indptr_cpu[1 : len(pcp_chunk_sizes) + 1], q_chunk_arange = (
1112+
self._get_cumsum_and_arange(pcp_chunk_sizes)
1113+
)
11051114

11061115
q_head_start_loc = np.roll(np.cumsum(q_lens), 1)
11071116
q_head_start_loc[0] = 0
11081117
self.q_head_indices.np[: q_chunk_arange.shape[0]] = q_chunk_arange + np.repeat(
11091118
q_head_start_loc,
11101119
pcp_chunk_sizes,
11111120
)
1121+
11121122
self.q_head_indices.copy_to_gpu(q_chunk_arange.shape[0])
11131123

11141124
q_tail_start_loc = q_head_start_loc + pcp_chunk_sizes
@@ -1122,17 +1132,25 @@ def _get_partial_kv_idx(kv_partial_len, kv_partial_indptr, kv_parial_indices):
11221132
kv_start_loc[0] = 0
11231133
# kv_for_q_head
11241134
kv_for_head_len = (self.pcp_rank + 1) * pcp_chunk_sizes
1125-
kv_head_tokens_sum = _get_partial_kv_idx(kv_for_head_len, self.kv_for_head_indptr_cpu, self.kv_for_head_indices)
1135+
kv_head_tokens_sum = _get_partial_kv_idx(
1136+
kv_for_head_len,
1137+
self.kv_for_head_indptr_cpu,
1138+
self.kv_for_head_indices,
1139+
)
11261140
self.kv_for_head_indices.copy_to_gpu(kv_head_tokens_sum)
11271141
# kv_for_q_tail
11281142
kv_for_tail_len = (2 * self.pcp_world_size - self.pcp_rank) * pcp_chunk_sizes
1129-
kv_tail_tokens_sum = _get_partial_kv_idx(kv_for_tail_len, self.kv_for_tail_indptr_cpu, self.kv_for_tail_indices)
1143+
kv_tail_tokens_sum = _get_partial_kv_idx(
1144+
kv_for_tail_len,
1145+
self.kv_for_tail_indptr_cpu,
1146+
self.kv_for_tail_indices,
1147+
)
11301148
self.kv_for_tail_indices.copy_to_gpu(kv_tail_tokens_sum)
11311149

11321150
q_full_indices = torch.cat(
11331151
[
11341152
self.q_head_indices.gpu[: q_chunk_arange.shape[0]],
1135-
self.q_tail_indices.gpu[: q_chunk_arange.shape[0]]
1153+
self.q_tail_indices.gpu[: q_chunk_arange.shape[0]],
11361154
]
11371155
).argsort()
11381156

@@ -1141,13 +1159,17 @@ def _get_partial_kv_idx(kv_partial_len, kv_partial_indptr, kv_parial_indices):
11411159
q_head_indices=self.q_head_indices.gpu[: q_chunk_arange.shape[0]],
11421160
q_tail_indices=self.q_tail_indices.gpu[: q_chunk_arange.shape[0]],
11431161
q_head_start_loc=self.q_indptr_cpu_tensor[: len(pcp_chunk_sizes) + 1],
1144-
kv_for_head_indices=self.kv_for_head_indices.gpu[: kv_head_tokens_sum],
1145-
kv_for_tail_indices=self.kv_for_tail_indices.gpu[: kv_tail_tokens_sum],
1146-
kv_for_head_indptr=self.kv_for_head_indptr_cpu_tensor[: len(kv_for_head_len) + 1],
1147-
kv_for_tail_indptr=self.kv_for_tail_indptr_cpu_tensor[: len(kv_for_tail_len) + 1],
1162+
kv_for_head_indices=self.kv_for_head_indices.gpu[:kv_head_tokens_sum],
1163+
kv_for_tail_indices=self.kv_for_tail_indices.gpu[:kv_tail_tokens_sum],
1164+
kv_for_head_indptr=(
1165+
self.kv_for_head_indptr_cpu_tensor[: len(kv_for_head_len) + 1]
1166+
),
1167+
kv_for_tail_indptr=(
1168+
self.kv_for_tail_indptr_cpu_tensor[: len(kv_for_tail_len) + 1]
1169+
),
11481170
q_full_indices=q_full_indices,
11491171
)
1150-
1172+
11511173
def _update_tokens_for_pcp(
11521174
self,
11531175
tokens: np.ndarray,
@@ -1189,8 +1211,15 @@ def _update_tokens_for_pcp(
11891211
self.input_batch.num_computed_tokens_cpu[:num_reqs]
11901212
>= self.input_batch.num_prompt_tokens[:num_reqs]
11911213
)
1214+
else:
1215+
if num_reqs is None or num_decode_reqs is None:
1216+
raise ValueError(
1217+
"num_reqs and num_decode_reqs must be provided for dummy input"
1218+
)
1219+
assert num_reqs is not None
1220+
assert num_decode_reqs is not None
11921221
self.num_pcp_pads_cpu[:num_reqs] = 0
1193-
1222+
11941223
num_decode_tokens = sum(tokens[:num_decode_reqs])
11951224

11961225
num_padded_scheduled_tokens = np.ceil(
@@ -1259,8 +1288,8 @@ def get_current_rank_positions(
12591288
self._get_pcp_metadata(
12601289
pcp_tokens[num_decode_reqs:],
12611290
num_padded_scheduled_tokens[num_decode_reqs:],
1262-
self.pcp_allgather_restore_idx.gpu[: all_positions.shape[0]]
1263-
)
1291+
self.pcp_allgather_restore_idx.gpu[: all_positions.shape[0]],
1292+
),
12641293
)
12651294

12661295
def _get_cumsum_and_arange(
@@ -1471,10 +1500,9 @@ def _prepare_inputs(
14711500

14721501
pcp_metadata = None
14731502
if self.pcp_world_size > 1:
1474-
num_scheduled_tokens[:num_reqs], pcp_positions, pcp_metadata = \
1475-
self._update_tokens_for_pcp(
1476-
num_scheduled_tokens[:num_reqs]
1477-
)
1503+
num_scheduled_tokens[:num_reqs], pcp_positions, pcp_metadata = (
1504+
self._update_tokens_for_pcp(num_scheduled_tokens[:num_reqs])
1505+
)
14781506

14791507
# Re-update after PCP split sequences.
14801508
total_num_scheduled_tokens = sum(num_scheduled_tokens)
@@ -1605,7 +1633,7 @@ def _prepare_inputs(
16051633
if self.pcp_world_size > 1:
16061634
discard_requests_mask = (
16071635
self.input_batch.num_computed_tokens_cpu[:num_reqs]
1608-
+ num_scheduled_tokens * self.pcp_world_size
1636+
+ num_scheduled_tokens * self.pcp_world_size
16091637
- self.num_pcp_pads_cpu[:num_reqs]
16101638
) < num_tokens_np
16111639
else:
@@ -3167,7 +3195,7 @@ def execute_model(
31673195
# NOTE we must `slice` hidden_states because pcp_allgather_restore_idx
31683196
# ignores the padding from CUDA Graph.
31693197
hidden_states = get_pcp_group().all_gather(
3170-
hidden_states[:num_scheduled_tokens_np.sum()],
3198+
hidden_states[: num_scheduled_tokens_np.sum()],
31713199
0,
31723200
)
31733201
hidden_states = torch.index_select(
@@ -4077,13 +4105,14 @@ def _dummy_run(
40774105
pcp_metadata = None
40784106
if self.pcp_world_size > 1 and force_attention:
40794107
num_decode_reqs = sum(num_scheduled_tokens == 1)
4080-
num_scheduled_tokens[:num_reqs], _, pcp_metadata = \
4108+
num_scheduled_tokens[:num_reqs], _, pcp_metadata = (
40814109
self._update_tokens_for_pcp(
40824110
num_scheduled_tokens[:num_reqs],
40834111
dummy_input=True,
40844112
num_reqs=num_reqs,
40854113
num_decode_reqs=num_decode_reqs,
40864114
)
4115+
)
40874116
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
40884117
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
40894118

0 commit comments

Comments
 (0)