Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 211 additions & 48 deletions tests/attention/test_trtllm_gen_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,47 +577,7 @@ def test_trtllm_batch_prefill_bs1(
)


@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND
@pytest.mark.parametrize(
"batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size",
[
(4, 1, 16, 2, 1),
(4, 1, 32, 2, 5),
(4, 2, 64, 2, 5),
(4, 3, 32, 2, 5),
(4, 3, 64, 2, 1),
(4, 4, 64, 4, 1),
(4, 5, 64, 4, 8),
(128, 1, 64, 2, 5),
(128, 2, 32, 4, 1),
(128, 3, 16, 4, 8),
(128, 4, 16, 2, 5),
(128, 5, 16, 2, 5),
(256, 1, 64, 4, 8),
(256, 2, 16, 2, 8),
(256, 3, 64, 4, 5),
(256, 4, 32, 2, 8),
(256, 5, 32, 2, 1),
],
)
@pytest.mark.parametrize("window_left", [-1, 127])
@pytest.mark.parametrize(
"q_dtype,kv_dtype,o_dtype",
[
("bf16", "bf16", "bf16"),
("fp16", "fp16", "fp16"),
("bf16", "fp8", "bf16"),
("fp16", "fp8", "fp16"),
("fp8", "fp8", "bf16"),
("fp8", "fp8", "fp16"),
("fp8", "fp8", "fp8"),
("fp8", "fp8", "nvfp4"),
],
)
@pytest.mark.parametrize("enable_pdl", [True, False, None])
@pytest.mark.parametrize("enable_sink", [True, False])
@pytest.mark.parametrize("max_in_kv_len", [110])
def test_trtllm_batch_decode(
def _test_trtllm_batch_decode(
kv_layout,
batch_size,
q_len_per_req,
Expand All @@ -631,7 +591,13 @@ def test_trtllm_batch_decode(
enable_pdl,
enable_sink,
max_in_kv_len,
head_dim,
):
"""
Common function for testing trtllm-gen decode.

Combinations of parameters are tested in test_trtllm_batch_decode() and test_trtllm_batch_decode_...()
"""
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] != 10:
pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")
Expand All @@ -642,7 +608,6 @@ def test_trtllm_batch_decode(

# Set up test parameters
torch.manual_seed(0)
head_dim = 128

# Generate random sequence lengths
num_qo_heads = num_kv_heads * head_grp_size
Expand Down Expand Up @@ -858,6 +823,82 @@ def test_trtllm_batch_decode(
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()


@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND
@pytest.mark.parametrize(
"batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size",
[
(4, 1, 16, 2, 1),
(4, 1, 32, 2, 5),
(4, 2, 64, 2, 5),
(4, 3, 32, 2, 5),
(4, 3, 64, 2, 1),
(4, 4, 64, 4, 1),
(4, 5, 64, 4, 8),
(128, 1, 64, 2, 5),
(128, 2, 32, 4, 1),
(128, 3, 16, 4, 8),
(128, 4, 16, 2, 5),
(128, 5, 16, 2, 5),
(256, 1, 64, 4, 8),
(256, 2, 16, 2, 8),
(256, 3, 64, 4, 5),
(256, 4, 32, 2, 8),
(256, 5, 32, 2, 1),
],
)
@pytest.mark.parametrize("window_left", [-1, 127])
@pytest.mark.parametrize(
"q_dtype,kv_dtype,o_dtype",
[
("bf16", "bf16", "bf16"),
("fp16", "fp16", "fp16"),
("bf16", "fp8", "bf16"),
("fp16", "fp8", "fp16"),
("fp8", "fp8", "bf16"),
("fp8", "fp8", "fp16"),
("fp8", "fp8", "fp8"),
("fp8", "fp8", "nvfp4"),
],
)
@pytest.mark.parametrize("enable_pdl", [True, False, None])
@pytest.mark.parametrize("enable_sink", [True, False])
@pytest.mark.parametrize("max_in_kv_len", [110])
@pytest.mark.parametrize("head_dim", [128])
def test_trtllm_batch_decode(
kv_layout,
batch_size,
q_len_per_req,
page_size,
num_kv_heads,
head_grp_size,
window_left,
q_dtype,
o_dtype,
kv_dtype,
enable_pdl,
enable_sink,
max_in_kv_len,
head_dim,
):
# General set of tests for trtllm-gen decode
_test_trtllm_batch_decode(
kv_layout,
batch_size,
q_len_per_req,
page_size,
num_kv_heads,
head_grp_size,
window_left,
q_dtype,
o_dtype,
kv_dtype,
enable_pdl,
enable_sink,
max_in_kv_len,
head_dim,
)


@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND
@pytest.mark.parametrize(
"batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size",
Expand All @@ -875,6 +916,7 @@ def test_trtllm_batch_decode(
@pytest.mark.parametrize("enable_pdl", [None])
@pytest.mark.parametrize("enable_sink", [False])
@pytest.mark.parametrize("max_in_kv_len", [8192])
@pytest.mark.parametrize("head_dim", [128])
def test_trtllm_batch_decode_bs1(
kv_layout,
batch_size,
Expand All @@ -889,9 +931,134 @@ def test_trtllm_batch_decode_bs1(
enable_pdl,
enable_sink,
max_in_kv_len,
head_dim,
):
# Small number of test cases for batch size 1
pytest.xfail("trtllm-gen decode gets incorrect output with bs1")
test_trtllm_batch_decode(
_test_trtllm_batch_decode(
kv_layout,
batch_size,
q_len_per_req,
page_size,
num_kv_heads,
head_grp_size,
window_left,
q_dtype,
o_dtype,
kv_dtype,
enable_pdl,
enable_sink,
max_in_kv_len,
head_dim,
)


@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND
@pytest.mark.parametrize(
"batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size",
[
(4, 1, 16, 2, 1),
(4, 1, 32, 2, 5),
(4, 3, 64, 2, 1),
(4, 4, 64, 4, 1),
(128, 3, 16, 4, 8),
(128, 4, 16, 2, 5),
(256, 4, 32, 2, 8),
(256, 5, 32, 2, 1),
],
)
@pytest.mark.parametrize("window_left", [-1])
@pytest.mark.parametrize(
"q_dtype,kv_dtype,o_dtype",
[
("bf16", "bf16", "bf16"),
("fp16", "fp16", "fp16"),
("fp8", "fp8", "fp16"),
("fp8", "fp8", "fp8"),
("fp8", "fp8", "nvfp4"),
],
)
@pytest.mark.parametrize("enable_pdl", [None])
@pytest.mark.parametrize("enable_sink", [False])
@pytest.mark.parametrize("max_in_kv_len", [110])
@pytest.mark.parametrize("head_dim", [256])
def test_trtllm_batch_decode_head_dim_256(
kv_layout,
batch_size,
q_len_per_req,
page_size,
num_kv_heads,
head_grp_size,
window_left,
q_dtype,
o_dtype,
kv_dtype,
enable_pdl,
enable_sink,
max_in_kv_len,
head_dim,
):
# Small number of test cases for head_dim = 256
pytest.xfail("trtllm-gen decode gets incorrect output with head_dim = 256")
_test_trtllm_batch_decode(
kv_layout,
batch_size,
q_len_per_req,
page_size,
num_kv_heads,
head_grp_size,
window_left,
q_dtype,
o_dtype,
kv_dtype,
enable_pdl,
enable_sink,
max_in_kv_len,
head_dim,
)


@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND
@pytest.mark.parametrize(
"batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size",
[
(1, 1, 16, 2, 1),
(1, 1, 32, 2, 5),
(1, 3, 64, 2, 1),
(1, 4, 64, 4, 1),
],
)
@pytest.mark.parametrize("window_left", [-1])
@pytest.mark.parametrize(
"q_dtype,kv_dtype,o_dtype",
[
("bf16", "bf16", "bf16"),
("fp8", "fp8", "fp8"),
],
)
@pytest.mark.parametrize("enable_pdl", [None])
@pytest.mark.parametrize("enable_sink", [False])
@pytest.mark.parametrize("max_in_kv_len", [4096, 8192, 16384, 32768, 65536, 131072])
@pytest.mark.parametrize("head_dim", [128])
def test_trtllm_batch_decode_long_sequence_length(
kv_layout,
batch_size,
q_len_per_req,
page_size,
num_kv_heads,
head_grp_size,
window_left,
q_dtype,
o_dtype,
kv_dtype,
enable_pdl,
enable_sink,
max_in_kv_len,
head_dim,
):
# Small number of test cases for long sequence length
pytest.xfail("trtllm-gen decode gets incorrect output with Long sequence length")
_test_trtllm_batch_decode(
kv_layout,
batch_size,
q_len_per_req,
Expand All @@ -905,6 +1072,7 @@ def test_trtllm_batch_decode_bs1(
enable_pdl,
enable_sink,
max_in_kv_len,
head_dim,
)


Expand Down Expand Up @@ -1053,8 +1221,3 @@ def test_trtllm_gen_prefill_deepseek_bs1(
test_trtllm_gen_prefill_deepseek(
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal
)


if __name__ == "__main__":
test_trtllm_batch_prefill("HND", 128, 32, 2, 5, -1, "fp16", "fp16", "fp16", False)
test_trtllm_batch_decode("HND", 256, 3, 64, 4, 5, -1, "fp8", "fp8", "fp8", True)