Skip to content

Conversation

@farzad-openai
Copy link

@farzad-openai farzad-openai commented Oct 31, 2025

This PR improves the throughput of mxfp8 upcast and downcast operations. I included a commit from @jongsoo-openai (original PR here) and added improvements below on top of it. The PR is functionally a no-op, which is verified by the tests in python/triton_kernels/tests/test_mxfp.py.

Upcast improvements:

  • Added native packed e2m1 conversion to fp16 (for Blackwell+).
  • Added tensor descriptors to utilize TMA for reading the input mxfp value tensor and writing the output.
    • Note that this addition required adding padding for the innermost dimension for IO tensors not adhering to tensor descriptor specification requirements, and unpadding the output afterwards.
  • Tuned tile dimensions and num_warps.

Downcast improvements:

  • Enabled grouped store of mxfp4 value tensors, as opposed to byte-level stores.
  • Tuned the tile dimensions as well as num_warps.
  • Unfortunately, as opposed to upcast, tensor descriptors were unable to give a consistent performance improvement.

I left performance tuning as a TODO for a subsequent PR.

Performance comparison (BW, in GBps)

Done via python/triton_kernels/tests/test_mxfp.py.

Before -- GB200

MXFP8 (e4m3fn):
   M     N  quant_dtype            quant_bw_bfloat16    quant_bw_float16    dequant_bw_bfloat16    dequant_bw_float16
----  ----  -------------------  -------------------  ------------------  ---------------------  --------------------
1024  8192  torch.float8_e4m3fn              1985.94             2053.35                2154.61               2347.56
4096  8192  torch.float8_e4m3fn              3479.79             3518.71                3243.02               3753.85

MXFP4 (e2m1):
   M     N  quant_dtype      quant_bw_bfloat16    quant_bw_float16    dequant_bw_bfloat16    dequant_bw_float16
----  ----  -------------  -------------------  ------------------  ---------------------  --------------------
1024  8192  torch.uint8                808.089             815.124                647.589               713.9
4096  8192  torch.uint8               1045.23             1041.91                 811.089               888.624

After -- GB200

MXFP8 (e4m3fn):
   M     N  quant_dtype            quant_bw_bfloat16    quant_bw_float16    dequant_bw_bfloat16    dequant_bw_float16
----  ----  -------------------  -------------------  ------------------  ---------------------  --------------------
1024  8192  torch.float8_e4m3fn              2259.86             2404.99                2119.76               2361.66
4096  8192  torch.float8_e4m3fn              4106.69             4268.29                4038.16               4059

MXFP4 (e2m1):
   M     N  quant_dtype      quant_bw_bfloat16    quant_bw_float16    dequant_bw_bfloat16    dequant_bw_float16
----  ----  -------------  -------------------  ------------------  ---------------------  --------------------
1024  8192  torch.uint8                1334.75             1332.03                1424.7                1397.36
4096  8192  torch.uint8                2027.41             2028.98                2097.15               2275.56

Before -- H100

MXFP8 (e4m3fn):
   M     N  quant_dtype            quant_bw_bfloat16    quant_bw_float16    dequant_bw_bfloat16    dequant_bw_float16
----  ----  -------------------  -------------------  ------------------  ---------------------  --------------------
1024  8192  torch.float8_e4m3fn              1250.29             1244.35                1595.2                1588.75
4096  8192  torch.float8_e4m3fn              1805.81             1799.62                2080.51               2118.34

MXFP4 (e2m1):
   M     N  quant_dtype      quant_bw_bfloat16    quant_bw_float16    dequant_bw_bfloat16    dequant_bw_float16
----  ----  -------------  -------------------  ------------------  ---------------------  --------------------
1024  8192  torch.uint8                418.493             416.102                572.367               627.739
4096  8192  torch.uint8                489.531             490.08                 687.861               758.08

After -- H100

MXFP8 (e4m3fn):
   M     N  quant_dtype            quant_bw_bfloat16    quant_bw_float16    dequant_bw_bfloat16    dequant_bw_float16
----  ----  -------------------  -------------------  ------------------  ---------------------  --------------------
1024  8192  torch.float8_e4m3fn              1604.96             1624.86                1732.23               1751.52
4096  8192  torch.float8_e4m3fn              2347.56             2337.09                2386.74               2292.8

MXFP4 (e2m1):
   M     N  quant_dtype      quant_bw_bfloat16    quant_bw_float16    dequant_bw_bfloat16    dequant_bw_float16
----  ----  -------------  -------------------  ------------------  ---------------------  --------------------
1024  8192  torch.uint8                731.429             745.575                892.861               917.871
4096  8192  torch.uint8                882.343             894.995               1102.37               1165.08

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because test_mxfp.py already has coverage.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@farzad-openai farzad-openai changed the title Mxfp conversions speedup [MXFP] mxfp conversions speedup Nov 1, 2025
@farzad-openai farzad-openai force-pushed the mxfp_conversions_speedup branch from 49d6715 to e5855e7 Compare November 4, 2025 19:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants