1717 MLACommonImpl ,
1818 MLACommonMetadata ,
1919 MLACommonMetadataBuilder )
20+ from vllm .v1 .attention .backends .utils import AttentionCGSupport
2021from vllm .v1 .kv_cache_interface import AttentionSpec
2122from vllm .vllm_flash_attn import flash_attn_varlen_func , get_scheduler_metadata
2223
2324logger = init_logger (__name__ )
2425
26+ # NOTE(matt): This is an arbitrary number, copied from
27+ # woosuk's implementation in standard FlashAttention backend
28+ _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
29+
2530
2631class FlashAttnMLABackend (MLACommonBackend ):
2732
@@ -48,6 +53,7 @@ class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
4853 max_query_len : int
4954 max_seq_len : int
5055 scheduler_metadata : Optional [torch .Tensor ] = None
56+ max_num_splits : int = 0
5157
5258
5359@dataclass
@@ -57,14 +63,41 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
5763
5864class FlashAttnMLAMetadataBuilder (
5965 MLACommonMetadataBuilder [FlashAttnMLAMetadata ]):
66+ cudagraph_support : ClassVar [AttentionCGSupport ] = \
67+ AttentionCGSupport .UNIFORM_BATCH
68+
6069 reorder_batch_threshold : ClassVar [int ] = 512
6170
6271 def __init__ (self , kv_cache_spec : AttentionSpec , layer_names : list [str ],
6372 vllm_config : VllmConfig , device : torch .device ):
6473 super ().__init__ (kv_cache_spec , layer_names , vllm_config , device ,
6574 FlashAttnMLAMetadata )
75+ self .max_num_splits = 0 # No upper bound on the number of splits.
6676 self .fa_aot_schedule = (get_flash_attn_version () == 3 )
6777
78+ self .use_full_cuda_graph = \
79+ self .compilation_config .cudagraph_mode .has_full_cudagraphs ()
80+
81+ if self .use_full_cuda_graph and self .fa_aot_schedule :
82+ self .max_cudagraph_size = self .compilation_config .max_capture_size
83+
84+ if self .max_cudagraph_size > 992 :
85+ # This condition derives from FA3's internal heuristic.
86+ # TODO(woosuk): Support larger cudagraph sizes.
87+ raise ValueError (
88+ "Capture size larger than 992 is not supported for "
89+ "full cuda graph." )
90+
91+ self .scheduler_metadata = torch .zeros (
92+ vllm_config .scheduler_config .max_num_seqs + 1 ,
93+ dtype = torch .int32 ,
94+ device = self .device ,
95+ )
96+ # When using cuda graph, we need to set the upper bound of the
97+ # number of splits so that large enough intermediate buffers are
98+ # pre-allocated during capture.
99+ self .max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
100+
68101 def _schedule_decode (self , num_reqs , cu_query_lens , max_query_len , seqlens ,
69102 max_seq_len , causal ):
70103 if self .fa_aot_schedule :
@@ -81,14 +114,16 @@ def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
81114 page_size = self .page_size ,
82115 cu_seqlens_q = cu_query_lens ,
83116 causal = causal ,
117+ num_splits = self .max_num_splits ,
84118 )
85119 return None
86120
87- def _build_decode (
88- self , block_table_tensor : torch .Tensor , seq_lens_cpu : torch .Tensor ,
89- seq_lens_device : torch .Tensor , query_start_loc_cpu : torch .Tensor ,
90- query_start_loc_device : torch .Tensor
91- ) -> FlashAttnMLADecodeMetadata :
121+ def _build_decode (self , block_table_tensor : torch .Tensor ,
122+ seq_lens_cpu : torch .Tensor ,
123+ seq_lens_device : torch .Tensor ,
124+ query_start_loc_cpu : torch .Tensor ,
125+ query_start_loc_device : torch .Tensor ,
126+ num_decode_tokens : int ) -> FlashAttnMLADecodeMetadata :
92127 query_lens_cpu = (query_start_loc_cpu [1 :] - query_start_loc_cpu [:- 1 ])
93128 max_query_len = query_lens_cpu .max ().item ()
94129 max_seq_len = seq_lens_cpu .max ().item ()
@@ -102,13 +137,37 @@ def _build_decode(
102137 causal = True ,
103138 )
104139
140+ # For FA3 + full cudagraph
141+ max_num_splits = 0
142+ if self .use_full_cuda_graph and scheduler_metadata is not None :
143+ n = scheduler_metadata .shape [0 ]
144+ # Ensure the persistent buffer is large enough
145+ assert n <= self .scheduler_metadata .shape [0 ], \
146+ f"Scheduler metadata size { n } exceeds buffer size " + \
147+ f"{ self .scheduler_metadata .shape [0 ]} "
148+ self .scheduler_metadata [:n ] = scheduler_metadata
149+ # NOTE(woosuk): We should zero out the rest of the scheduler
150+ # metadata to guarantee the correctness. Otherwise, some thread
151+ # blocks may use the invalid scheduler metadata and overwrite the
152+ # output buffer.
153+ self .scheduler_metadata [n :] = 0
154+ scheduler_metadata = self .scheduler_metadata [:n ]
155+
156+ if num_decode_tokens <= self .max_cudagraph_size :
157+ # NOTE(woosuk): Setting num_splits > 1 may increase the memory
158+ # usage, because the intermediate buffers of size [num_splits,
159+ # num_heads, num_tokens, head_size] are allocated. Therefore,
160+ # we only set num_splits when using cuda graphs.
161+ max_num_splits = self .max_num_splits
162+
105163 return FlashAttnMLADecodeMetadata (
106164 block_table = block_table_tensor ,
107165 seq_lens = seq_lens_device ,
108166 query_start_loc = query_start_loc_device ,
109167 max_query_len = max_query_len ,
110168 max_seq_len = max_seq_len ,
111169 scheduler_metadata = scheduler_metadata ,
170+ max_num_splits = max_num_splits ,
112171 )
113172
114173
@@ -175,12 +234,17 @@ def _forward_decode(
175234 kv_c_cache = kv_c_and_k_pe_cache [..., :self .kv_lora_rank ]
176235 k_pe_cache = kv_c_and_k_pe_cache [..., self .kv_lora_rank :]
177236
237+ # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
238+ # kernel uses this to calculate grid dimensions. Ensure it's at least 1
239+ # to prevent invalid grid configuration during graph capture.
240+ max_seqlen_q = max (attn_metadata .decode .max_query_len , 1 )
241+
178242 o = flash_attn_varlen_func (
179243 q = q_pe ,
180244 k = k_pe_cache .unsqueeze (- 2 ), # Add head dim of 1
181245 v = kv_c_cache .unsqueeze (- 2 ), # Add head dim of 1
182246 q_v = q_nope ,
183- max_seqlen_q = attn_metadata . decode . max_query_len ,
247+ max_seqlen_q = max_seqlen_q ,
184248 cu_seqlens_q = attn_metadata .decode .query_start_loc ,
185249 max_seqlen_k = attn_metadata .decode .max_seq_len ,
186250 seqused_k = attn_metadata .decode .seq_lens ,
@@ -189,6 +253,7 @@ def _forward_decode(
189253 causal = True ,
190254 fa_version = 3 , # only version 3 is supported
191255 scheduler_metadata = attn_metadata .decode .scheduler_metadata ,
256+ num_splits = attn_metadata .decode .max_num_splits ,
192257 )
193258
194259 return self ._v_up_proj (o )
0 commit comments