@@ -1043,8 +1043,8 @@ def __init__(
10431043 self .q_proj = q_proj
10441044 self .kv_b_proj = kv_b_proj
10451045 self .o_proj = o_proj
1046- self .triton_fa_func = triton_attention
10471046
1047+ self .triton_fa_func = triton_attention
10481048 # Handle the differences between the flash_attn_varlen from flash_attn
10491049 # and the one from vllm_flash_attn. The former is used on RoCM and the
10501050 # latter has an additional parameter to control FA2 vs FA3
@@ -1055,6 +1055,70 @@ def __init__(
10551055 functools .partial (flash_attn_varlen_func ,
10561056 fa_version = self .vllm_flash_attn_version )
10571057
1058+ # For MLA the v head dim is smaller than qk head dim so we pad out
1059+ # v with 0s to match the qk head dim for attention backends that do
1060+ # not support different headdims
1061+ # We don't need to pad V if we are on a hopper system with FA3
1062+ self ._pad_v = self .vllm_flash_attn_version is None or not (
1063+ self .vllm_flash_attn_version == 3
1064+ and current_platform .get_device_capability ()[0 ] == 9 )
1065+
1066+ def _flash_attn_varlen_diff_headdims (self , q , k , v , softmax_scale ,
1067+ return_softmax_lse , ** kwargs ):
1068+ maybe_padded_v = v
1069+ if self ._pad_v :
1070+ maybe_padded_v = torch .nn .functional .pad (
1071+ v , [0 , q .shape [- 1 ] - v .shape [- 1 ]], value = 0 )
1072+
1073+ if is_hip and envs .VLLM_USE_TRITON_FLASH_ATTN \
1074+ and not return_softmax_lse :
1075+ attn_out = self .triton_fa_func (
1076+ q ,
1077+ k ,
1078+ maybe_padded_v ,
1079+ ** kwargs ,
1080+ )
1081+ if is_vllm_fa :
1082+ attn_out = self .flash_attn_varlen_func (
1083+ q = q ,
1084+ k = k ,
1085+ v = maybe_padded_v ,
1086+ return_softmax_lse = return_softmax_lse ,
1087+ softmax_scale = softmax_scale ,
1088+ ** kwargs ,
1089+ )
1090+ else :
1091+ # Use return_attn_probs instead of return_softmax_lse for RoCM
1092+ attn_out = self .flash_attn_varlen_func (
1093+ q = q ,
1094+ k = k ,
1095+ v = maybe_padded_v ,
1096+ return_attn_probs = return_softmax_lse ,
1097+ softmax_scale = softmax_scale ,
1098+ ** kwargs ,
1099+ )
1100+
1101+ # Unpack the output if there is multiple results,
1102+ # triton always returns (output, softmax_lse),
1103+ # vllm_flash_attn returns (output, softmax_lse) when
1104+ # `return_softmax_lse = True`
1105+ # flash_attn (RoCM) returns (output, softmax_lse, ...) when
1106+ # `return_attn_probs = True`
1107+ rest = None
1108+ if isinstance (attn_out , tuple ):
1109+ attn_out , * rest = attn_out
1110+
1111+ # unpad if necessary
1112+ if self ._pad_v :
1113+ attn_out = attn_out [..., :v .shape [- 1 ]]
1114+
1115+ # Remain consistent with old `flash_attn_varlen_func` where there
1116+ # is only one output tensor if `return_softmax_lse` is False.
1117+ if return_softmax_lse :
1118+ assert rest is not None
1119+ return attn_out , rest [0 ]
1120+ return attn_out
1121+
10581122 def _v_up_proj_and_o_proj (self , x ):
10591123 # Convert from (B, N, L) to (N, B, L)
10601124 x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
@@ -1176,40 +1240,19 @@ def _compute_prefill_context(
11761240 k = torch .cat ((k_nope , k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))),
11771241 dim = - 1 )
11781242
1179- # For MLA the v head dim is smaller than qk head dim so we pad
1180- # out v with 0s to match the qk head dim
1181- v_padded = torch .nn .functional .pad (v ,
1182- [0 , q .shape [- 1 ] - v .shape [- 1 ]],
1183- value = 0 )
1184-
1185- if is_vllm_fa :
1186- attn_output , attn_softmax_lse = self .flash_attn_varlen_func (
1187- q = q ,
1188- k = k ,
1189- v = v_padded ,
1190- cu_seqlens_q = prefill_metadata .query_start_loc ,
1191- cu_seqlens_k = prefill_metadata .context_chunk_cu_seq_lens [i ],
1192- max_seqlen_q = prefill_metadata .max_query_len ,
1193- max_seqlen_k = prefill_metadata .
1194- context_chunk_max_seq_lens [i ],
1195- softmax_scale = self .scale ,
1196- causal = False , # Context is unmasked
1197- return_softmax_lse = True ,
1198- )
1199- else :
1200- attn_output , attn_softmax_lse , _ = self .flash_attn_varlen_func (
1201- q = q ,
1202- k = k ,
1203- v = v_padded ,
1204- cu_seqlens_q = prefill_metadata .query_start_loc ,
1205- cu_seqlens_k = prefill_metadata .context_chunk_cu_seq_lens [i ],
1206- max_seqlen_q = prefill_metadata .max_query_len ,
1207- max_seqlen_k = prefill_metadata .
1208- context_chunk_max_seq_lens [i ],
1209- softmax_scale = self .scale ,
1210- causal = False , # Context is unmasked
1211- return_attn_probs = True ,
1212- )
1243+ attn_output , attn_softmax_lse = \
1244+ self ._flash_attn_varlen_diff_headdims (
1245+ q = q ,
1246+ k = k ,
1247+ v = v ,
1248+ cu_seqlens_q = prefill_metadata .query_start_loc ,
1249+ cu_seqlens_k = prefill_metadata .context_chunk_cu_seq_lens [i ],
1250+ max_seqlen_q = prefill_metadata .max_query_len ,
1251+ max_seqlen_k = prefill_metadata .context_chunk_max_seq_lens [i ],
1252+ softmax_scale = self .scale ,
1253+ causal = False , # Context is unmasked
1254+ return_softmax_lse = True ,
1255+ )
12131256
12141257 if output is None :
12151258 output = attn_output
@@ -1252,58 +1295,22 @@ def _forward_prefill(
12521295
12531296 k = torch .cat ((k_nope , k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))), dim = - 1 )
12541297
1255- # For MLA the v head dim is smaller than qk head dim so we pad out
1256- # v with 0s to match the qk head dim
1257- v_padded = torch .nn .functional .pad (v , [0 , q .shape [- 1 ] - v .shape [- 1 ]],
1258- value = 0 )
1259-
1260- if is_hip and envs .VLLM_USE_TRITON_FLASH_ATTN and not has_context :
1261- output = self .triton_fa_func (
1262- q ,
1263- k ,
1264- v_padded ,
1265- None ,
1266- prefill_metadata .query_start_loc ,
1267- prefill_metadata .query_start_loc ,
1268- prefill_metadata .max_prefill_seq_len ,
1269- prefill_metadata .max_prefill_seq_len ,
1270- True , # causal
1271- self .scale ,
1272- None , # attn_mask is None unless applying ALiBi mask
1273- )
1274- ## triton flash attention always return 2 objects
1275- if not has_context :
1276- output = output [0 ]
1277- elif is_vllm_fa :
1278- output = self .flash_attn_varlen_func (
1279- q = q ,
1280- k = k ,
1281- v = v_padded ,
1282- cu_seqlens_q = prefill_metadata .query_start_loc ,
1283- cu_seqlens_k = prefill_metadata .query_start_loc ,
1284- max_seqlen_q = prefill_metadata .max_prefill_seq_len ,
1285- max_seqlen_k = prefill_metadata .max_prefill_seq_len ,
1286- softmax_scale = self .scale ,
1287- causal = True ,
1288- return_softmax_lse = has_context ,
1289- )
1290- else :
1291- output = self .flash_attn_varlen_func (
1292- q = q ,
1293- k = k ,
1294- v = v_padded ,
1295- cu_seqlens_q = prefill_metadata .query_start_loc ,
1296- cu_seqlens_k = prefill_metadata .query_start_loc ,
1297- max_seqlen_q = prefill_metadata .max_prefill_seq_len ,
1298- max_seqlen_k = prefill_metadata .max_prefill_seq_len ,
1299- softmax_scale = self .scale ,
1300- causal = True ,
1301- return_attn_probs = has_context ,
1302- )
1298+ output = self ._flash_attn_varlen_diff_headdims (
1299+ q = q ,
1300+ k = k ,
1301+ v = v ,
1302+ cu_seqlens_q = prefill_metadata .query_start_loc ,
1303+ cu_seqlens_k = prefill_metadata .query_start_loc ,
1304+ max_seqlen_q = prefill_metadata .max_prefill_seq_len ,
1305+ max_seqlen_k = prefill_metadata .max_prefill_seq_len ,
1306+ softmax_scale = self .scale ,
1307+ causal = True ,
1308+ return_softmax_lse = has_context ,
1309+ )
13031310
13041311 if has_context :
13051312 # ROCm flash_attn_varlen_func will return 3 objects instead of 2
1306- suffix_output , suffix_lse , * rest = output
1313+ suffix_output , suffix_lse = output
13071314 context_output , context_lse = self ._compute_prefill_context ( \
13081315 q , kv_c_and_k_pe_cache , attn_metadata )
13091316
@@ -1316,12 +1323,7 @@ def _forward_prefill(
13161323 suffix_lse = suffix_lse ,
13171324 )
13181325
1319- # slice by `:v.shape[-1]` in order to remove v headdim padding
1320- output = output \
1321- .view (- 1 , self .num_heads , q .shape [- 1 ])[..., :v .shape [- 1 ]]\
1322- .reshape (- 1 , self .num_heads * v .shape [- 1 ])
1323-
1324- return self .o_proj (output )[0 ]
1326+ return self .o_proj (output .flatten (start_dim = - 2 ))[0 ]
13251327
13261328 @abstractmethod
13271329 def _forward_decode (
0 commit comments