Skip to content

Commit 2dd82ea

Browse files
committed
Add MXFP8 scale checks in _check_mm_mxfp8_problem_size
1 parent b0455bf commit 2dd82ea

2 files changed

Lines changed: 40 additions & 1 deletion

File tree

flashinfer/gemm/gemm_base.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2513,6 +2513,45 @@ def _check_mm_mxfp8_problem_size(
25132513
if b_descale.dtype != torch.uint8:
25142514
raise ValueError(f"b_descale must be a uint8 tensor, got {b_descale.dtype=}")
25152515

2516+
sf_vec_size = 32
2517+
if a_descale.ndim == 1:
2518+
expected_len = _mxfp8_swizzled_scale_len(a.shape[0], a.shape[1])
2519+
if a_descale.shape[0] != expected_len:
2520+
raise ValueError(
2521+
"a_descale shape mismatch for swizzled layout. "
2522+
f"Expected {(expected_len,)}, got {a_descale.shape}."
2523+
)
2524+
elif a_descale.ndim == 2:
2525+
expected_shape = (a.shape[0], a.shape[1] // sf_vec_size)
2526+
if a_descale.shape != expected_shape:
2527+
raise ValueError(
2528+
"a_descale shape mismatch for non-swizzled layout. "
2529+
f"Expected {expected_shape}, got {a_descale.shape}."
2530+
)
2531+
else:
2532+
raise ValueError(
2533+
f"a_descale must be 1D (swizzled) or 2D (non-swizzled), got {a_descale.shape}."
2534+
)
2535+
2536+
if b_descale.ndim == 1:
2537+
expected_len = _mxfp8_swizzled_scale_len(b.shape[1], b.shape[0])
2538+
if b_descale.shape[0] != expected_len:
2539+
raise ValueError(
2540+
"b_descale shape mismatch for swizzled layout. "
2541+
f"Expected {(expected_len,)}, got {b_descale.shape}."
2542+
)
2543+
elif b_descale.ndim == 2:
2544+
expected_shape = (b.shape[0] // sf_vec_size, b.shape[1])
2545+
if b_descale.shape != expected_shape:
2546+
raise ValueError(
2547+
"b_descale shape mismatch for non-swizzled layout. "
2548+
f"Expected {expected_shape}, got {b_descale.shape}."
2549+
)
2550+
else:
2551+
raise ValueError(
2552+
f"b_descale must be 1D (swizzled) or 2D (non-swizzled), got {b_descale.shape}."
2553+
)
2554+
25162555
if out is not None:
25172556
expected_shape = (a.shape[0], b.shape[1])
25182557
if out.shape != expected_shape:

tests/gemm/test_mm_mxfp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def test_mm_mxfp8_invalid_ndim():
214214
a_descale = a_scale.view(1, -1, 1)
215215
b_descale = b_scale.view(1, -1, 1)
216216
with pytest.raises(
217-
AssertionError, match="a_descale must be 1D \(swizzled\) or 2D \(non-swizzled\)"
217+
ValueError, match="a_descale must be 1D \(swizzled\) or 2D \(non-swizzled\)"
218218
):
219219
mm_mxfp8(
220220
a_mx,

0 commit comments

Comments
 (0)