@@ -1267,9 +1267,6 @@ def testMmMxfp8(args):
12671267
12681268 ## Parse input arguments
12691269 backends = args .backends
1270- if "cutlass" not in backends :
1271- # Only cutlass is supported for mm_mxfp8
1272- backends = ["cutlass" ]
12731270 m = args .m
12741271 n = args .n
12751272 k = args .k
@@ -1308,11 +1305,17 @@ def testMmMxfp8(args):
13081305
13091306 ## Prepare input tensors
13101307 # Use swizzled layout for optimal performance
1308+ is_sf_swizzled_layout = True
1309+
13111310 input = torch .randn ([m , k ], device = device , dtype = torch .bfloat16 )
1312- input_mxfp8 , input_scale = mxfp8_quantize (input , is_sf_swizzled_layout = True )
1311+ input_mxfp8 , input_scale = mxfp8_quantize (
1312+ input , is_sf_swizzled_layout = is_sf_swizzled_layout
1313+ )
13131314
13141315 mat2 = torch .randn ([n , k ], device = device , dtype = torch .bfloat16 )
1315- mat2_mxfp8 , mat2_scale = mxfp8_quantize (mat2 , is_sf_swizzled_layout = True )
1316+ mat2_mxfp8 , mat2_scale = mxfp8_quantize (
1317+ mat2 , is_sf_swizzled_layout = is_sf_swizzled_layout
1318+ )
13161319
13171320 if args .verbose >= 2 :
13181321 print (f"[VVERBOSE] { input_mxfp8 .shape = } " )
@@ -1379,7 +1382,8 @@ def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale):
13791382 input_args = (cur_backend , input_mxfp8 , mat2_mxfp8 , input_scale , mat2_scale ),
13801383 )
13811384
1382- min_cos_sim = 0.9
1385+ # Minimum cosine similarity for swizzled layout
1386+ min_cos_sim = 0.98
13831387
13841388 tested_backends = list (outputs .keys ())
13851389 tested_outputs = list (outputs .values ())
@@ -1393,13 +1397,15 @@ def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale):
13931397 )
13941398 if cos_sim < min_cos_sim :
13951399 print (
1396- f"[ERROR] Output tensor mismatch between backends { tested_backends [0 ]} and { tested_backends [i ]} "
1400+ "[ERROR] Output tensor mismatch between reference "
1401+ f"{ tested_backends [0 ]} and backend { tested_backends [i ]} "
13971402 )
13981403 if not args .allow_output_mismatch :
13991404 raise AssertionError (
1400- f"[ERROR] Backend { tested_backends [i ]} output mismatch with cos_sim={ cos_sim } "
1405+ "[ERROR] Output tensor mismatch between reference "
1406+ f"{ tested_backends [0 ]} and backend { tested_backends [i ]} "
1407+ f"with { cos_sim = } (expected >= { min_cos_sim } )"
14011408 )
1402-
14031409 for backend in backends :
14041410 backend_name = backend + (
14051411 "_autotune"
0 commit comments