Skip to content

Commit 44b8397

Browse files
committed
separate alibi test for lighter load
Signed-off-by: NickLucche <[email protected]>
1 parent 6cc676d commit 44b8397

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

tests/kernels/test_attention.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,6 @@ def ref_multi_query_kv_attention(
383383
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
384384
@pytest.mark.parametrize("num_heads", NUM_HEADS)
385385
@pytest.mark.parametrize("head_size", HEAD_SIZES)
386-
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
387386
@pytest.mark.parametrize("dtype", DTYPES)
388387
@pytest.mark.parametrize("seed", SEEDS)
389388
@pytest.mark.parametrize("device", CUDA_DEVICES)
@@ -394,10 +393,10 @@ def test_multi_query_kv_attention(
394393
num_seqs: int,
395394
num_heads: tuple[int, int],
396395
head_size: int,
397-
use_alibi: bool,
398396
dtype: torch.dtype,
399397
seed: int,
400398
device: str,
399+
use_alibi: bool = False,
401400
) -> None:
402401
current_platform.seed_everything(seed)
403402
torch.set_default_device(device)
@@ -472,4 +471,32 @@ def test_multi_query_kv_attention(
472471
)
473472
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
474473
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
475-
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
474+
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
475+
476+
477+
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
478+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
479+
@pytest.mark.parametrize("head_size", [64])
480+
@pytest.mark.parametrize("dtype", DTYPES)
481+
@pytest.mark.parametrize("seed", SEEDS)
482+
@pytest.mark.parametrize("device", CUDA_DEVICES)
483+
@pytest.mark.skipif(current_platform.is_rocm(),
484+
reason="Xformers backend is not supported on ROCm.")
485+
@torch.inference_mode()
486+
def test_multi_query_kv_attention_with_alibi(
487+
num_seqs: int,
488+
num_heads: tuple[int, int],
489+
head_size: int,
490+
dtype: torch.dtype,
491+
seed: int,
492+
device: str,
493+
) -> None:
494+
return test_multi_query_kv_attention(
495+
num_seqs,
496+
num_heads,
497+
head_size,
498+
dtype,
499+
seed,
500+
device,
501+
use_alibi=True,
502+
)

0 commit comments

Comments
 (0)