Skip to content

Commit 9f1813d

Browse files
committed
Fix issues in testMmMxfp8
Signed-off-by: Daniel Serebrenik <[email protected]>
1 parent a647f59 commit 9f1813d

1 file changed

Lines changed: 15 additions & 9 deletions

File tree

benchmarks/routines/gemm.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,9 +1267,6 @@ def testMmMxfp8(args):
12671267

12681268
## Parse input arguments
12691269
backends = args.backends
1270-
if "cutlass" not in backends:
1271-
# Only cutlass is supported for mm_mxfp8
1272-
backends = ["cutlass"]
12731270
m = args.m
12741271
n = args.n
12751272
k = args.k
@@ -1308,11 +1305,17 @@ def testMmMxfp8(args):
13081305

13091306
## Prepare input tensors
13101307
# Use swizzled layout for optimal performance
1308+
is_sf_swizzled_layout = True
1309+
13111310
input = torch.randn([m, k], device=device, dtype=torch.bfloat16)
1312-
input_mxfp8, input_scale = mxfp8_quantize(input, is_sf_swizzled_layout=True)
1311+
input_mxfp8, input_scale = mxfp8_quantize(
1312+
input, is_sf_swizzled_layout=is_sf_swizzled_layout
1313+
)
13131314

13141315
mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16)
1315-
mat2_mxfp8, mat2_scale = mxfp8_quantize(mat2, is_sf_swizzled_layout=True)
1316+
mat2_mxfp8, mat2_scale = mxfp8_quantize(
1317+
mat2, is_sf_swizzled_layout=is_sf_swizzled_layout
1318+
)
13161319

13171320
if args.verbose >= 2:
13181321
print(f"[VVERBOSE] {input_mxfp8.shape = }")
@@ -1379,7 +1382,8 @@ def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale):
13791382
input_args=(cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale),
13801383
)
13811384

1382-
min_cos_sim = 0.9
1385+
# Minimum cosine similarity for swizzled layout
1386+
min_cos_sim = 0.98
13831387

13841388
tested_backends = list(outputs.keys())
13851389
tested_outputs = list(outputs.values())
@@ -1393,13 +1397,15 @@ def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale):
13931397
)
13941398
if cos_sim < min_cos_sim:
13951399
print(
1396-
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}"
1400+
"[ERROR] Output tensor mismatch between reference "
1401+
f"{tested_backends[0]} and backend {tested_backends[i]}"
13971402
)
13981403
if not args.allow_output_mismatch:
13991404
raise AssertionError(
1400-
f"[ERROR] Backend {tested_backends[i]} output mismatch with cos_sim={cos_sim}"
1405+
"[ERROR] Output tensor mismatch between reference "
1406+
f"{tested_backends[0]} and backend {tested_backends[i]} "
1407+
f"with {cos_sim=} (expected >= {min_cos_sim})"
14011408
)
1402-
14031409
for backend in backends:
14041410
backend_name = backend + (
14051411
"_autotune"

0 commit comments

Comments
 (0)