From a8cdabaa52f9894e2839ac77eda8fefacadbfef0 Mon Sep 17 00:00:00 2001 From: qqma Date: Fri, 17 Oct 2025 14:58:53 -0700 Subject: [PATCH 1/3] bugfix for Flash Attention MLA with full cuda graph IMA following pr-25490 Signed-off-by: qqma --- .../attention/backends/mla/flashattn_mla.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 71f5473bc9de..88454e75aa03 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -89,10 +89,9 @@ def __init__( self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) + self.max_cudagraph_size = self.compilation_config.max_capture_size if self.use_full_cuda_graph and self.fa_aot_schedule: - self.max_cudagraph_size = self.compilation_config.max_capture_size - if self.max_cudagraph_size > 992: # This condition derives from FA3's internal heuristic. # TODO(woosuk): Support larger cudagraph sizes. @@ -114,7 +113,7 @@ def __init__( self.max_num_splits = 1 def _schedule_decode( - self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal, max_num_splits ): if self.fa_aot_schedule: return get_scheduler_metadata( @@ -130,7 +129,7 @@ def _schedule_decode( page_size=self.page_size, cu_seqlens_q=cu_query_lens, causal=causal, - num_splits=self.max_num_splits, + num_splits=max_num_splits, ) return None @@ -148,6 +147,17 @@ def _build_decode( max_query_len = query_lens_cpu.max().item() max_seq_len = seq_lens_device.max().item() + + # For Flash Attention MLA + full cudagraph + max_num_splits = 0 + if self.use_full_cuda_graph and \ + num_decode_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + scheduler_metadata = self._schedule_decode( num_reqs=seq_lens_cpu.numel(), cu_query_lens=query_start_loc_device, @@ -155,10 +165,9 @@ def _build_decode( seqlens=seq_lens_device, max_seq_len=max_seq_len, causal=True, + max_num_splits=max_num_splits, ) - # For FA3 + full cudagraph - max_num_splits = 0 if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] # Ensure the persistent buffer is large enough @@ -174,13 +183,6 @@ def _build_decode( self.scheduler_metadata[n:] = 0 scheduler_metadata = self.scheduler_metadata[:n] - if num_decode_tokens <= self.max_cudagraph_size: - # NOTE(woosuk): Setting num_splits > 1 may increase the memory - # usage, because the intermediate buffers of size [num_splits, - # num_heads, num_tokens, head_size] are allocated. Therefore, - # we only set num_splits when using cuda graphs. - max_num_splits = self.max_num_splits - if vllm_is_batch_invariant(): max_num_splits = 1 From 376c203932025dca2baa8a86135a4e0ce03f1160 Mon Sep 17 00:00:00 2001 From: qqma Date: Fri, 17 Oct 2025 15:36:56 -0700 Subject: [PATCH 2/3] fix linting check Signed-off-by: qqma --- vllm/v1/attention/backends/mla/flashattn_mla.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 88454e75aa03..d094e67d3804 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -113,7 +113,14 @@ def __init__( self.max_num_splits = 1 def _schedule_decode( - self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal, max_num_splits + self, + num_reqs, + cu_query_lens, + max_query_len, + seqlens, + max_seq_len, + causal, + max_num_splits, ): if self.fa_aot_schedule: return get_scheduler_metadata( @@ -150,8 +157,7 @@ def _build_decode( # For Flash Attention MLA + full cudagraph max_num_splits = 0 - if self.use_full_cuda_graph and \ - num_decode_tokens <= self.max_cudagraph_size: + if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size: # NOTE(woosuk): Setting num_splits > 1 may increase the memory # usage, because the intermediate buffers of size [num_splits, # num_heads, num_tokens, head_size] are allocated. Therefore, From 8c2514577e3a38315a26fe180f8374be964dc8d3 Mon Sep 17 00:00:00 2001 From: qqma Date: Fri, 17 Oct 2025 15:40:59 -0700 Subject: [PATCH 3/3] fix linting check Signed-off-by: qqma --- vllm/v1/attention/backends/mla/flashattn_mla.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index d094e67d3804..18e5908d6ef1 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -154,7 +154,6 @@ def _build_decode( max_query_len = query_lens_cpu.max().item() max_seq_len = seq_lens_device.max().item() - # For Flash Attention MLA + full cudagraph max_num_splits = 0 if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size: