Skip to content

Conversation

@phambinhfin
Copy link

This PR fixes FP8 training issues on gfx950 (MI350) architecture where ROCm's FP8 implementation doesn't properly handle conversion operations. The current ROCm FP8 support skips conversion handling, causing training nan issue. This fix ensures FP8 operations fall back to FP16 while waiting for XLA support FP8 convert to improve trainign performance and avoid nan issue , this link with ticket
(The nan isuse does not happen in MI300 now because MI300 handle FP8 dot into FP16)

Copy link
Collaborator

@i-chaochen i-chaochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's gfx related changes, please add related gfx instead of general skipped.

@ScXfjiang
Copy link

XLA is supposed to support such scenario and it should not fall back to FP16 GEMM.

@yeandy
Copy link

yeandy commented Oct 6, 2025

Hi @phambinhfin, to echo @ScXfjiang's question, I'd like to understand, under what conditions would the FP8 rewriter not be able to rewrite dot with FP8 inputs into cublasLt custom call on MI35X? And similarly, for MI300, under what conditions would it not rewrite dot with nano_fp8 inputs into the custom GEMM?

Should it be the case in a properly installed ROCm environment that nanoo_fp8/fp8 GEMMs get used?

@ScXfjiang
Copy link

Hi @phambinhfin, to echo @ScXfjiang's question, I'd like to understand, under what conditions would the FP8 rewriter not be able to rewrite dot with FP8 inputs into cublasLt custom call on MI35X? And similarly, for MI300, under what conditions would it not rewrite dot with nano_fp8 inputs into the custom GEMM?

Should it be the case in a properly installed ROCm environment that nanoo_fp8/fp8 GEMMs get used?

There are multiple factors to decide if a hipblalst FP8 gemm custom call can be generated, e.g., data types, rocm versions, if a specific pattern is triggered.

You can check the main logic here:
https://github.com/openxla/xla/blob/e80be278d2b01f6b1b92102785b1f74ad10dfc92/xla/service/gpu/transforms/gemm_rewriter.cc#L1078

But if you only care about the results, you can enable this log:
https://github.com/openxla/xla/blob/e80be278d2b01f6b1b92102785b1f74ad10dfc92/xla/service/gpu/transforms/gemm_rewriter.cc#L1429

    Adds a temporary workaround to disable FP8 GEMM operations
    on gfx950 (MI355X) architecture, because FP8 operations combine with quantization produce NaN issues.
    While investigating and fixing the root cause, temporarily forcing them to
    fall back to FP16 instead.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants