Skip to content

Commit 3f73536

Browse files
Merge pull request #2 from pisceskkk/long_seq_dev
support cp for flashinfer-GQA
2 parents f8afd97 + 4c4878c commit 3f73536

File tree

7 files changed

+388
-59
lines changed

7 files changed

+388
-59
lines changed

vllm/attention/backends/abstract.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,15 @@ def __new__(cls, *args, **kwargs):
275275
# DCP might not be initialized in testing
276276
self.dcp_world_size = 1
277277
self.dcp_rank = 0
278+
try:
279+
from vllm.distributed.parallel_state import get_cp_group
280+
self.cp_world_size = get_cp_group().world_size
281+
self.cp_rank = get_cp_group().rank_in_group
282+
except AssertionError:
283+
# CP might not be initialized in testing
284+
self.cp_world_size = 1
285+
self.cp_rank = 0
286+
278287
self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
279288
and self.can_return_lse_for_decode
280289
return self

vllm/attention/ops/common.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,28 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
137137
assert out.is_contiguous()
138138
out = cp_group.reduce_scatter(out, dim=1)
139139
return out
140+
141+
def cp_lse_ag_out_ar(cp_attn_out: torch.Tensor,
142+
cp_attn_lse: torch.Tensor,
143+
cp_group: GroupCoordinator,
144+
ctx: CPTritonContext = None):
145+
"""
146+
cp_attn_out: [ B, H, D ]
147+
cp_attn_lse: [ B, H ]
148+
"""
149+
if cp_group.world_size == 1:
150+
return cp_attn_out
151+
152+
if ctx is None:
153+
ctx = CPTritonContext()
154+
155+
lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape,
156+
dtype=cp_attn_lse.dtype,
157+
device=cp_attn_lse.device)
158+
159+
cp_attn_lse = cp_attn_lse.contiguous()
160+
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
161+
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
162+
assert out.is_contiguous()
163+
out = cp_group.all_reduce(out)
164+
return out

vllm/v1/attention/backends/flashinfer.py

Lines changed: 179 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
1212
BatchPrefillWithPagedKVCacheWrapper,
13+
BatchPrefillWithRaggedKVCacheWrapper,
1314
MultiLevelCascadeAttentionWrapper)
1415
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
1516
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
@@ -18,8 +19,10 @@
1819
from vllm import _custom_ops as ops
1920
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2021
AttentionType)
22+
from vllm.attention.ops.common import cp_lse_ag_out_ar
2123
from vllm.config import CUDAGraphMode, VllmConfig
2224
from vllm.logger import init_logger
25+
from vllm.distributed.parallel_state import get_cp_group
2326
from vllm.model_executor.layers.quantization.utils.quant_utils import (
2427
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
2528
from vllm.platforms import current_platform
@@ -234,6 +237,9 @@ class FlashInferMetadata:
234237

235238
qo_indptr_gpu: Optional[torch.Tensor] = None
236239
paged_kv_indptr_gpu: Optional[torch.Tensor] = None
240+
241+
# For context parallel
242+
cp_kv_recover_idx: Optional[torch.Tensor] = None
237243

238244

239245
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
@@ -256,8 +262,9 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
256262
self.kv_cache_spec.block_size)
257263
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
258264
max_num_pages = max_num_reqs * max_num_pages_per_req
265+
# NOTE(qcs): Context Parallel do not support graph mode now
259266
self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\
260-
decode_mode() == CUDAGraphMode.FULL)
267+
decode_mode() == CUDAGraphMode.FULL and self.cp_world_size == 1)
261268
if self.enable_cuda_graph:
262269
# For full cudagraph capture, one `decode_wrapper` for each batch
263270
# size is needed for FlashInfer.
@@ -266,6 +273,14 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
266273
self._decode_cudagraph_max_bs = min(
267274
max_num_reqs, self.compilation_config.max_capture_size)
268275

