From 26e28a67886e79b4d656e6ec42e76a362876b536 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 11 Feb 2025 15:28:51 +0000 Subject: [PATCH 01/20] dont pad with new FA3 Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 54278f5f608e..be8e0fd23b7b 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1254,8 +1254,11 @@ def _forward_prefill( # For MLA the v head dim is smaller than qk head dim so we pad out # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) + v_dim = v.shape[-1] + pad_v = self.vllm_flash_attn_version < 3 + if pad_v: + v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], + value=0) if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context: output = self.triton_fa_func( @@ -1278,7 +1281,7 @@ def _forward_prefill( output = self.flash_attn_varlen_func( q=q, k=k, - v=v_padded, + v=v, cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.query_start_loc, max_seqlen_q=prefill_metadata.max_prefill_seq_len, @@ -1291,7 +1294,7 @@ def _forward_prefill( output = self.flash_attn_varlen_func( q=q, k=k, - v=v_padded, + v=v, cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.query_start_loc, max_seqlen_q=prefill_metadata.max_prefill_seq_len, @@ -1317,10 +1320,11 @@ def _forward_prefill( ) # slice by `:v.shape[-1]` in order to remove v headdim padding - output = output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) + if pad_v: + attn_output = attn_output\ + .view(-1, self.num_heads, q.shape[-1])[..., :v_dim] + attn_output = attn_output.reshape(-1, self.num_heads * v_dim) return self.o_proj(output)[0] @abstractmethod From 7009cf5e1b08f578451c591e54b6e014416f7ca0 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 27 Feb 2025 01:41:41 +0000 Subject: [PATCH 02/20] no pad Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 30 ++++++++++++++++++++------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index be8e0fd23b7b..85c91e3e7c03 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1055,6 +1055,26 @@ def __init__( functools.partial(flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version) + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim for attention backends that do + # not support different headdims + # We don't need to pad V if we are on a hopper system with FA3 + self._pad_v = self.vllm_flash_attn_version is None or not ( + self.vllm_flash_attn_version == 3 + and current_platform.get_device_capability()[0] == 9) + + def flash_attn_varlen_diff_headdims(self, q, k, v, **kwargs): + maybe_padded_v = v + if self._pad_v: + maybe_padded_v = torch.nn.functional.pad( + v, [0, q.shape[-1] - v.shape[-1]], value=0) + # rest in case we have softmax lse output + attn_output, *rest = self._flash_attn_varlen_func( + q, k, maybe_padded_v, **kwargs) + if self._pad_v: + attn_output = attn_output[:, :, :v.shape[-1]], *rest + return attn_output, *rest + def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -1291,7 +1311,7 @@ def _forward_prefill( return_softmax_lse=has_context, ) else: - output = self.flash_attn_varlen_func( + output = self.flash_attn_varlen_diff_headdims( q=q, k=k, v=v, @@ -1319,13 +1339,7 @@ def _forward_prefill( suffix_lse=suffix_lse, ) - # slice by `:v.shape[-1]` in order to remove v headdim padding - if pad_v: - attn_output = attn_output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v_dim] - - attn_output = attn_output.reshape(-1, self.num_heads * v_dim) - return self.o_proj(output)[0] + return self.o_proj(output[0].flatten(start_dim=-2))[0] @abstractmethod def _forward_decode( From af7598c5bbf3159dd82cbad5e4e207e76f237b90 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 27 Feb 2025 01:48:28 +0000 Subject: [PATCH 03/20] minor fixes Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 85c91e3e7c03..388b17d87f64 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1072,7 +1072,7 @@ def flash_attn_varlen_diff_headdims(self, q, k, v, **kwargs): attn_output, *rest = self._flash_attn_varlen_func( q, k, maybe_padded_v, **kwargs) if self._pad_v: - attn_output = attn_output[:, :, :v.shape[-1]], *rest + attn_output = attn_output[..., :v.shape[-1]] return attn_output, *rest def _v_up_proj_and_o_proj(self, x): From 7520a1310fa306bb7d7739d6bdb4583eb4c9abc3 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 27 Feb 2025 21:51:36 +0000 Subject: [PATCH 04/20] bug fix Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 388b17d87f64..e8908a4a8d8d 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1063,17 +1063,25 @@ def __init__( self.vllm_flash_attn_version == 3 and current_platform.get_device_capability()[0] == 9) - def flash_attn_varlen_diff_headdims(self, q, k, v, **kwargs): + def _flash_attn_varlen_diff_headdims(self, q, k, v, **kwargs): maybe_padded_v = v if self._pad_v: maybe_padded_v = torch.nn.functional.pad( v, [0, q.shape[-1] - v.shape[-1]], value=0) - # rest in case we have softmax lse output - attn_output, *rest = self._flash_attn_varlen_func( - q, k, maybe_padded_v, **kwargs) - if self._pad_v: - attn_output = attn_output[..., :v.shape[-1]] - return attn_output, *rest + + attn_out = self._flash_attn_varlen_func(q, k, maybe_padded_v, **kwargs) + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + # only unpack if it is a tuple to avoid unpacking tensors by accident + if isinstance(attn_out, tuple): + attn_out, *rest = attn_out + # unpad if necessary + if self._pad_v: + attn_out = attn_out[..., :v.shape[-1]] + return attn_out, *rest + else: + return attn_out[..., :v.shape[-1]] if self._pad_v else attn_out def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) @@ -1311,7 +1319,7 @@ def _forward_prefill( return_softmax_lse=has_context, ) else: - output = self.flash_attn_varlen_diff_headdims( + output = self._flash_attn_varlen_diff_headdims( q=q, k=k, v=v, @@ -1339,7 +1347,7 @@ def _forward_prefill( suffix_lse=suffix_lse, ) - return self.o_proj(output[0].flatten(start_dim=-2))[0] + return self.o_proj(output.flatten(start_dim=-2))[0] @abstractmethod def _forward_decode( From 8803b3f477b19086da2777ea01cdb625943118a8 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 27 Feb 2025 22:16:23 +0000 Subject: [PATCH 05/20] update FA Signed-off-by: Lucas Wilkinson --- cmake/external_projects/vllm_flash_attn.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index afd7c47e8ac0..0ef462e723f8 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22 + GIT_TAG a582527281dba7ca7fc96c7a2781a173a5ee7674 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn From 841ae573fb919f96301dd38ec4d4258df4b4bd08 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 28 Feb 2025 07:46:37 +0000 Subject: [PATCH 06/20] sync with amd, support v1 Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/mla/common.py | 68 ++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 8c7179ba0a8a..0c95b1b9a3aa 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -207,15 +207,20 @@ try: from vllm.vllm_flash_attn import flash_attn_varlen_func + is_vllm_fa = True except ImportError: # For rocm use upstream flash attention from flash_attn import flash_attn_varlen_func + is_vllm_fa = False +from vllm.attention.ops.triton_flash_attention import triton_attention if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner +is_hip = current_platform.is_rocm() + logger = init_logger(__name__) @@ -626,15 +631,78 @@ def __init__( self.o_proj = o_proj self.vllm_flash_attn_version = get_flash_attn_version() + self.triton_fa_func = triton_attention # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the # latter has an additional parameter to control FA2 vs FA3 self.flash_attn_varlen_func = flash_attn_varlen_func + self.vllm_flash_attn_version = get_flash_attn_version() if self.vllm_flash_attn_version is not None: self.flash_attn_varlen_func = \ functools.partial(flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version) + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim for attention backends that do + # not support different headdims + # We don't need to pad V if we are on a hopper system with FA3 + self._pad_v = self.vllm_flash_attn_version is None or not ( + self.vllm_flash_attn_version == 3 + and current_platform.get_device_capability()[0] == 9) + + def _flash_attn_varlen_diff_headdims(self, + q, + k, + v, + return_softmax_lse=False, + softmax_scale=None, + **kwargs): + maybe_padded_v = v + if self._pad_v: + maybe_padded_v = torch.nn.functional.pad( + v, [0, q.shape[-1] - v.shape[-1]], value=0) + + if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN: + assert return_softmax_lse is False + attn_out = self.triton_fa_func( + q, + k, + maybe_padded_v, + sm_scale=softmax_scale, + **kwargs, + ) + elif is_vllm_fa: + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + else: + # Use return_attn_probs instead of return_softmax_lse for RoCM + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_attn_probs=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + # only unpack if it is a tuple to avoid unpacking tensors by accident + if isinstance(attn_out, tuple): + attn_out, *rest = attn_out + # unpad if necessary + if self._pad_v: + attn_out = attn_out[..., :v.shape[-1]] + return attn_out, *rest + else: + return attn_out[..., :v.shape[-1]] if self._pad_v else attn_out + def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) From f538d1cf65b4332d7e05cfcaf241ea3868358c1e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 4 Mar 2025 01:43:00 +0000 Subject: [PATCH 07/20] review comments Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 119 ++++++++++++-------------- 1 file changed, 57 insertions(+), 62 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index e8908a4a8d8d..cf2c77055b7e 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1069,19 +1069,53 @@ def _flash_attn_varlen_diff_headdims(self, q, k, v, **kwargs): maybe_padded_v = torch.nn.functional.pad( v, [0, q.shape[-1] - v.shape[-1]], value=0) - attn_out = self._flash_attn_varlen_func(q, k, maybe_padded_v, **kwargs) + if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN: + attn_out = self.triton_fa_func( + q, + k, + maybe_padded_v, + sm_scale=softmax_scale, + **kwargs, + ) + elif is_vllm_fa: + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) + else: + # Use return_attn_probs instead of return_softmax_lse for RoCM + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_attn_probs=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) - # Remain consistent with old `flash_attn_varlen_func` where there - # is only one output tensor if `return_softmax_lse` is False. - # only unpack if it is a tuple to avoid unpacking tensors by accident + # Unpack the output if there is multiple results, + # triton always returns (output, softmax_lse), + # vllm_flash_attn returns (output, softmax_lse) when + # `return_softmax_lse = True` + # flash_attn (RoCM) returns (output, softmax_lse, ...) when + # `return_attn_probs = True` + rest = None if isinstance(attn_out, tuple): attn_out, *rest = attn_out - # unpad if necessary - if self._pad_v: - attn_out = attn_out[..., :v.shape[-1]] - return attn_out, *rest - else: - return attn_out[..., :v.shape[-1]] if self._pad_v else attn_out + + # unpad if necessary + if self._pad_v: + attn_out = attn_out[..., :v.shape[-1]] + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + if return_softmax_lse: + return attn_out, rest[0] + return attn_out def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) @@ -1280,61 +1314,22 @@ def _forward_prefill( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_dim = v.shape[-1] - pad_v = self.vllm_flash_attn_version < 3 - if pad_v: - v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context: - output = self.triton_fa_func( - q, - k, - v_padded, - None, - prefill_metadata.query_start_loc, - prefill_metadata.query_start_loc, - prefill_metadata.max_prefill_seq_len, - prefill_metadata.max_prefill_seq_len, - True, # causal - self.scale, - None, # attn_mask is None unless applying ALiBi mask - ) - ## triton flash attention always return 2 objects - if not has_context: - output = output[0] - elif is_vllm_fa: - output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_softmax_lse=has_context, - ) - else: - output = self._flash_attn_varlen_diff_headdims( - q=q, - k=k, - v=v, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_attn_probs=has_context, - ) + output = self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.query_start_loc, + max_seqlen_q=prefill_metadata.max_prefill_seq_len, + max_seqlen_k=prefill_metadata.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=has_context, + ) if has_context: # ROCm flash_attn_varlen_func will return 3 objects instead of 2 - suffix_output, suffix_lse, *rest = output + suffix_output, suffix_lse = output context_output, context_lse = self._compute_prefill_context( \ q, kv_c_and_k_pe_cache, attn_metadata) From 615364fbc92b0c7f3b50fd4a04f3c81010132efa Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 5 Mar 2025 03:42:38 +0000 Subject: [PATCH 08/20] fixes Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 1 + vllm/v1/attention/backends/mla/common.py | 35 +++++++++++++----------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index cf2c77055b7e..eb8b651e3bb7 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1114,6 +1114,7 @@ def _flash_attn_varlen_diff_headdims(self, q, k, v, **kwargs): # Remain consistent with old `flash_attn_varlen_func` where there # is only one output tensor if `return_softmax_lse` is False. if return_softmax_lse: + assert rest is not None return attn_out, rest[0] return attn_out diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0c95b1b9a3aa..ad638c311a0a 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -663,7 +663,6 @@ def _flash_attn_varlen_diff_headdims(self, v, [0, q.shape[-1] - v.shape[-1]], value=0) if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN: - assert return_softmax_lse is False attn_out = self.triton_fa_func( q, k, @@ -691,17 +690,26 @@ def _flash_attn_varlen_diff_headdims(self, **kwargs, ) - # Remain consistent with old `flash_attn_varlen_func` where there - # is only one output tensor if `return_softmax_lse` is False. - # only unpack if it is a tuple to avoid unpacking tensors by accident + # Unpack the output if there is multiple results, + # triton always returns (output, softmax_lse), + # vllm_flash_attn returns (output, softmax_lse) when + # `return_softmax_lse = True` + # flash_attn (RoCM) returns (output, softmax_lse, ...) when + # `return_attn_probs = True` + rest = None if isinstance(attn_out, tuple): attn_out, *rest = attn_out - # unpad if necessary - if self._pad_v: - attn_out = attn_out[..., :v.shape[-1]] - return attn_out, *rest - else: - return attn_out[..., :v.shape[-1]] if self._pad_v else attn_out + + # unpad if necessary + if self._pad_v: + attn_out = attn_out[..., :v.shape[-1]] + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + if return_softmax_lse: + assert rest is not None + return attn_out, rest[0] + return attn_out def _v_up_proj_and_o_proj(self, x): # Convert from (B, N, L) to (N, B, L) @@ -903,12 +911,7 @@ def _forward_prefill( suffix_lse=suffix_lse, ) - # slice by `:v.shape[-1]` in order to remove v headdim padding - output = output\ - .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ - .reshape(-1, self.num_heads * v.shape[-1]) - - return self.o_proj(output)[0] + return self.o_proj(output.flatten(start_dim=-2))[0] @abstractmethod def _forward_decode( From be5c3e4be088427d77eee1f7c51faf44ddc225de Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 5 Mar 2025 15:37:21 +0000 Subject: [PATCH 09/20] update vllm_flash_attn Signed-off-by: Lucas Wilkinson --- cmake/external_projects/vllm_flash_attn.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 0ef462e723f8..9c01aeefb0aa 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG a582527281dba7ca7fc96c7a2781a173a5ee7674 + GIT_TAG eec5715c3ebc85167f1205788d3005bce4c4a931 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn From 417a4eac25bc1e35d61a0676ef5c051d7f72fe05 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 20 Mar 2025 22:19:22 +0000 Subject: [PATCH 10/20] update FA Signed-off-by: Lucas Wilkinson --- cmake/external_projects/vllm_flash_attn.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 9c01aeefb0aa..dd837c770fe3 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG eec5715c3ebc85167f1205788d3005bce4c4a931 + GIT_TAG 6d21ae21bbd9587e1b6f99f7e50eb83f14e1ec4f GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn From f3d862a0d4b82068782948ff8587758ef9fb587b Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sat, 22 Mar 2025 04:59:23 +0000 Subject: [PATCH 11/20] update mla Signed-off-by: Lucas Wilkinson --- cmake/external_projects/vllm_flash_attn.cmake | 4 +- vllm/attention/backends/utils.py | 26 ++++++++- vllm/v1/attention/backends/mla/common.py | 57 +++++++++++++++++-- 3 files changed, 80 insertions(+), 7 deletions(-) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index dd837c770fe3..d7f1eef054d0 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -37,8 +37,8 @@ if(VLLM_FLASH_ATTN_SRC_DIR) else() FetchContent_Declare( vllm-flash-attn - GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 6d21ae21bbd9587e1b6f99f7e50eb83f14e1ec4f + GIT_REPOSITORY https://github.com/neuralmagic/vllm-flash-attention.git + GIT_TAG 38d6823cef5ac15bd710e8ebbbc1c8bc696219b5 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index b4413c36b64a..89f1ea9b8a57 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -2,8 +2,10 @@ """Attention backend utils""" from collections import defaultdict from contextlib import contextmanager +from dataclasses import dataclass from itertools import accumulate -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, + TypeVar, Union) import numpy as np import torch @@ -11,6 +13,7 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, AttentionState) from vllm.attention.backends.abstract import AttentionType +from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad @@ -583,3 +586,24 @@ def get_num_prefill_decode_query_kv_tokens( return (num_prefill_query_tokens, num_prefill_kv_tokens, num_decode_query_tokens) + + +@dataclass +class MLADims: + q_lora_rank: Optional[int] + kv_lora_rank: int + qk_nope_head_dim: int + qk_rope_head_dim: int + v_head_dim: int + + +def get_mla_dims(model_config: ModelConfig) -> MLADims: + hf_text_config = model_config.hf_text_config + + return MLADims( + q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), + kv_lora_rank=hf_text_config.kv_lora_rank, + qk_nope_head_dim=hf_text_config.qk_nope_head_dim, + qk_rope_head_dim=hf_text_config.qk_rope_head_dim, + v_head_dim=hf_text_config.v_head_dim, + ) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index ad638c311a0a..14d9602b15bd 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -195,6 +195,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) + +from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -206,7 +208,8 @@ from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version try: - from vllm.vllm_flash_attn import flash_attn_varlen_func + from vllm.vllm_flash_attn import (flash_attn_varlen_func, + get_scheduler_metadata) is_vllm_fa = True except ImportError: # For rocm use upstream flash attention @@ -355,6 +358,9 @@ def __init__(self, model_config = runner.model_config cache_config = runner.cache_config self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + self.num_heads = model_config.get_num_attention_heads( + runner.parallel_config) + self.mla_dims = get_mla_dims(model_config) if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -471,7 +477,9 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] seq_lens = seq_lens_cpu.to(device, non_blocking=True) - max_query_len = seq_lens_cpu.max().item() + max_seq_len = seq_lens_cpu.max().item() + + aot_schedule = (get_flash_attn_version() == 3) prefill_metadata = None if self._num_prefills > 0: @@ -482,6 +490,25 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, num_computed_tokens_cpu_tensor[reqs_start:num_reqs] max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + + scheduler_metadata = None + if aot_schedule: + scheduler_metadata = get_scheduler_metadata( + batch_size=self._num_prefills, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + cache_seqlens=seq_lens, + num_heads_q=self.num_heads, + num_heads_kv=self.num_heads, + headdim=self.mla_dims.qk_nope_head_dim + + self.mla_dims.qk_rope_head_dim, + headdim_v=self.mla_dims.v_head_dim, + page_size=self.page_size, + cu_seqlens_q=prefill_query_start_loc, + causal=True, + ) chunked_context_metadata = None if self.chunked_prefill_enabled and self._num_prefills > 0 \ @@ -519,6 +546,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + max_chunk_seq_lens = chunk_seq_lens.max(dim=1) cu_seq_lens_cpu = torch.zeros(num_chunks, self._num_prefills + 1, @@ -529,6 +557,26 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32) + scheduler_metadatas = None + if aot_schedule: + scheduler_metadatas = [] + for i in range(num_chunks): + scheduler_metadatas.append( + get_scheduler_metadata( + batch_size=self._num_prefills, + max_seqlen_q=max_query_len, + max_seqlen_k=max_chunk_seq_lens[i], + cache_seqlens=chunk_seq_lens[i], + num_heads_q=self.num_heads, + num_heads_kv=self.num_heads, + headdim=self.mla_dims.qk_nope_head_dim + + self.mla_dims.qk_rope_head_dim, + headdim_v=self.mla_dims.v_head_dim, + page_size=self.page_size, + cu_seqlens_q=prefill_query_start_loc, + causal=False, + )) + chunked_context_metadata = \ MLACommonPrefillMetadata.ChunkedContextMetadata( cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), @@ -536,6 +584,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), workspace=self.chunked_prefill_workspace, + scheduler_metadatas=scheduler_metadatas, ) assert max(chunked_context_metadata.max_seq_lens) <= \ @@ -544,10 +593,10 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, prefill_metadata = MLACommonPrefillMetadata( input_positions=input_positions[tokens_start:], block_table=block_table[reqs_start:, ...], - query_start_loc=query_start_loc[reqs_start:] - - query_start_loc[reqs_start], + query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, + scheduler_metadata=scheduler_metadata, ) decode_metadata = None From b49136750d7dba988dd0b90f32d2177b3afeebb2 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 24 Mar 2025 04:34:24 +0000 Subject: [PATCH 12/20] add scheduler code Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/flash_attn.py | 58 +++++++++++++++++++++++- vllm/v1/attention/backends/mla/common.py | 9 ++-- 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index b4c7708daab9..5e06f9a11384 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -23,7 +23,8 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner if current_platform.is_cuda(): - from vllm.vllm_flash_attn import flash_attn_varlen_func + from vllm.vllm_flash_attn import (flash_attn_varlen_func, + get_scheduler_metadata) logger = init_logger(__name__) @@ -85,6 +86,7 @@ class FlashAttentionMetadata: seq_lens: torch.Tensor block_table: torch.Tensor slot_mapping: torch.Tensor + scheduler_metadata: Optional[torch.Tensor] # For cascade attention. use_cascade: bool @@ -92,6 +94,7 @@ class FlashAttentionMetadata: cu_prefix_query_lens: Optional[torch.Tensor] prefix_kv_lens: Optional[torch.Tensor] suffix_kv_lens: Optional[torch.Tensor] + prefix_scheduler_metadata: Optional[torch.Tensor] # For logging. num_input_tokens: int = 0 # Number of tokens including padding. @@ -277,7 +280,14 @@ def make_local_attention_virtual_batches( class FlashAttentionMetadataBuilder: def __init__(self, runner: "GPUModelRunner"): + model_config = runner.model_config + self.runner = runner + self.aot_schedule = (get_flash_attn_version() == 3) + self.num_heads = model_config.get_num_attention_heads( + runner.parallel_config) + self.headdim = model_config.get_head_size() + self.page_size = self.runner.block_size def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -319,6 +329,25 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, ) use_cascade = common_prefix_len > 0 + + def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len, + causal): + if self.aot_schedule: + scheduler_metadata = get_scheduler_metadata( + batch_size=num_reqs, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + cache_seqlens=seqlens, + num_heads_q=self.num_heads, + num_heads_kv=self.num_heads, + headdim=self.headdim, + page_size=self.page_size, + cu_seqlens_q=cu_query_lens, + causal=causal, + ) + return scheduler_metadata + return None + if use_cascade: cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], dtype=torch.int32, @@ -330,10 +359,28 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len) suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( self.runner.device) + prefix_scheduler_metadata = schedule( + cu_query_lens=cu_prefix_query_lens, + max_query_len=num_actual_tokens, + seqlens=prefix_kv_lens, + max_seq_len=common_prefix_len, + causal=False) + scheduler_metadata = schedule(cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=suffix_kv_lens, + max_seq_len=max_seq_len - + common_prefix_len, + causal=True) else: cu_prefix_query_lens = None prefix_kv_lens = None suffix_kv_lens = None + prefix_scheduler_metadata = None + scheduler_metadata = schedule(cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=seq_lens, + max_seq_len=max_seq_len, + causal=True) attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, @@ -345,10 +392,12 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, + scheduler_metadata=scheduler_metadata, cu_prefix_query_lens=cu_prefix_query_lens, prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, local_attn_metadata=local_attn_metadata, + prefix_scheduler_metadata=prefix_scheduler_metadata, ) return attn_metadata @@ -515,6 +564,7 @@ def forward( window_size=self.sliding_window, block_table=block_table, softcap=self.logits_soft_cap, + scheduler_metadata=attn_metadata.scheduler_metadata, fa_version=self.vllm_flash_attn_version, q_descale=layer._q_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape), @@ -543,6 +593,8 @@ def forward( block_table=attn_metadata.block_table, common_prefix_len=attn_metadata.common_prefix_len, fa_version=self.vllm_flash_attn_version, + prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, + suffix_scheduler_metadata=attn_metadata.scheduler_metadata, q_descale=layer._q_scale, k_descale=layer._k_scale, v_descale=layer._v_scale, @@ -636,6 +688,8 @@ def cascade_attention( block_table: torch.Tensor, common_prefix_len: int, fa_version: int, + prefix_scheduler_metadata: Optional[torch.Tensor] = None, + suffix_scheduler_metadata: Optional[torch.Tensor] = None, q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, @@ -667,6 +721,7 @@ def cascade_attention( block_table=block_table[:1], softcap=logits_soft_cap, return_softmax_lse=True, + scheduler_metadata=prefix_scheduler_metadata, fa_version=fa_version, q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, @@ -693,6 +748,7 @@ def cascade_attention( block_table=block_table[:, num_common_kv_blocks:], softcap=logits_soft_cap, return_softmax_lse=True, + scheduler_metadata=suffix_scheduler_metadata, fa_version=fa_version, q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 14d9602b15bd..d1d915249984 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -361,6 +361,8 @@ def __init__(self, self.num_heads = model_config.get_num_attention_heads( runner.parallel_config) self.mla_dims = get_mla_dims(model_config) + self.aot_schedule = (get_flash_attn_version() == 3) + self.page_size = self.runner.block_size if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -386,7 +388,6 @@ def __init__(self, dtype=model_config.dtype, device=runner.device, ) - self.page_size = self.runner.block_size def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -479,8 +480,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, seq_lens = seq_lens_cpu.to(device, non_blocking=True) max_seq_len = seq_lens_cpu.max().item() - aot_schedule = (get_flash_attn_version() == 3) - prefill_metadata = None if self._num_prefills > 0: reqs_start = self._num_decodes # prefill_start @@ -494,7 +493,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, reqs_start:] - query_start_loc[reqs_start] scheduler_metadata = None - if aot_schedule: + if self.aot_schedule: scheduler_metadata = get_scheduler_metadata( batch_size=self._num_prefills, max_seqlen_q=max_query_len, @@ -558,7 +557,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, dtype=torch.int32) scheduler_metadatas = None - if aot_schedule: + if self.aot_schedule: scheduler_metadatas = [] for i in range(num_chunks): scheduler_metadatas.append( From d6888b43f25b62d13273c747329cbcc58859ebfb Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Mar 2025 21:48:14 +0000 Subject: [PATCH 13/20] cleanups Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/flash_attn.py | 3 +- vllm/v1/attention/backends/mla/common.py | 93 ++++-------------------- 2 files changed, 14 insertions(+), 82 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 5e06f9a11384..9d8b3babc233 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -333,7 +333,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len, causal): if self.aot_schedule: - scheduler_metadata = get_scheduler_metadata( + return get_scheduler_metadata( batch_size=num_reqs, max_seqlen_q=max_query_len, max_seqlen_k=max_seq_len, @@ -345,7 +345,6 @@ def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len, cu_seqlens_q=cu_query_lens, causal=causal, ) - return scheduler_metadata return None if use_cascade: diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index d1d915249984..acafc5c4d245 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -208,8 +208,7 @@ from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version try: - from vllm.vllm_flash_attn import (flash_attn_varlen_func, - get_scheduler_metadata) + from vllm.vllm_flash_attn import flash_attn_varlen_func is_vllm_fa = True except ImportError: # For rocm use upstream flash attention @@ -478,7 +477,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] seq_lens = seq_lens_cpu.to(device, non_blocking=True) - max_seq_len = seq_lens_cpu.max().item() prefill_metadata = None if self._num_prefills > 0: @@ -492,23 +490,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, prefill_query_start_loc = query_start_loc[ reqs_start:] - query_start_loc[reqs_start] - scheduler_metadata = None - if self.aot_schedule: - scheduler_metadata = get_scheduler_metadata( - batch_size=self._num_prefills, - max_seqlen_q=max_query_len, - max_seqlen_k=max_seq_len, - cache_seqlens=seq_lens, - num_heads_q=self.num_heads, - num_heads_kv=self.num_heads, - headdim=self.mla_dims.qk_nope_head_dim + - self.mla_dims.qk_rope_head_dim, - headdim_v=self.mla_dims.v_head_dim, - page_size=self.page_size, - cu_seqlens_q=prefill_query_start_loc, - causal=True, - ) - chunked_context_metadata = None if self.chunked_prefill_enabled and self._num_prefills > 0 \ and max_context_len_cpu > 0: @@ -545,7 +526,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - max_chunk_seq_lens = chunk_seq_lens.max(dim=1) cu_seq_lens_cpu = torch.zeros(num_chunks, self._num_prefills + 1, @@ -556,26 +536,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32) - scheduler_metadatas = None - if self.aot_schedule: - scheduler_metadatas = [] - for i in range(num_chunks): - scheduler_metadatas.append( - get_scheduler_metadata( - batch_size=self._num_prefills, - max_seqlen_q=max_query_len, - max_seqlen_k=max_chunk_seq_lens[i], - cache_seqlens=chunk_seq_lens[i], - num_heads_q=self.num_heads, - num_heads_kv=self.num_heads, - headdim=self.mla_dims.qk_nope_head_dim + - self.mla_dims.qk_rope_head_dim, - headdim_v=self.mla_dims.v_head_dim, - page_size=self.page_size, - cu_seqlens_q=prefill_query_start_loc, - causal=False, - )) - chunked_context_metadata = \ MLACommonPrefillMetadata.ChunkedContextMetadata( cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), @@ -583,7 +543,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), workspace=self.chunked_prefill_workspace, - scheduler_metadatas=scheduler_metadatas, ) assert max(chunked_context_metadata.max_seq_lens) <= \ @@ -595,7 +554,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, - scheduler_metadata=scheduler_metadata, ) decode_metadata = None @@ -710,43 +668,19 @@ def _flash_attn_varlen_diff_headdims(self, maybe_padded_v = torch.nn.functional.pad( v, [0, q.shape[-1] - v.shape[-1]], value=0) - if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN: - attn_out = self.triton_fa_func( - q, - k, - maybe_padded_v, - sm_scale=softmax_scale, - **kwargs, - ) - elif is_vllm_fa: - attn_out = self.flash_attn_varlen_func( - q=q, - k=k, - v=maybe_padded_v, - return_softmax_lse=return_softmax_lse, - softmax_scale=softmax_scale, - **kwargs, - ) - else: - # Use return_attn_probs instead of return_softmax_lse for RoCM - attn_out = self.flash_attn_varlen_func( - q=q, - k=k, - v=maybe_padded_v, - return_attn_probs=return_softmax_lse, - softmax_scale=softmax_scale, - **kwargs, - ) + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + return_softmax_lse=return_softmax_lse, + softmax_scale=softmax_scale, + **kwargs, + ) - # Unpack the output if there is multiple results, - # triton always returns (output, softmax_lse), - # vllm_flash_attn returns (output, softmax_lse) when - # `return_softmax_lse = True` - # flash_attn (RoCM) returns (output, softmax_lse, ...) when - # `return_attn_probs = True` - rest = None + # Unpack the output if there is multiple results + lse = None if isinstance(attn_out, tuple): - attn_out, *rest = attn_out + attn_out, lse = attn_out[0], attn_out[1] # unpad if necessary if self._pad_v: @@ -755,8 +689,7 @@ def _flash_attn_varlen_diff_headdims(self, # Remain consistent with old `flash_attn_varlen_func` where there # is only one output tensor if `return_softmax_lse` is False. if return_softmax_lse: - assert rest is not None - return attn_out, rest[0] + return attn_out, lse return attn_out def _v_up_proj_and_o_proj(self, x): From a8292a3effcf9c73fef16c61c9189ea7da856633 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Mar 2025 23:08:55 +0000 Subject: [PATCH 14/20] update fa Signed-off-by: Lucas Wilkinson --- cmake/external_projects/vllm_flash_attn.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index d7f1eef054d0..44858ea6bb0d 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -37,8 +37,8 @@ if(VLLM_FLASH_ATTN_SRC_DIR) else() FetchContent_Declare( vllm-flash-attn - GIT_REPOSITORY https://github.com/neuralmagic/vllm-flash-attention.git - GIT_TAG 38d6823cef5ac15bd710e8ebbbc1c8bc696219b5 + GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git + GIT_TAG 8cbd1b42140ced0510ce3e6467674b7de39391bf GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn From 11f61eea1707df64dec1720fe1098d70d972d7f3 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 10 Apr 2025 04:22:41 +0000 Subject: [PATCH 15/20] precommit + missing args Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 5 +++-- vllm/v1/attention/backends/flash_attn.py | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index eb8b651e3bb7..af3c5b100072 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -78,7 +78,7 @@ spda_o = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), - v + v ) return spda_o @ W_O @@ -1063,7 +1063,8 @@ def __init__( self.vllm_flash_attn_version == 3 and current_platform.get_device_capability()[0] == 9) - def _flash_attn_varlen_diff_headdims(self, q, k, v, **kwargs): + def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale, + return_softmax_lse, **kwargs): maybe_padded_v = v if self._pad_v: maybe_padded_v = torch.nn.functional.pad( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 9d8b3babc233..c039cd8067f3 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -86,7 +86,6 @@ class FlashAttentionMetadata: seq_lens: torch.Tensor block_table: torch.Tensor slot_mapping: torch.Tensor - scheduler_metadata: Optional[torch.Tensor] # For cascade attention. use_cascade: bool @@ -94,7 +93,10 @@ class FlashAttentionMetadata: cu_prefix_query_lens: Optional[torch.Tensor] prefix_kv_lens: Optional[torch.Tensor] suffix_kv_lens: Optional[torch.Tensor] - prefix_scheduler_metadata: Optional[torch.Tensor] + + # Optional aot scheduling + scheduler_metadata: Optional[torch.Tensor] = None + prefix_scheduler_metadata: Optional[torch.Tensor] = None # For logging. num_input_tokens: int = 0 # Number of tokens including padding. From 2d5dd631c0e1afdea79d5891be7083f3267d4ced Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 10 Apr 2025 04:54:38 +0000 Subject: [PATCH 16/20] mla fixes Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 47 +++++++----------------- vllm/v1/attention/backends/mla/common.py | 20 +++------- 2 files changed, 18 insertions(+), 49 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index af3c5b100072..af32d134e87f 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1240,40 +1240,19 @@ def _compute_prefill_context( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad - # out v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, - [0, q.shape[-1] - v.shape[-1]], - value=0) - - if is_vllm_fa: - attn_output, attn_softmax_lse = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata. - context_chunk_max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, - ) - else: - attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func( - q=q, - k=k, - v=v_padded, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata. - context_chunk_max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_attn_probs=True, - ) + attn_output, attn_softmax_lse = \ + self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) if output is None: output = attn_output diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index acafc5c4d245..4f9287384f04 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -804,16 +804,11 @@ def _compute_prefill_context( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad - # out v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, - [0, q.shape[-1] - v.shape[-1]], - value=0) - - attn_output, attn_softmax_lse = self.flash_attn_varlen_func( + attn_output, attn_softmax_lse = \ + self._flash_attn_varlen_diff_headdims( q=q, k=k, - v=v_padded, + v=v, cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], max_seqlen_q=prefill_metadata.max_query_len, @@ -860,15 +855,10 @@ def _forward_prefill( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - output = self.flash_attn_varlen_func( + output = self._flash_attn_varlen_diff_headdims( q=q, k=k, - v=v_padded, + v=v, cu_seqlens_q=attn_metadata.prefill.query_start_loc, cu_seqlens_k=attn_metadata.prefill.query_start_loc, max_seqlen_q=attn_metadata.prefill.max_query_len, From 286c79f1d87f64300ec91a614023076aad860e07 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 10 Apr 2025 04:59:55 +0000 Subject: [PATCH 17/20] cleanup Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index af32d134e87f..0ea299296b76 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -78,7 +78,7 @@ spda_o = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), - v + v ) return spda_o @ W_O From 4308a5e8726f2df282f5d00d47de01ec44538b25 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 10 Apr 2025 19:02:46 +0000 Subject: [PATCH 18/20] update git hash Signed-off-by: Lucas Wilkinson --- cmake/external_projects/vllm_flash_attn.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 44858ea6bb0d..110ef266c665 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 8cbd1b42140ced0510ce3e6467674b7de39391bf + GIT_TAG 0a721daebe4fa7149f06ecf3d3eabeb6dcd0f1fa GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn From 31be47bb37860aab7c569321790e35939a826e81 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 10 Apr 2025 19:47:34 +0000 Subject: [PATCH 19/20] amd fixes Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 16 +--------------- vllm/v1/attention/backends/mla/common.py | 9 +++++---- 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 0ea299296b76..81b7f1ae2387 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -196,7 +196,6 @@ import torch from vllm import _custom_ops as ops -from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, @@ -216,10 +215,6 @@ from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version -if HAS_TRITON: - from vllm.attention.ops.triton_flash_attention import triton_attention -else: - triton_attention = None try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -1043,7 +1038,6 @@ def __init__( self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj - self.triton_fa_func = triton_attention # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the @@ -1070,15 +1064,7 @@ def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale, maybe_padded_v = torch.nn.functional.pad( v, [0, q.shape[-1] - v.shape[-1]], value=0) - if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN: - attn_out = self.triton_fa_func( - q, - k, - maybe_padded_v, - sm_scale=softmax_scale, - **kwargs, - ) - elif is_vllm_fa: + if is_vllm_fa: attn_out = self.flash_attn_varlen_func( q=q, k=k, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 4f9287384f04..2da5a34e4538 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -214,7 +214,6 @@ # For rocm use upstream flash attention from flash_attn import flash_attn_varlen_func is_vllm_fa = False -from vllm.attention.ops.triton_flash_attention import triton_attention if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -360,8 +359,11 @@ def __init__(self, self.num_heads = model_config.get_num_attention_heads( runner.parallel_config) self.mla_dims = get_mla_dims(model_config) - self.aot_schedule = (get_flash_attn_version() == 3) - self.page_size = self.runner.block_size + self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3) + + # Dont try to access the runner on AMD + if self.aot_schedule: + self.page_size = self.runner.block_size if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -637,7 +639,6 @@ def __init__( self.o_proj = o_proj self.vllm_flash_attn_version = get_flash_attn_version() - self.triton_fa_func = triton_attention # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the # latter has an additional parameter to control FA2 vs FA3 From fcc54d803a5d95cc8d7e9e3a9b22be651485028a Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 11 Apr 2025 14:37:18 +0000 Subject: [PATCH 20/20] more amd tweaks Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/mla/common.py | 14 ++++++++++++++ vllm/v1/attention/backends/mla/common.py | 3 --- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 81b7f1ae2387..2ec771a64557 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -196,6 +196,7 @@ import torch from vllm import _custom_ops as ops +from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, AttentionMetadataBuilder, @@ -215,6 +216,10 @@ from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version +if HAS_TRITON: + from vllm.attention.ops.triton_flash_attention import triton_attention +else: + triton_attention = None try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -1039,6 +1044,7 @@ def __init__( self.kv_b_proj = kv_b_proj self.o_proj = o_proj + self.triton_fa_func = triton_attention # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the # latter has an additional parameter to control FA2 vs FA3 @@ -1064,6 +1070,14 @@ def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale, maybe_padded_v = torch.nn.functional.pad( v, [0, q.shape[-1] - v.shape[-1]], value=0) + if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \ + and not return_softmax_lse: + attn_out = self.triton_fa_func( + q, + k, + maybe_padded_v, + **kwargs, + ) if is_vllm_fa: attn_out = self.flash_attn_varlen_func( q=q, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 2da5a34e4538..6e1512896971 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -195,7 +195,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) - from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.logger import init_logger @@ -220,8 +219,6 @@ from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner -is_hip = current_platform.is_rocm() - logger = init_logger(__name__)