Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class ArtifactPath:
When compiling new cubins for backend directories, update the corresponding path.
"""

TRTLLM_GEN_FMHA: str = "1e49deb33ec20018ae0acf1d956a579578069da1/fmha/trtllm-gen/"
TRTLLM_GEN_FMHA: str = "9f1b6ddaa1592a8339a82fcab7d27a57eff445fd/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"c108f5cc46420e11805467898186533fb48d6a6f/batched_gemm-0d28130-7b26988"
)
Expand All @@ -107,7 +107,7 @@ class CheckSumHash:
"""

TRTLLM_GEN_FMHA: str = (
"66757498f573430583d63b04c02bf9e38306eefe2ce31df9b5d923d99bd15d84"
"88f6898813aa61414e0d84545fd34543f77eaaf1b1042489796e9a80fb7233cd"
)
TRTLLM_GEN_BMM: str = (
"85a4516b7ab25b1a6495398ae934a00e30ccd6662b9ec27be1330d7bba5e1ddf"
Expand Down
9 changes: 7 additions & 2 deletions tests/attention/test_trtllm_gen_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,13 +414,13 @@ def _test_trtllm_batch_prefill(
max_q_len,
max_kv_len,
device_scale,
head_dim,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] != 10:
pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")
# Set up test parameters
torch.manual_seed(0)
head_dim = 128

# Generate random sequence lengths
num_qo_heads = num_kv_heads * head_grp_size
Expand Down Expand Up @@ -639,6 +639,7 @@ def _test_trtllm_batch_prefill(
@pytest.mark.parametrize("enable_sink", [True, False])
@pytest.mark.parametrize("max_q_len", [511])
@pytest.mark.parametrize("max_kv_len", [2047])
@pytest.mark.parametrize("head_dim", [128, 256])
def test_trtllm_batch_prefill(
kv_layout,
batch_size,
Expand All @@ -653,6 +654,7 @@ def test_trtllm_batch_prefill(
enable_sink,
max_q_len,
max_kv_len,
head_dim,
):
_test_trtllm_batch_prefill(
kv_layout,
Expand All @@ -669,6 +671,7 @@ def test_trtllm_batch_prefill(
max_q_len,
max_kv_len,
kv_dtype == "fp8",
head_dim,
)


Expand All @@ -690,6 +693,7 @@ def test_trtllm_batch_prefill(
@pytest.mark.parametrize("enable_sink", [False])
@pytest.mark.parametrize("max_q_len", [8192])
@pytest.mark.parametrize("max_kv_len", [8192])
@pytest.mark.parametrize("head_dim", [128, 256])
def test_trtllm_batch_prefill_bs1(
kv_layout,
batch_size,
Expand All @@ -704,6 +708,7 @@ def test_trtllm_batch_prefill_bs1(
enable_sink,
max_q_len,
max_kv_len,
head_dim,
):
_test_trtllm_batch_prefill(
kv_layout,
Expand All @@ -720,6 +725,7 @@ def test_trtllm_batch_prefill_bs1(
max_q_len,
max_kv_len,
False,
head_dim,
)


Expand Down Expand Up @@ -1202,7 +1208,6 @@ def test_trtllm_batch_decode_head_dim_256(
device_scale,
):
# Small number of test cases for head_dim = 256
pytest.xfail("trtllm-gen decode gets incorrect output with head_dim = 256")
_test_trtllm_batch_decode(
"trtllm-gen",
kv_layout,
Expand Down