276+
try:
277+
self.cp_world_size = get_cp_group().world_size
278+
self.cp_rank = get_cp_group().rank_in_group
279+
except AssertionError:
280+
# CP might not be initialized in testing
281+
self.cp_world_size = 1
282+
self.cp_rank = 0
283+
269284
self.num_qo_heads = self.model_config.get_num_attention_heads(
270285
self.vllm_config.parallel_config)
271286
self.num_kv_heads = self.kv_cache_spec.num_kv_heads
@@ -348,8 +363,12 @@ def _get_workspace_buffer(self):
348363

349364
def _get_prefill_wrapper(self):
350365
if self._prefill_wrapper is None:
351-
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
352-
self._get_workspace_buffer(), get_kv_cache_layout())
366+
if self.cp_world_size > 1:
367+
self._prefill_wrapper = BatchPrefillWithRaggedKVCacheWrapper(
368+
self._get_workspace_buffer(), get_kv_cache_layout())
369+
else:
370+
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
371+
self._get_workspace_buffer(), get_kv_cache_layout())
353372
return self._prefill_wrapper
354373

355374
def _get_decode_wrapper(self,
@@ -413,7 +432,12 @@ def build(self,
413432
max_seq_len = common_attn_metadata.max_seq_len
414433
seq_lens = common_attn_metadata.seq_lens
415434
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
435+
if self.cp_world_size > 1:
436+
seq_lens_cpu = seq_lens_cpu // \
437+
self.cp_world_size + (self.cp_rank < seq_lens_cpu \
438+
% self.cp_world_size)
416439
seq_lens_np = seq_lens_cpu.numpy()
440+
num_computed_tokens_cpu = common_attn_metadata.num_computed_tokens_cpu
417441
block_table_tensor = common_attn_metadata.block_table_tensor
418442

419443
num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size
@@ -495,6 +519,13 @@ def build(self,
495519
self.cache_dtype,
496520
self.q_data_type,
497521
has_sinks=self.has_sinks)
522+
523+
if self.cp_world_size > 1 and (prefill_use_trtllm
524+
or decode_use_trtllm):
525+
raise NotImplementedError(
526+
"Trtllm not support lse, please use flash attention "
527+
"or disable attention sinks.")
528+
498529
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
499530
raise NotImplementedError(
500531
"FlashInfer backend currently does not support attention "
@@ -521,6 +552,7 @@ def build(self,
521552
num_prefills=num_prefills,
522553
num_prefill_tokens=num_prefill_tokens,
523554
use_cascade=use_cascade,
555+
cp_kv_recover_idx=common_attn_metadata.cp_kv_recover_idx,
524556
)
525557

526558
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
@@ -567,23 +599,69 @@ def build(self,
567599
qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[
568600
prefill_start]
569601
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
602+
prefill_num_computed_tokens_cpu = num_computed_tokens_cpu[prefill_start:]
570603
if not attn_metadata.prefill_use_trtllm:
571-
attn_metadata.prefill_wrapper.plan(
572-
qo_indptr_cpu,
573-
paged_kv_indptr_cpu,
574-
paged_kv_indices,
575-
paged_kv_last_page_len_cpu[prefill_start:],
576-
self.num_qo_heads,
577-
self.num_kv_heads,
578-
self.head_dim,
579-
self.page_size,
580-
causal=True,
581-
sm_scale=self.sm_scale,
582-
window_left=self.window_left,
583-
logits_soft_cap=self.logits_soft_cap,
584-
q_data_type=self.q_data_type,
585-
kv_data_type=self.kv_cache_dtype,
586-
)
604+
if self.cp_world_size > 1:
605+
# NOTE(qcs): no chunked prefill and prefix caching
606+
kv_indptr_cpu = qo_indptr_cpu * self.cp_world_size
607+
# init custom mask for head-tail query order
608+
mask_arr = []
609+
q_pos = common_attn_metadata.query_positions
610+
for i in range(num_prefills):
611+
# |----<C>-----|-<Q0>-|-<Q1>-|
612+
# |---<C+Q*cp_world_size>----|
613+
# cp_world_size = 2
614+
# Q = 2
615+
# C = 8
616+
# cur_q_pos = [0,3]
617+
# context_mask_i.shape = (2, 8)
618+
# upper = [0,1,2,3]
619+
# local_mask_i = [[True, False, False, False],
620+
# [True, True, True, True]] # size=(2, 4)
621+
# mask_i.shape = (2, 12)
622+
cur_q_pos = torch.from_numpy(q_pos[qo_indptr_cpu[i]:qo_indptr_cpu[i+1]])
623+
Q = len(cur_q_pos)
624+
C = prefill_num_computed_tokens_cpu[i]
625+
if Q <= 0:
626+
mask_arr.append(torch.zeros(0, dtype=torch.bool))
627+
continue
628+
context_mask_i = torch.ones((Q, C), dtype=torch.bool)
629+
upper = torch.arange(Q*self.cp_world_size)
630+
local_mask_i = (upper.unsqueeze(0) <= cur_q_pos.unsqueeze(1))
631+
mask_i = torch.cat([context_mask_i, local_mask_i], dim=1)
632+
mask_arr.append(mask_i.flatten())
633+
custom_mask = torch.cat(mask_arr, dim=0).to(self.device)
634+
635+
attn_metadata.prefill_wrapper.plan(
636+
qo_indptr_cpu.to(self.device),
637+
kv_indptr_cpu.to(self.device),
638+
self.num_qo_heads,
639+
self.num_kv_heads,
640+
self.head_dim,
641+
custom_mask=custom_mask,
642+
sm_scale=self.sm_scale,
643+
window_left=self.window_left,
644+
logits_soft_cap=self.logits_soft_cap,
645+
q_data_type=self.q_data_type,
646+
kv_data_type=self.kv_cache_dtype,
647+
)
648+
else:
649+
attn_metadata.prefill_wrapper.plan(
650+
qo_indptr_cpu,
651+
paged_kv_indptr_cpu,
652+
paged_kv_indices,
653+
paged_kv_last_page_len_cpu[prefill_start:],
654+
self.num_qo_heads,
655+
self.num_kv_heads,
656+
self.head_dim,
657+
self.page_size,
658+
causal=True,
659+
sm_scale=self.sm_scale,
660+
window_left=self.window_left,
661+
logits_soft_cap=self.logits_soft_cap,
662+
q_data_type=self.q_data_type,
663+
kv_data_type=self.kv_cache_dtype,
664+
)
587665
else:
588666
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
589667
self.device, non_blocking=True)
@@ -644,6 +722,8 @@ def use_cascade_attention(self, *args, **kwargs) -> bool:
644722
# TODO: The cascade wrapper currently does not support setting
645723
# kv cache dtype to something different from query dtype.
646724
return False
725+
if self.cp_world_size > 1:
726+
return False
647727
return use_cascade_attention(*args, **kwargs)
648728

649729

@@ -803,16 +883,17 @@ def forward(
803883
# and value[:num_actual_tokens] because the reshape_and_cache_flash
804884
# op uses the slot_mapping's shape to determine the number of
805885
# actual tokens.
806-
torch.ops._C_cache_ops.reshape_and_cache_flash(
807-
key,
808-
value,
809-
kv_cache[:, 0],
810-
kv_cache[:, 1],
811-
attn_metadata.slot_mapping,
812-
self.kv_cache_dtype,
813-
layer._k_scale,
814-
layer._v_scale,
815-
)
886+
if self.cp_world_size == 1:
887+
torch.ops._C_cache_ops.reshape_and_cache_flash(
888+
key,
889+
value,
890+
kv_cache[:, 0],
891+
kv_cache[:, 1],
892+
attn_metadata.slot_mapping,
893+
self.kv_cache_dtype,
894+
layer._k_scale,
895+
layer._v_scale,
896+
)
816897

817898
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
818899
# to process the cache when the kv_cache_dtype is fp8
@@ -847,18 +928,50 @@ def forward(
847928
assert prefill_wrapper is not None
848929

849930
if not attn_metadata.prefill_use_trtllm:
850-
assert prefill_wrapper._causal
851931
assert prefill_wrapper._window_left == self.window_left
852932
assert prefill_wrapper._logits_soft_cap == (
853933
self.logits_soft_cap or 0.0)
854934
assert prefill_wrapper._sm_scale == self.scale
855-
prefill_wrapper.run(
856-
prefill_query,
857-
kv_cache_permute,
858-
k_scale=layer._k_scale_float,
859-
v_scale=layer._v_scale_float,
860-
out=output[num_decode_tokens:],
861-
)
935+
if self.cp_world_size > 1:
936+
key_across_cp = get_cp_group().all_gather(
937+
key[num_decode_tokens:].contiguous(), dim=0)
938+
value_across_cp = get_cp_group().all_gather(
939+
value[num_decode_tokens:].contiguous(), dim=0)
940+
key_across_cp = torch.index_select(
941+
key_across_cp, 0,
942+
attn_metadata.cp_kv_recover_idx
943+
)
944+
value_across_cp = torch.index_select(
945+
value_across_cp, 0,
946+
attn_metadata.cp_kv_recover_idx
947+
)
948+
torch.ops._C_cache_ops.reshape_and_cache_flash(
949+
key_across_cp,
950+
value_across_cp,
951+
kv_cache[:, 0],
952+
kv_cache[:, 1],
953+
attn_metadata.slot_mapping[num_decode_tokens:],
954+
self.kv_cache_dtype,
955+
layer._k_scale,
956+
layer._v_scale,
957+
)
958+
# TODO(qcs): 考虑 chunked prefill/ prefix cache 情况下
959+
# kvcache的获取与拼接
960+
prefill_wrapper.run(
961+
prefill_query,
962+
key_across_cp,
963+
value_across_cp,
964+
out=output[num_decode_tokens:],
965+
)
966+
else:
967+
assert prefill_wrapper._causal
968+
prefill_wrapper.run(
969+
prefill_query,
970+
kv_cache_permute,
971+
k_scale=layer._k_scale_float,
972+
v_scale=layer._v_scale_float,
973+
out=output[num_decode_tokens:],
974+
)
862975
else:
863976
# prefill_query may be non-contiguous
864977
prefill_query = prefill_query.contiguous()
@@ -933,13 +1046,35 @@ def forward(
9331046
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap
9341047
or 0.0)
9351048
assert decode_wrapper._sm_scale == self.scale
936-
decode_wrapper.run(
937-
decode_query,
938-
kv_cache_permute,
939-
k_scale=layer._k_scale_float,
940-
v_scale=layer._v_scale_float,
941-
out=output[:num_decode_tokens],
942-
)
1049+
if self.cp_world_size > 1:
1050+
torch.ops._C_cache_ops.reshape_and_cache_flash(
1051+
key[:num_decode_tokens],
1052+
value[:num_decode_tokens],
1053+
kv_cache[:, 0],
1054+
kv_cache[:, 1],
1055+
attn_metadata.slot_mapping[:num_decode_tokens],
1056+
self.kv_cache_dtype,
1057+
layer._k_scale,
1058+
layer._v_scale,
1059+
)
1060+
kv_cache_permute = kv_cache.permute(*stride_order)
1061+
out, lse = decode_wrapper.run(
1062+
decode_query,
1063+
kv_cache_permute,
1064+
k_scale=layer._k_scale_float,
1065+
v_scale=layer._v_scale_float,
1066+
return_lse=True,
1067+
)
1068+
output[:num_decode_tokens] =\
1069+
cp_lse_ag_out_ar(out, lse, get_cp_group())
1070+
else:
1071+
decode_wrapper.run(
1072+
decode_query,
1073+
kv_cache_permute,
1074+
k_scale=layer._k_scale_float,
1075+
v_scale=layer._v_scale_float,
1076+
out=output[:num_decode_tokens],
1077+
)
9431078
else:
9441079
# decode_query may be non-contiguous
9451080
decode_query = decode_query.contiguous()

0 commit comments

Comments
 (0)