1717 from xformers import ops as xops
1818 from xformers .ops .fmha .attn_bias import BlockDiagonalCausalMask
1919
20+ from vllm .attention .backends .xformers import _make_alibi_bias
21+
2022FLOAT32_BYTES = torch .finfo (torch .float ).bits // 8
2123# This will change depending on the compute capability.
2224# - 512 as a buffer
@@ -345,20 +347,26 @@ def ref_multi_query_kv_attention(
345347 key : torch .Tensor ,
346348 value : torch .Tensor ,
347349 scale : float ,
350+ alibi_bias : Optional [list [torch .Tensor ]],
348351 dtype : torch .dtype ,
349352) -> torch .Tensor :
350353 num_seqs = len (cu_seq_lens ) - 1
351354 ref_outputs : list [torch .Tensor ] = []
355+ if alibi_bias :
356+ assert len (alibi_bias ) == num_seqs
352357 for i in range (num_seqs ):
353358 start_idx = cu_seq_lens [i ]
354359 end_idx = cu_seq_lens [i + 1 ]
355360 seq_len = end_idx - start_idx
356361
357- # Create attention mask.
358- attn_mask = torch .triu (torch .ones (seq_len , seq_len , dtype = dtype ),
359- diagonal = 1 )
360- attn_mask = attn_mask * torch .finfo (dtype ).min
361- attn_mask = attn_mask .to (dtype = dtype )
362+ # Create attention mask. ALiBi already includes a tril causal mask.
363+ if alibi_bias :
364+ attn_mask = alibi_bias [i ]
365+ else :
366+ attn_mask = torch .triu (torch .ones (seq_len , seq_len , dtype = dtype ),
367+ diagonal = 1 )
368+ attn_mask = attn_mask * torch .finfo (dtype ).min
369+ attn_mask = attn_mask .to (dtype = dtype )
362370
363371 ref_output = ref_masked_attention (
364372 query [start_idx :end_idx ],
@@ -372,7 +380,6 @@ def ref_multi_query_kv_attention(
372380 return torch .cat (ref_outputs , dim = 0 )
373381
374382
375- # TODO(woosuk): Add tests for USE_ALIBI=True.
376383@pytest .mark .parametrize ("num_seqs" , NUM_PREFILL_SEQS )
377384@pytest .mark .parametrize ("num_heads" , NUM_HEADS )
378385@pytest .mark .parametrize ("head_size" , HEAD_SIZES )
@@ -389,6 +396,7 @@ def test_multi_query_kv_attention(
389396 dtype : torch .dtype ,
390397 seed : int ,
391398 device : str ,
399+ use_alibi : bool = False ,
392400) -> None :
393401 current_platform .seed_everything (seed )
394402 torch .set_default_device (device )
@@ -414,16 +422,40 @@ def test_multi_query_kv_attention(
414422 # Handle MQA and GQA
415423 key = torch .repeat_interleave (key , num_queries_per_kv , dim = 1 )
416424 value = torch .repeat_interleave (value , num_queries_per_kv , dim = 1 )
417- attn_bias = BlockDiagonalCausalMask .from_seqlens (seq_lens )
418- output = xops .memory_efficient_attention_forward (
419- query .unsqueeze (0 ),
420- key .unsqueeze (0 ),
421- value .unsqueeze (0 ),
422- attn_bias = attn_bias ,
423- p = 0.0 ,
424- scale = scale ,
425- )
426- output = output .squeeze (0 )
425+ alibi_bias = None
426+ if use_alibi :
427+ alibi_slopes = torch .randn (num_query_heads , dtype = torch .float )
428+ attn_bias = _make_alibi_bias (alibi_slopes , num_kv_heads , dtype ,
429+ seq_lens )
430+ output = torch .empty_like (query )
431+ start = 0
432+ # Dynamic sequence length not supported with custom attn_bias.
433+ for i , seq_len in enumerate (seq_lens ):
434+ end = start + seq_len
435+ out = xops .memory_efficient_attention_forward (
436+ query [None , start :end ],
437+ key [None , start :end ],
438+ value [None , start :end ],
439+ attn_bias = attn_bias [i ],
440+ p = 0.0 ,
441+ scale = scale )
442+ output [start :end ].copy_ (out .view_as (query [start :end ]))
443+ start += seq_len
444+ # xformers.AttentionBias to Tensor for use in reference impl.
445+ alibi_bias = [
446+ b .materialize (b .shape , device = device ).squeeze () for b in attn_bias
447+ ]
448+ else :
449+ attn_bias = BlockDiagonalCausalMask .from_seqlens (seq_lens )
450+ output = xops .memory_efficient_attention_forward (
451+ query .unsqueeze (0 ),
452+ key .unsqueeze (0 ),
453+ value .unsqueeze (0 ),
454+ attn_bias = attn_bias ,
455+ p = 0.0 ,
456+ scale = scale ,
457+ )
458+ output = output .squeeze (0 )
427459
428460 cu_seq_lens = [0 ]
429461 for seq_len in seq_lens :
@@ -434,8 +466,37 @@ def test_multi_query_kv_attention(
434466 key ,
435467 value ,
436468 scale ,
469+ alibi_bias ,
437470 dtype ,
438471 )
439472 atol = get_default_atol (output ) if current_platform .is_rocm () else 1e-3
440473 rtol = get_default_rtol (output ) if current_platform .is_rocm () else 1e-5
441- torch .testing .assert_close (output , ref_output , atol = atol , rtol = rtol )
474+ torch .testing .assert_close (output , ref_output , atol = atol , rtol = rtol )
475+
476+
477+ @pytest .mark .parametrize ("num_seqs" , NUM_PREFILL_SEQS )
478+ @pytest .mark .parametrize ("num_heads" , NUM_HEADS )
479+ @pytest .mark .parametrize ("head_size" , [64 ])
480+ @pytest .mark .parametrize ("dtype" , DTYPES )
481+ @pytest .mark .parametrize ("seed" , SEEDS )
482+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
483+ @pytest .mark .skipif (current_platform .is_rocm (),
484+ reason = "Xformers backend is not supported on ROCm." )
485+ @torch .inference_mode ()
486+ def test_multi_query_kv_attention_with_alibi (
487+ num_seqs : int ,
488+ num_heads : tuple [int , int ],
489+ head_size : int ,
490+ dtype : torch .dtype ,
491+ seed : int ,
492+ device : str ,
493+ ) -> None :
494+ return test_multi_query_kv_attention (
495+ num_seqs ,
496+ num_heads ,
497+ head_size ,
498+ dtype ,
499+ seed ,
500+ device ,
501+ use_alibi = True ,
502+ )
0 commit comments