From 95d6186dddc3f097351b458bd3a85f3f099b049e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 1 Jul 2025 00:03:49 +0000 Subject: [PATCH] [Misc] Add num_splits input arg to flash_attn_varlen_func Signed-off-by: Woosuk Kwon --- vllm_flash_attn/flash_attn_interface.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index cfeda8520e..ba21c49d4d 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -142,6 +142,7 @@ def flash_attn_varlen_func( q_descale=None, k_descale=None, v_descale=None, + num_splits: int = 0, # Version selector fa_version: int = DEFAULT_FA_VERSION, ): @@ -224,6 +225,8 @@ def flash_attn_varlen_func( "FA2 does not support scheduler_metadata, q_descale, " "k_descale, v_descale" ) + if num_splits > 1: + raise NotImplementedError("FA2 does not support num_splits > 1") out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd( q, k, v, out, @@ -270,7 +273,7 @@ def flash_attn_varlen_func( softcap, True, # rotary_interleaved scheduler_metadata, - 0, # num_splits + num_splits, None, # pack_gqa 0, # sm_margin )