diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 7ce086a6ac..f14c57b1f1 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -577,47 +577,7 @@ def test_trtllm_batch_prefill_bs1( ) -@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND -@pytest.mark.parametrize( - "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", - [ - (4, 1, 16, 2, 1), - (4, 1, 32, 2, 5), - (4, 2, 64, 2, 5), - (4, 3, 32, 2, 5), - (4, 3, 64, 2, 1), - (4, 4, 64, 4, 1), - (4, 5, 64, 4, 8), - (128, 1, 64, 2, 5), - (128, 2, 32, 4, 1), - (128, 3, 16, 4, 8), - (128, 4, 16, 2, 5), - (128, 5, 16, 2, 5), - (256, 1, 64, 4, 8), - (256, 2, 16, 2, 8), - (256, 3, 64, 4, 5), - (256, 4, 32, 2, 8), - (256, 5, 32, 2, 1), - ], -) -@pytest.mark.parametrize("window_left", [-1, 127]) -@pytest.mark.parametrize( - "q_dtype,kv_dtype,o_dtype", - [ - ("bf16", "bf16", "bf16"), - ("fp16", "fp16", "fp16"), - ("bf16", "fp8", "bf16"), - ("fp16", "fp8", "fp16"), - ("fp8", "fp8", "bf16"), - ("fp8", "fp8", "fp16"), - ("fp8", "fp8", "fp8"), - ("fp8", "fp8", "nvfp4"), - ], -) -@pytest.mark.parametrize("enable_pdl", [True, False, None]) -@pytest.mark.parametrize("enable_sink", [True, False]) -@pytest.mark.parametrize("max_in_kv_len", [110]) -def test_trtllm_batch_decode( +def _test_trtllm_batch_decode( kv_layout, batch_size, q_len_per_req, @@ -631,7 +591,13 @@ def test_trtllm_batch_decode( enable_pdl, enable_sink, max_in_kv_len, + head_dim, ): + """ + Common function for testing trtllm-gen decode. + + Combinations of parameters are tested in test_trtllm_batch_decode() and test_trtllm_batch_decode_...() + """ compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] != 10: pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") @@ -642,7 +608,6 @@ def test_trtllm_batch_decode( # Set up test parameters torch.manual_seed(0) - head_dim = 128 # Generate random sequence lengths num_qo_heads = num_kv_heads * head_grp_size @@ -858,6 +823,82 @@ def test_trtllm_batch_decode( assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() +@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize( + "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + [ + (4, 1, 16, 2, 1), + (4, 1, 32, 2, 5), + (4, 2, 64, 2, 5), + (4, 3, 32, 2, 5), + (4, 3, 64, 2, 1), + (4, 4, 64, 4, 1), + (4, 5, 64, 4, 8), + (128, 1, 64, 2, 5), + (128, 2, 32, 4, 1), + (128, 3, 16, 4, 8), + (128, 4, 16, 2, 5), + (128, 5, 16, 2, 5), + (256, 1, 64, 4, 8), + (256, 2, 16, 2, 8), + (256, 3, 64, 4, 5), + (256, 4, 32, 2, 8), + (256, 5, 32, 2, 1), + ], +) +@pytest.mark.parametrize("window_left", [-1, 127]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ("fp16", "fp16", "fp16"), + ("bf16", "fp8", "bf16"), + ("fp16", "fp8", "fp16"), + ("fp8", "fp8", "bf16"), + ("fp8", "fp8", "fp16"), + ("fp8", "fp8", "fp8"), + ("fp8", "fp8", "nvfp4"), + ], +) +@pytest.mark.parametrize("enable_pdl", [True, False, None]) +@pytest.mark.parametrize("enable_sink", [True, False]) +@pytest.mark.parametrize("max_in_kv_len", [110]) +@pytest.mark.parametrize("head_dim", [128]) +def test_trtllm_batch_decode( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, +): + # General set of tests for trtllm-gen decode + _test_trtllm_batch_decode( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, + ) + + @pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND @pytest.mark.parametrize( "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", @@ -875,6 +916,7 @@ def test_trtllm_batch_decode( @pytest.mark.parametrize("enable_pdl", [None]) @pytest.mark.parametrize("enable_sink", [False]) @pytest.mark.parametrize("max_in_kv_len", [8192]) +@pytest.mark.parametrize("head_dim", [128]) def test_trtllm_batch_decode_bs1( kv_layout, batch_size, @@ -889,9 +931,134 @@ def test_trtllm_batch_decode_bs1( enable_pdl, enable_sink, max_in_kv_len, + head_dim, ): + # Small number of test cases for batch size 1 pytest.xfail("trtllm-gen decode gets incorrect output with bs1") - test_trtllm_batch_decode( + _test_trtllm_batch_decode( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, + ) + + +@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize( + "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + [ + (4, 1, 16, 2, 1), + (4, 1, 32, 2, 5), + (4, 3, 64, 2, 1), + (4, 4, 64, 4, 1), + (128, 3, 16, 4, 8), + (128, 4, 16, 2, 5), + (256, 4, 32, 2, 8), + (256, 5, 32, 2, 1), + ], +) +@pytest.mark.parametrize("window_left", [-1]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ("fp16", "fp16", "fp16"), + ("fp8", "fp8", "fp16"), + ("fp8", "fp8", "fp8"), + ("fp8", "fp8", "nvfp4"), + ], +) +@pytest.mark.parametrize("enable_pdl", [None]) +@pytest.mark.parametrize("enable_sink", [False]) +@pytest.mark.parametrize("max_in_kv_len", [110]) +@pytest.mark.parametrize("head_dim", [256]) +def test_trtllm_batch_decode_head_dim_256( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, +): + # Small number of test cases for head_dim = 256 + pytest.xfail("trtllm-gen decode gets incorrect output with head_dim = 256") + _test_trtllm_batch_decode( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, + ) + + +@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize( + "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + [ + (1, 1, 16, 2, 1), + (1, 1, 32, 2, 5), + (1, 3, 64, 2, 1), + (1, 4, 64, 4, 1), + ], +) +@pytest.mark.parametrize("window_left", [-1]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ("fp8", "fp8", "fp8"), + ], +) +@pytest.mark.parametrize("enable_pdl", [None]) +@pytest.mark.parametrize("enable_sink", [False]) +@pytest.mark.parametrize("max_in_kv_len", [4096, 8192, 16384, 32768, 65536, 131072]) +@pytest.mark.parametrize("head_dim", [128]) +def test_trtllm_batch_decode_long_sequence_length( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, +): + # Small number of test cases for long sequence length + pytest.xfail("trtllm-gen decode gets incorrect output with Long sequence length") + _test_trtllm_batch_decode( kv_layout, batch_size, q_len_per_req, @@ -905,6 +1072,7 @@ def test_trtllm_batch_decode_bs1( enable_pdl, enable_sink, max_in_kv_len, + head_dim, ) @@ -1053,8 +1221,3 @@ def test_trtllm_gen_prefill_deepseek_bs1( test_trtllm_gen_prefill_deepseek( batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal ) - - -if __name__ == "__main__": - test_trtllm_batch_prefill("HND", 128, 32, 2, 5, -1, "fp16", "fp16", "fp16", False) - test_trtllm_batch_decode("HND", 256, 3, 64, 4, 5, -1, "fp8", "fp8", "fp8", True)