Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flashinfer/gemm/gemm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3394,7 +3394,7 @@ def mm_fp4(
return out


@supported_compute_capability([89, 90, 100, 103, 120, 121])
@supported_compute_capability([89, 90, 100, 103, 110, 120, 121])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The corresponding test test_bmm_fp8 skips all tests on SM110 that use e5m2 dtypes. This suggests that the cudnn backend for bmm_fp8 might not support e5m2 on SM110. If this is the case, it would be better to add a check inside _cudnn_bmm_fp8_requirement to explicitly reject this combination and raise a ValueError, for consistency with _cutlass_bmm_fp8_requirement. This would make the API more robust. For example:

from ..utils import get_compute_capability

# ... inside _cudnn_bmm_fp8_requirement
major, _ = get_compute_capability(A.device)
if major == 11 and (A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2):
    raise ValueError("bmm_fp8 with e5m2 is not supported on SM110 for cudnn backend")

def _cudnn_bmm_fp8_requirement(
A: torch.Tensor,
B: torch.Tensor,
Expand All @@ -3408,7 +3408,7 @@ def _cudnn_bmm_fp8_requirement(
return True


@supported_compute_capability([89, 90, 100, 103, 120, 121])
@supported_compute_capability([89, 90, 100, 103, 110, 120, 121])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the cudnn backend, the test test_bmm_fp8 skips all tests on SM110 with e5m2 dtypes. If the cublas backend for bmm_fp8 also doesn't support e5m2 on SM110, this requirement function should be updated to raise a ValueError for this combination. This would improve API robustness and clarity. For example:

from ..utils import get_compute_capability

# ... inside _cublas_bmm_fp8_requirement
major, _ = get_compute_capability(A.device)
if major == 11 and (A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2):
    raise ValueError("bmm_fp8 with e5m2 is not supported on SM110 for cublas backend")

def _cublas_bmm_fp8_requirement(
A: torch.Tensor,
B: torch.Tensor,
Expand Down
7 changes: 6 additions & 1 deletion tests/gemm/test_bmm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ def test_bmm_fp8(b, m, n, k, input_dtype, mat2_dtype, res_dtype, backend, auto_t
pytest.skip("Invalid combination: cutlass does not support e5m2")
if auto_tuning and backend != "cutlass":
pytest.skip("Invalid combination: auto_tuning only supported for cutlass")

if compute_capability[0] == 11 and (
input_dtype == torch.float8_e5m2 or mat2_dtype == torch.float8_e5m2
):
pytest.skip(
"Invalid combination: only cutlass supports SM110 which does not support e5m2"
)
Comment on lines +32 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The skip message is a bit confusing. It says "only cutlass supports SM110", but this pull request appears to add SM110 support for cublas and cudnn backends as well. If the intention is that none of the backends support e5m2 on SM110, a clearer message would be helpful to avoid confusion.

Suggested change
if compute_capability[0] == 11 and (
input_dtype == torch.float8_e5m2 or mat2_dtype == torch.float8_e5m2
):
pytest.skip(
"Invalid combination: only cutlass supports SM110 which does not support e5m2"
)
if compute_capability[0] == 11 and (
input_dtype == torch.float8_e5m2 or mat2_dtype == torch.float8_e5m2
):
pytest.skip(
"e5m2 is not supported on SM110 for bmm_fp8 by any of the available backends."
)

input = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16)
input_fp8, input_inv_s = to_float8(input, dtype=input_dtype)

Expand Down
Loading