From 06b331ceb58c4e1e0f3a3c24b1675efcc0a54903 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Tue, 28 Oct 2025 18:52:20 +0000 Subject: [PATCH 1/3] Add head dim 256 test cases and mark as xfial --- tests/attention/test_trtllm_gen_attention.py | 75 ++++++++++++++++++-- 1 file changed, 69 insertions(+), 6 deletions(-) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 7ce086a6ac..d79d8678e5 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -617,6 +617,7 @@ def test_trtllm_batch_prefill_bs1( @pytest.mark.parametrize("enable_pdl", [True, False, None]) @pytest.mark.parametrize("enable_sink", [True, False]) @pytest.mark.parametrize("max_in_kv_len", [110]) +@pytest.mark.parametrize("head_dim", [128]) def test_trtllm_batch_decode( kv_layout, batch_size, @@ -631,6 +632,7 @@ def test_trtllm_batch_decode( enable_pdl, enable_sink, max_in_kv_len, + head_dim, ): compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] != 10: @@ -642,7 +644,6 @@ def test_trtllm_batch_decode( # Set up test parameters torch.manual_seed(0) - head_dim = 128 # Generate random sequence lengths num_qo_heads = num_kv_heads * head_grp_size @@ -875,6 +876,7 @@ def test_trtllm_batch_decode( @pytest.mark.parametrize("enable_pdl", [None]) @pytest.mark.parametrize("enable_sink", [False]) @pytest.mark.parametrize("max_in_kv_len", [8192]) +@pytest.mark.parametrize("head_dim", [128]) def test_trtllm_batch_decode_bs1( kv_layout, batch_size, @@ -889,6 +891,7 @@ def test_trtllm_batch_decode_bs1( enable_pdl, enable_sink, max_in_kv_len, + head_dim, ): pytest.xfail("trtllm-gen decode gets incorrect output with bs1") test_trtllm_batch_decode( @@ -905,6 +908,71 @@ def test_trtllm_batch_decode_bs1( enable_pdl, enable_sink, max_in_kv_len, + head_dim, + ) + + +@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize( + "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + [ + (4, 1, 16, 2, 1), + (4, 1, 32, 2, 5), + (4, 3, 64, 2, 1), + (4, 4, 64, 4, 1), + (128, 3, 16, 4, 8), + (128, 4, 16, 2, 5), + (256, 4, 32, 2, 8), + (256, 5, 32, 2, 1), + ], +) +@pytest.mark.parametrize("window_left", [-1]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ("fp16", "fp16", "fp16"), + ("fp8", "fp8", "fp16"), + ("fp8", "fp8", "fp8"), + ("fp8", "fp8", "nvfp4"), + ], +) +@pytest.mark.parametrize("enable_pdl", [None]) +@pytest.mark.parametrize("enable_sink", [False]) +@pytest.mark.parametrize("max_in_kv_len", [110]) +@pytest.mark.parametrize("head_dim", [256]) +def test_trtllm_batch_decode_head_dim_256( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, +): + pytest.xfail("trtllm-gen decode gets incorrect output with head_dim = 256") + test_trtllm_batch_decode( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, ) @@ -1053,8 +1121,3 @@ def test_trtllm_gen_prefill_deepseek_bs1( test_trtllm_gen_prefill_deepseek( batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal ) - - -if __name__ == "__main__": - test_trtllm_batch_prefill("HND", 128, 32, 2, 5, -1, "fp16", "fp16", "fp16", False) - test_trtllm_batch_decode("HND", 256, 3, 64, 4, 5, -1, "fp8", "fp8", "fp8", True) From 9b6e17c1b31dd203fd1297696c28cc79321c8bc8 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Tue, 28 Oct 2025 20:31:32 +0000 Subject: [PATCH 2/3] Refactor tests to base test --- tests/attention/test_trtllm_gen_attention.py | 130 ++++++++++++------- 1 file changed, 86 insertions(+), 44 deletions(-) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index d79d8678e5..c825d8fb18 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -577,48 +577,7 @@ def test_trtllm_batch_prefill_bs1( ) -@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND -@pytest.mark.parametrize( - "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", - [ - (4, 1, 16, 2, 1), - (4, 1, 32, 2, 5), - (4, 2, 64, 2, 5), - (4, 3, 32, 2, 5), - (4, 3, 64, 2, 1), - (4, 4, 64, 4, 1), - (4, 5, 64, 4, 8), - (128, 1, 64, 2, 5), - (128, 2, 32, 4, 1), - (128, 3, 16, 4, 8), - (128, 4, 16, 2, 5), - (128, 5, 16, 2, 5), - (256, 1, 64, 4, 8), - (256, 2, 16, 2, 8), - (256, 3, 64, 4, 5), - (256, 4, 32, 2, 8), - (256, 5, 32, 2, 1), - ], -) -@pytest.mark.parametrize("window_left", [-1, 127]) -@pytest.mark.parametrize( - "q_dtype,kv_dtype,o_dtype", - [ - ("bf16", "bf16", "bf16"), - ("fp16", "fp16", "fp16"), - ("bf16", "fp8", "bf16"), - ("fp16", "fp8", "fp16"), - ("fp8", "fp8", "bf16"), - ("fp8", "fp8", "fp16"), - ("fp8", "fp8", "fp8"), - ("fp8", "fp8", "nvfp4"), - ], -) -@pytest.mark.parametrize("enable_pdl", [True, False, None]) -@pytest.mark.parametrize("enable_sink", [True, False]) -@pytest.mark.parametrize("max_in_kv_len", [110]) -@pytest.mark.parametrize("head_dim", [128]) -def test_trtllm_batch_decode( +def _test_trtllm_batch_decode( kv_layout, batch_size, q_len_per_req, @@ -634,6 +593,11 @@ def test_trtllm_batch_decode( max_in_kv_len, head_dim, ): + """ + Common function for testing trtllm-gen decode. + + Combinations of parameters are tested in test_trtllm_batch_decode() and test_trtllm_batch_decode_...() + """ compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] != 10: pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") @@ -859,6 +823,82 @@ def test_trtllm_batch_decode( assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() +@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize( + "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + [ + (4, 1, 16, 2, 1), + (4, 1, 32, 2, 5), + (4, 2, 64, 2, 5), + (4, 3, 32, 2, 5), + (4, 3, 64, 2, 1), + (4, 4, 64, 4, 1), + (4, 5, 64, 4, 8), + (128, 1, 64, 2, 5), + (128, 2, 32, 4, 1), + (128, 3, 16, 4, 8), + (128, 4, 16, 2, 5), + (128, 5, 16, 2, 5), + (256, 1, 64, 4, 8), + (256, 2, 16, 2, 8), + (256, 3, 64, 4, 5), + (256, 4, 32, 2, 8), + (256, 5, 32, 2, 1), + ], +) +@pytest.mark.parametrize("window_left", [-1, 127]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ("fp16", "fp16", "fp16"), + ("bf16", "fp8", "bf16"), + ("fp16", "fp8", "fp16"), + ("fp8", "fp8", "bf16"), + ("fp8", "fp8", "fp16"), + ("fp8", "fp8", "fp8"), + ("fp8", "fp8", "nvfp4"), + ], +) +@pytest.mark.parametrize("enable_pdl", [True, False, None]) +@pytest.mark.parametrize("enable_sink", [True, False]) +@pytest.mark.parametrize("max_in_kv_len", [110]) +@pytest.mark.parametrize("head_dim", [128]) +def test_trtllm_batch_decode( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, +): + # General set of tests for trtllm-gen decode + _test_trtllm_batch_decode( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, + ) + + @pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND @pytest.mark.parametrize( "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", @@ -893,8 +933,9 @@ def test_trtllm_batch_decode_bs1( max_in_kv_len, head_dim, ): + # Small number of test cases for batch size 1 pytest.xfail("trtllm-gen decode gets incorrect output with bs1") - test_trtllm_batch_decode( + _test_trtllm_batch_decode( kv_layout, batch_size, q_len_per_req, @@ -957,8 +998,9 @@ def test_trtllm_batch_decode_head_dim_256( max_in_kv_len, head_dim, ): + # Small number of test cases for head_dim = 256 pytest.xfail("trtllm-gen decode gets incorrect output with head_dim = 256") - test_trtllm_batch_decode( + _test_trtllm_batch_decode( kv_layout, batch_size, q_len_per_req, From 097308f9007ae8ae18b57b55984e212ad77d5f38 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Tue, 28 Oct 2025 21:46:38 +0000 Subject: [PATCH 3/3] Adding long seqlen tests --- tests/attention/test_trtllm_gen_attention.py | 58 ++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index c825d8fb18..f14c57b1f1 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -1018,6 +1018,64 @@ def test_trtllm_batch_decode_head_dim_256( ) +@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND +@pytest.mark.parametrize( + "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + [ + (1, 1, 16, 2, 1), + (1, 1, 32, 2, 5), + (1, 3, 64, 2, 1), + (1, 4, 64, 4, 1), + ], +) +@pytest.mark.parametrize("window_left", [-1]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ("fp8", "fp8", "fp8"), + ], +) +@pytest.mark.parametrize("enable_pdl", [None]) +@pytest.mark.parametrize("enable_sink", [False]) +@pytest.mark.parametrize("max_in_kv_len", [4096, 8192, 16384, 32768, 65536, 131072]) +@pytest.mark.parametrize("head_dim", [128]) +def test_trtllm_batch_decode_long_sequence_length( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, +): + # Small number of test cases for long sequence length + pytest.xfail("trtllm-gen decode gets incorrect output with Long sequence length") + _test_trtllm_batch_decode( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + enable_sink, + max_in_kv_len, + head_dim, + ) + + @pytest.mark.parametrize("batch_size", [4, 128, 256]) @pytest.mark.parametrize("s_qo", [32, 64, 87]) @pytest.mark.parametrize("s_kv", [32, 64, 87])