@@ -2513,6 +2513,45 @@ def _check_mm_mxfp8_problem_size(
25132513 if b_descale .dtype != torch .uint8 :
25142514 raise ValueError (f"b_descale must be a uint8 tensor, got { b_descale .dtype = } " )
25152515
2516+ sf_vec_size = 32
2517+ if a_descale .ndim == 1 :
2518+ expected_len = _mxfp8_swizzled_scale_len (a .shape [0 ], a .shape [1 ])
2519+ if a_descale .shape [0 ] != expected_len :
2520+ raise ValueError (
2521+ "a_descale shape mismatch for swizzled layout. "
2522+ f"Expected { (expected_len ,)} , got { a_descale .shape } ."
2523+ )
2524+ elif a_descale .ndim == 2 :
2525+ expected_shape = (a .shape [0 ], a .shape [1 ] // sf_vec_size )
2526+ if a_descale .shape != expected_shape :
2527+ raise ValueError (
2528+ "a_descale shape mismatch for non-swizzled layout. "
2529+ f"Expected { expected_shape } , got { a_descale .shape } ."
2530+ )
2531+ else :
2532+ raise ValueError (
2533+ f"a_descale must be 1D (swizzled) or 2D (non-swizzled), got { a_descale .shape } ."
2534+ )
2535+
2536+ if b_descale .ndim == 1 :
2537+ expected_len = _mxfp8_swizzled_scale_len (b .shape [1 ], b .shape [0 ])
2538+ if b_descale .shape [0 ] != expected_len :
2539+ raise ValueError (
2540+ "b_descale shape mismatch for swizzled layout. "
2541+ f"Expected { (expected_len ,)} , got { b_descale .shape } ."
2542+ )
2543+ elif b_descale .ndim == 2 :
2544+ expected_shape = (b .shape [0 ] // sf_vec_size , b .shape [1 ])
2545+ if b_descale .shape != expected_shape :
2546+ raise ValueError (
2547+ "b_descale shape mismatch for non-swizzled layout. "
2548+ f"Expected { expected_shape } , got { b_descale .shape } ."
2549+ )
2550+ else :
2551+ raise ValueError (
2552+ f"b_descale must be 1D (swizzled) or 2D (non-swizzled), got { b_descale .shape } ."
2553+ )
2554+
25162555 if out is not None :
25172556 expected_shape = (a .shape [0 ], b .shape [1 ])
25182557 if out .shape != expected_shape :
0 commit comments