245245 from flash_attn import flash_attn_varlen_func
246246 is_vllm_fa = False
247247
248- from vllm .attention .ops .triton_flash_attention import ( triton_attention )
248+ from vllm .attention .ops .triton_flash_attention import triton_attention
249249
250250if TYPE_CHECKING :
251251 from vllm .worker .model_runner import (ModelInputForGPUBuilder ,
@@ -1330,9 +1330,9 @@ def _compute_prefill_context(
13301330 prefill_metadata .context_chunk_cu_seq_lens [i ],
13311331 prefill_metadata .max_query_len ,
13321332 prefill_metadata .context_chunk_max_seq_lens [i ],
1333- False , # causal
1333+ False , # causal
13341334 self .scale ,
1335- None , # attn_mask is None unless applying ALiBi mask
1335+ None , # attn_mask is None unless applying ALiBi mask
13361336 )
13371337 elif is_vllm_fa :
13381338 attn_output , attn_softmax_lse = self .flash_attn_varlen_func (
@@ -1342,7 +1342,8 @@ def _compute_prefill_context(
13421342 cu_seqlens_q = prefill_metadata .query_start_loc ,
13431343 cu_seqlens_k = prefill_metadata .context_chunk_cu_seq_lens [i ],
13441344 max_seqlen_q = prefill_metadata .max_query_len ,
1345- max_seqlen_k = prefill_metadata .context_chunk_max_seq_lens [i ],
1345+ max_seqlen_k = prefill_metadata .
1346+ context_chunk_max_seq_lens [i ],
13461347 softmax_scale = self .scale ,
13471348 causal = False , # Context is unmasked
13481349 return_softmax_lse = True ,
@@ -1355,7 +1356,8 @@ def _compute_prefill_context(
13551356 cu_seqlens_q = prefill_metadata .query_start_loc ,
13561357 cu_seqlens_k = prefill_metadata .context_chunk_cu_seq_lens [i ],
13571358 max_seqlen_q = prefill_metadata .max_query_len ,
1358- max_seqlen_k = prefill_metadata .context_chunk_max_seq_lens [i ],
1359+ max_seqlen_k = prefill_metadata .
1360+ context_chunk_max_seq_lens [i ],
13591361 softmax_scale = self .scale ,
13601362 causal = False , # Context is unmasked
13611363 return_attn_probs = True ,
@@ -1417,9 +1419,9 @@ def _forward_prefill(
14171419 prefill_metadata .query_start_loc ,
14181420 prefill_metadata .max_prefill_seq_len ,
14191421 prefill_metadata .max_prefill_seq_len ,
1420- True , # causal
1422+ True , # causal
14211423 self .scale ,
1422- None , # attn_mask is None unless applying ALiBi mask
1424+ None , # attn_mask is None unless applying ALiBi mask
14231425 )
14241426 ## triton flash attention always return 2 objects
14251427 if not has_context :
0 commit comments