Add support for float8_e4m3fnuz and float8_e5m2fnuz.#3200
Conversation
|
I see where that test is failing... I'll get that fixed. |
|
@cantonios can you review the TSL changes (I'll also take a look at them). For some reason, I cannot assign you as a reviewer to this PR. @burmako is there anything we need to do on the StableHLO side before merging this? |
Probably because I'm not part of the openxla team. But yeah, happy to. |
cantonios
left a comment
There was a problem hiding this comment.
This change will likely need to wait until TensorFlow/TSL switches over to use ml_dtypes (as described in another comment). I'm in the process of doing this... my best estimate is within the next week or so. At that point, none of the TSL changes here will be necessary.
|
@reedwm Thank you for reaching out! These types have gone through the StableHLO RFC process and are now part of the StableHLO spec, so I don't think anything further is needed on the StableHLO side. |
reedwm
left a comment
There was a problem hiding this comment.
Haven't reviewed elemental_ir_emitter.cc yet, I'll try to get to that tomorrow.
reedwm
left a comment
There was a problem hiding this comment.
Please add tests to convert_test.cc and constants_test.cc similar to existing F8 tests in those files.
Also note I didn't review TSL changes based on @cantonios's comments that the dependency on TSL float types would go away.
reedwm
left a comment
There was a problem hiding this comment.
Please add tests to convert_test.cc and constants_test.cc similar to existing F8 tests in those files.
Also note I didn't review TSL changes based on @cantonios's comments that the dependency on TSL float types would go away.
Imported from GitHub PR openxla/xla#3200 This adds support for the two FP8 types `float8_e4m3fnuz` and `float8_e5m2fnuz` to XLA similar to `float8_e4m3fn`, `float8_e4m3b11`, and `float8_e5m2`. Copybara import of the project: -- 3b96f8fe219c1ac1bec5c4b99ff9c51148706981 by Jake Hall <[email protected]>: Add support for float8_e4m3fnuz and float8_e5m2fnuz. Merging this change closes #3200 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#3200 from jakeh-gc:fp8_fnuz 3b96f8fe219c1ac1bec5c4b99ff9c51148706981 PiperOrigin-RevId: 543802274
|
@jakeh-gc, I'm still working on merging this. Please don't commit to the PR in the meantime, as it's hard to update the internal changes with the PR changes. |
Imported from GitHub PR openxla/xla#3200 This adds support for the two FP8 types `float8_e4m3fnuz` and `float8_e5m2fnuz` to XLA similar to `float8_e4m3fn`, `float8_e4m3b11`, and `float8_e5m2`. Copybara import of the project: -- 3b96f8fe219c1ac1bec5c4b99ff9c51148706981 by Jake Hall <[email protected]>: Add support for float8_e4m3fnuz and float8_e5m2fnuz. Merging this change closes #3200 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#3200 from jakeh-gc:fp8_fnuz 3b96f8fe219c1ac1bec5c4b99ff9c51148706981 PiperOrigin-RevId: 543802274
Imported from GitHub PR openxla/xla#3200 This adds support for the two FP8 types `float8_e4m3fnuz` and `float8_e5m2fnuz` to XLA similar to `float8_e4m3fn`, `float8_e4m3b11`, and `float8_e5m2`. Copybara import of the project: -- 3b96f8fe219c1ac1bec5c4b99ff9c51148706981 by Jake Hall <[email protected]>: Add support for float8_e4m3fnuz and float8_e5m2fnuz. Merging this change closes #3200 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#3200 from jakeh-gc:fp8_fnuz 3b96f8fe219c1ac1bec5c4b99ff9c51148706981 PiperOrigin-RevId: 543802274
Imported from GitHub PR openxla/xla#3200 This adds support for the two FP8 types `float8_e4m3fnuz` and `float8_e5m2fnuz` to XLA similar to `float8_e4m3fn`, `float8_e4m3b11`, and `float8_e5m2`. Copybara import of the project: -- 3b96f8fe219c1ac1bec5c4b99ff9c51148706981 by Jake Hall <[email protected]>: Add support for float8_e4m3fnuz and float8_e5m2fnuz. Merging this change closes #3200 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#3200 from jakeh-gc:fp8_fnuz 3b96f8fe219c1ac1bec5c4b99ff9c51148706981 PiperOrigin-RevId: 543802274
Imported from GitHub PR openxla/xla#3200 This adds support for the two FP8 types `float8_e4m3fnuz` and `float8_e5m2fnuz` to XLA similar to `float8_e4m3fn`, `float8_e4m3b11`, and `float8_e5m2`. Copybara import of the project: -- 3b96f8fe219c1ac1bec5c4b99ff9c51148706981 by Jake Hall <[email protected]>: Add support for float8_e4m3fnuz and float8_e5m2fnuz. Merging this change closes #3200 PiperOrigin-RevId: 544198797
Imported from GitHub PR openxla/xla#3200 This adds support for the two FP8 types `float8_e4m3fnuz` and `float8_e5m2fnuz` to XLA similar to `float8_e4m3fn`, `float8_e4m3b11`, and `float8_e5m2`. Copybara import of the project: -- 3b96f8fe219c1ac1bec5c4b99ff9c51148706981 by Jake Hall <[email protected]>: Add support for float8_e4m3fnuz and float8_e5m2fnuz. Merging this change closes #3200 PiperOrigin-RevId: 544198797
Imported from GitHub PR openxla/xla#16585 This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler). ### `f8E4M3` type follows IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ### `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` ### Testing: ``` bazel test \ //xla:array2d_test \ //xla:fp_util_test \ //xla:literal_comparison_test \ //xla:literal_test \ //xla/mlir/utils:type_util_test \ //xla:primitive_util_test \ //xla/python/ifrt:dtype_test \ //xla/python:xla_client_test \ //xla/service:elemental_ir_emitter_test \ //xla/service:float_normalization_test \ //xla/service/gpu/tests:float_conversions_test \ //xla/tests:array_elementwise_ops_test \ //xla/tests:constants_test \ //xla/tests:convert_test \ //xla/tests:float8_test \ //xla:util_test bazel test \ //xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \ //xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \ //xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test ``` ### Related PRs: - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged) - StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged) - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved) - XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template) - JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review) Copybara import of the project: -- ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>: Add support for float8_e4m3 and float8_e3m4 types Merging this change closes #16585 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6 PiperOrigin-RevId: 680651037
This adds support for the two FP8 types
float8_e4m3fnuzandfloat8_e5m2fnuzto XLA similar tofloat8_e4m3fn,float8_e4m3b11, andfloat8_e5m2.