-
Notifications
You must be signed in to change notification settings - Fork 759
tests: bmm_fp8 for SM110 #2538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tests: bmm_fp8 for SM110 #2538
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3394,7 +3394,7 @@ def mm_fp4( | |
| return out | ||
|
|
||
|
|
||
| @supported_compute_capability([89, 90, 100, 103, 120, 121]) | ||
| @supported_compute_capability([89, 90, 100, 103, 110, 120, 121]) | ||
| def _cudnn_bmm_fp8_requirement( | ||
| A: torch.Tensor, | ||
| B: torch.Tensor, | ||
|
|
@@ -3408,7 +3408,7 @@ def _cudnn_bmm_fp8_requirement( | |
| return True | ||
|
|
||
|
|
||
| @supported_compute_capability([89, 90, 100, 103, 120, 121]) | ||
| @supported_compute_capability([89, 90, 100, 103, 110, 120, 121]) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the from ..utils import get_compute_capability
# ... inside _cublas_bmm_fp8_requirement
major, _ = get_compute_capability(A.device)
if major == 11 and (A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2):
raise ValueError("bmm_fp8 with e5m2 is not supported on SM110 for cublas backend") |
||
| def _cublas_bmm_fp8_requirement( | ||
| A: torch.Tensor, | ||
| B: torch.Tensor, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -29,7 +29,12 @@ def test_bmm_fp8(b, m, n, k, input_dtype, mat2_dtype, res_dtype, backend, auto_t | |||||||||||||||||||||||||
| pytest.skip("Invalid combination: cutlass does not support e5m2") | ||||||||||||||||||||||||||
| if auto_tuning and backend != "cutlass": | ||||||||||||||||||||||||||
| pytest.skip("Invalid combination: auto_tuning only supported for cutlass") | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if compute_capability[0] == 11 and ( | ||||||||||||||||||||||||||
| input_dtype == torch.float8_e5m2 or mat2_dtype == torch.float8_e5m2 | ||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||
| pytest.skip( | ||||||||||||||||||||||||||
| "Invalid combination: only cutlass supports SM110 which does not support e5m2" | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
Comment on lines
+32
to
+37
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The skip message is a bit confusing. It says "only cutlass supports SM110", but this pull request appears to add SM110 support for
Suggested change
|
||||||||||||||||||||||||||
| input = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16) | ||||||||||||||||||||||||||
| input_fp8, input_inv_s = to_float8(input, dtype=input_dtype) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The corresponding test
test_bmm_fp8skips all tests on SM110 that usee5m2dtypes. This suggests that thecudnnbackend forbmm_fp8might not supporte5m2on SM110. If this is the case, it would be better to add a check inside_cudnn_bmm_fp8_requirementto explicitly reject this combination and raise aValueError, for consistency with_cutlass_bmm_fp8_requirement. This would make the API more robust. For example: