-
Notifications
You must be signed in to change notification settings - Fork 6
Force FP8 gemms into F16 dot #394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: rocm-jaxlib-v0.6.0
Are you sure you want to change the base?
Conversation
i-chaochen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it's gfx related changes, please add related gfx instead of general skipped.
|
XLA is supposed to support such scenario and it should not fall back to FP16 GEMM. |
|
Hi @phambinhfin, to echo @ScXfjiang's question, I'd like to understand, under what conditions would the FP8 rewriter not be able to rewrite dot with FP8 inputs into cublasLt custom call on MI35X? And similarly, for MI300, under what conditions would it not rewrite dot with nano_fp8 inputs into the custom GEMM? Should it be the case in a properly installed ROCm environment that nanoo_fp8/fp8 GEMMs get used? |
There are multiple factors to decide if a hipblalst FP8 gemm custom call can be generated, e.g., data types, rocm versions, if a specific pattern is triggered. You can check the main logic here: But if you only care about the results, you can enable this log: |
fdb0fdc to
e11a81b
Compare
Adds a temporary workaround to disable FP8 GEMM operations
on gfx950 (MI355X) architecture, because FP8 operations combine with quantization produce NaN issues.
While investigating and fixing the root cause, temporarily forcing them to
fall back to FP16 instead.
e11a81b to
9919fea
Compare
This PR fixes FP8 training issues on gfx950 (MI350) architecture where ROCm's FP8 implementation doesn't properly handle conversion operations. The current ROCm FP8 support skips conversion handling, causing training nan issue. This fix ensures FP8 operations fall back to FP16 while waiting for XLA support FP8 convert to improve trainign performance and avoid nan issue , this link with ticket
(The nan isuse does not happen in MI300 now because MI300 handle FP8 dot into FP16)