diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index ddace9283d..ad55fe6b7f 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -3394,7 +3394,7 @@ def mm_fp4( return out -@supported_compute_capability([89, 90, 100, 103, 120, 121]) +@supported_compute_capability([89, 90, 100, 103, 110, 120, 121]) def _cudnn_bmm_fp8_requirement( A: torch.Tensor, B: torch.Tensor, @@ -3408,7 +3408,7 @@ def _cudnn_bmm_fp8_requirement( return True -@supported_compute_capability([89, 90, 100, 103, 120, 121]) +@supported_compute_capability([89, 90, 100, 103, 110, 120, 121]) def _cublas_bmm_fp8_requirement( A: torch.Tensor, B: torch.Tensor, diff --git a/tests/gemm/test_bmm_fp8.py b/tests/gemm/test_bmm_fp8.py index e20b5bc756..f44191865a 100644 --- a/tests/gemm/test_bmm_fp8.py +++ b/tests/gemm/test_bmm_fp8.py @@ -29,7 +29,12 @@ def test_bmm_fp8(b, m, n, k, input_dtype, mat2_dtype, res_dtype, backend, auto_t pytest.skip("Invalid combination: cutlass does not support e5m2") if auto_tuning and backend != "cutlass": pytest.skip("Invalid combination: auto_tuning only supported for cutlass") - + if compute_capability[0] == 11 and ( + input_dtype == torch.float8_e5m2 or mat2_dtype == torch.float8_e5m2 + ): + pytest.skip( + "Invalid combination: only cutlass supports SM110 which does not support e5m2" + ) input = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16) input_fp8, input_inv_s = to_float8(input, dtype=input_dtype)