diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py index 18995545552e..6e5468969bf2 100644 --- a/tests/kernels/attention/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -9,7 +9,6 @@ from vllm.vllm_flash_attn import ( fa_version_unsupported_reason, flash_attn_varlen_func, - flash_attn_with_kvcache, is_fa_version_supported, ) @@ -83,124 +82,6 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) -@pytest.mark.parametrize("use_out", [True, False]) -@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", SOFT_CAPS) -@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) -@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) -@pytest.mark.parametrize("fa_version", [2, 3]) -@pytest.mark.parametrize("q_dtype", QDTYPES) -@torch.inference_mode() -def test_flash_attn_with_paged_kv( - use_out: bool, - kv_lens: list[int], - num_heads: tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, - soft_cap: float | None, - num_blocks: int, - sliding_window: int | None, - fa_version: int, - q_dtype: torch.dtype | None, -) -> None: - torch.set_default_device("cuda") - if not is_fa_version_supported(fa_version): - pytest.skip( - f"Flash attention version {fa_version} not supported due " - f'to: "{fa_version_unsupported_reason(fa_version)}"' - ) - if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2): - pytest.skip( - "Flash attention with quantized inputs is only " - "supported on version 3 with bfloat16 base type" - ) - - current_platform.seed_everything(0) - num_seqs = len(kv_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) - - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn( - num_blocks, block_size, num_kv_heads, head_size, dtype=dtype - ) - value_cache = torch.randn_like(key_cache) - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint( - 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 - ) - - q = query.unsqueeze(1) - out = torch.empty_like(q) if use_out else None - - maybe_quantized_query = q - maybe_quantized_key_cache = key_cache - maybe_quantized_value_cache = value_cache - q_descale = None - k_descale = None - v_descale = None - if q_dtype is not None: - # QKV are drawn from N(0, 1): no need for a fp8 scaling factor - maybe_quantized_query = q.to(q_dtype) - maybe_quantized_key_cache = key_cache.to(q_dtype) - maybe_quantized_value_cache = value_cache.to(q_dtype) - - scale_shape = (num_seqs, num_kv_heads) - q_descale = torch.ones(scale_shape, dtype=torch.float32) - k_descale = torch.ones(scale_shape, dtype=torch.float32) - v_descale = torch.ones(scale_shape, dtype=torch.float32) - - output = flash_attn_with_kvcache( - q=maybe_quantized_query, - k_cache=maybe_quantized_key_cache, - v_cache=maybe_quantized_value_cache, - out=out, - softmax_scale=scale, - causal=True, - block_table=block_tables, - cache_seqlens=kv_lens_tensor, - softcap=soft_cap if soft_cap is not None else 0, - window_size=window_size, - fa_version=fa_version, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - ) - output = output if not use_out else out - output = output.squeeze(1) - - atol, rtol = 1.5e-2, 1e-2 - if q_dtype is not None: - atol, rtol = 1.5e-1, 1.5e-1 - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap, - sliding_window=sliding_window, - ) - ( - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), - f"{torch.max(torch.abs(output - ref_output))}", - ) - - @pytest.mark.parametrize("use_out", [True, False]) @pytest.mark.parametrize( "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]]