Skip to content

Add support for float8_e4m3fnuz and float8_e5m2fnuz.#3200

Closed
jakeh-gc wants to merge 1 commit into
openxla:mainfrom
jakeh-gc:fp8_fnuz
Closed

Add support for float8_e4m3fnuz and float8_e5m2fnuz.#3200
jakeh-gc wants to merge 1 commit into
openxla:mainfrom
jakeh-gc:fp8_fnuz

Conversation

@jakeh-gc
Copy link
Copy Markdown
Contributor

This adds support for the two FP8 types float8_e4m3fnuz and float8_e5m2fnuz to XLA similar to float8_e4m3fn, float8_e4m3b11, and float8_e5m2.

@github-actions github-actions Bot added the kokoro:force-run Forces CI to rerun label May 25, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label May 25, 2023
@jakeh-gc
Copy link
Copy Markdown
Contributor Author

I see where that test is failing... I'll get that fixed.

@github-actions github-actions Bot added the kokoro:force-run Forces CI to rerun label May 25, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label May 25, 2023
@reedwm reedwm self-requested a review May 25, 2023 18:23
@reedwm
Copy link
Copy Markdown
Contributor

reedwm commented May 25, 2023

@cantonios can you review the TSL changes (I'll also take a look at them). For some reason, I cannot assign you as a reviewer to this PR.

@burmako is there anything we need to do on the StableHLO side before merging this?

@cantonios
Copy link
Copy Markdown
Contributor

@cantonios can you review the TSL changes (I'll also take a look at them). For some reason, I cannot assign you as a reviewer to this PR.

Probably because I'm not part of the openxla team. But yeah, happy to.

Copy link
Copy Markdown
Contributor

@cantonios cantonios left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change will likely need to wait until TensorFlow/TSL switches over to use ml_dtypes (as described in another comment). I'm in the process of doing this... my best estimate is within the next week or so. At that point, none of the TSL changes here will be necessary.

Comment thread third_party/tsl/tsl/platform/float8.h Outdated
Comment thread third_party/tsl/tsl/python/lib/core/custom_casts.cc Outdated
Comment thread third_party/tsl/tsl/python/lib/core/float8.cc Outdated
Comment thread third_party/tsl/tsl/python/lib/core/float8.h Outdated
@burmako
Copy link
Copy Markdown

burmako commented May 25, 2023

@reedwm Thank you for reaching out! These types have gone through the StableHLO RFC process and are now part of the StableHLO spec, so I don't think anything further is needed on the StableHLO side.

Comment thread xla/translate/hlo_to_mhlo/hlo_utils.cc Outdated
Comment thread xla/translate/mhlo_to_hlo/type_to_shape.cc Outdated
@github-actions github-actions Bot added the kokoro:force-run Forces CI to rerun label May 25, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label May 25, 2023
Copy link
Copy Markdown
Contributor

@reedwm reedwm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't reviewed elemental_ir_emitter.cc yet, I'll try to get to that tomorrow.

Comment thread xla/primitive_util.h Outdated
Comment thread xla/primitive_util_test.cc Outdated
Comment thread xla/util_test.cc Outdated
Comment thread xla/xla_data.proto Outdated
@github-actions github-actions Bot added the kokoro:force-run Forces CI to rerun label May 26, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label May 26, 2023
@github-actions github-actions Bot added the kokoro:force-run Forces CI to rerun label May 26, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label May 26, 2023
Copy link
Copy Markdown

@burmako burmako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for MHLO/HLO parity

Copy link
Copy Markdown
Contributor

@reedwm reedwm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests to convert_test.cc and constants_test.cc similar to existing F8 tests in those files.

Also note I didn't review TSL changes based on @cantonios's comments that the dependency on TSL float types would go away.

Comment thread xla/service/elemental_ir_emitter.cc Outdated
Comment thread xla/service/elemental_ir_emitter.cc Outdated
Comment thread xla/service/elemental_ir_emitter.cc Outdated
Comment thread xla/service/elemental_ir_emitter.cc Outdated
Comment thread xla/service/elemental_ir_emitter.cc Outdated
Comment thread xla/service/elemental_ir_emitter.cc Outdated
Comment thread xla/service/elemental_ir_emitter.cc Outdated
Comment thread xla/primitive_util.h Outdated
Comment thread xla/util_test.cc Outdated
Copy link
Copy Markdown
Contributor

@reedwm reedwm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests to convert_test.cc and constants_test.cc similar to existing F8 tests in those files.

Also note I didn't review TSL changes based on @cantonios's comments that the dependency on TSL float types would go away.

@github-actions github-actions Bot added the kokoro:force-run Forces CI to rerun label Jun 15, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Jun 15, 2023
copybara-service Bot pushed a commit to google/tsl that referenced this pull request Jun 27, 2023
Imported from GitHub PR openxla/xla#3200

This adds support for the two FP8 types `float8_e4m3fnuz` and `float8_e5m2fnuz` to XLA similar to `float8_e4m3fn`, `float8_e4m3b11`, and `float8_e5m2`.
Copybara import of the project:

--
3b96f8fe219c1ac1bec5c4b99ff9c51148706981 by Jake Hall <[email protected]>:

Add support for float8_e4m3fnuz and float8_e5m2fnuz.

Merging this change closes #3200

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#3200 from jakeh-gc:fp8_fnuz 3b96f8fe219c1ac1bec5c4b99ff9c51148706981
PiperOrigin-RevId: 543802274
@reedwm
Copy link
Copy Markdown
Contributor

reedwm commented Jun 27, 2023

@jakeh-gc, I'm still working on merging this. Please don't commit to the PR in the meantime, as it's hard to update the internal changes with the PR changes.

copybara-service Bot pushed a commit to google/tsl that referenced this pull request Jun 28, 2023
Imported from GitHub PR openxla/xla#3200

This adds support for the two FP8 types `float8_e4m3fnuz` and `float8_e5m2fnuz` to XLA similar to `float8_e4m3fn`, `float8_e4m3b11`, and `float8_e5m2`.
Copybara import of the project:

--
3b96f8fe219c1ac1bec5c4b99ff9c51148706981 by Jake Hall <[email protected]>:

Add support for float8_e4m3fnuz and float8_e5m2fnuz.

Merging this change closes #3200

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#3200 from jakeh-gc:fp8_fnuz 3b96f8fe219c1ac1bec5c4b99ff9c51148706981
PiperOrigin-RevId: 543802274
copybara-service Bot pushed a commit to google/tsl that referenced this pull request Jun 28, 2023
Imported from GitHub PR openxla/xla#3200

This adds support for the two FP8 types `float8_e4m3fnuz` and `float8_e5m2fnuz` to XLA similar to `float8_e4m3fn`, `float8_e4m3b11`, and `float8_e5m2`.
Copybara import of the project:

--
3b96f8fe219c1ac1bec5c4b99ff9c51148706981 by Jake Hall <[email protected]>:

Add support for float8_e4m3fnuz and float8_e5m2fnuz.

Merging this change closes #3200

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#3200 from jakeh-gc:fp8_fnuz 3b96f8fe219c1ac1bec5c4b99ff9c51148706981
PiperOrigin-RevId: 543802274
copybara-service Bot pushed a commit to google/tsl that referenced this pull request Jun 28, 2023
Imported from GitHub PR openxla/xla#3200

This adds support for the two FP8 types `float8_e4m3fnuz` and `float8_e5m2fnuz` to XLA similar to `float8_e4m3fn`, `float8_e4m3b11`, and `float8_e5m2`.
Copybara import of the project:

--
3b96f8fe219c1ac1bec5c4b99ff9c51148706981 by Jake Hall <[email protected]>:

Add support for float8_e4m3fnuz and float8_e5m2fnuz.

Merging this change closes #3200

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#3200 from jakeh-gc:fp8_fnuz 3b96f8fe219c1ac1bec5c4b99ff9c51148706981
PiperOrigin-RevId: 543802274
copybara-service Bot pushed a commit to google/tsl that referenced this pull request Jun 28, 2023
Imported from GitHub PR openxla/xla#3200

This adds support for the two FP8 types `float8_e4m3fnuz` and `float8_e5m2fnuz` to XLA similar to `float8_e4m3fn`, `float8_e4m3b11`, and `float8_e5m2`.
Copybara import of the project:

--
3b96f8fe219c1ac1bec5c4b99ff9c51148706981 by Jake Hall <[email protected]>:

Add support for float8_e4m3fnuz and float8_e5m2fnuz.

Merging this change closes #3200

PiperOrigin-RevId: 544198797
copybara-service Bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jun 29, 2023
Imported from GitHub PR openxla/xla#3200

This adds support for the two FP8 types `float8_e4m3fnuz` and `float8_e5m2fnuz` to XLA similar to `float8_e4m3fn`, `float8_e4m3b11`, and `float8_e5m2`.
Copybara import of the project:

--
3b96f8fe219c1ac1bec5c4b99ff9c51148706981 by Jake Hall <[email protected]>:

Add support for float8_e4m3fnuz and float8_e5m2fnuz.

Merging this change closes #3200

PiperOrigin-RevId: 544198797
copybara-service Bot pushed a commit that referenced this pull request Jun 29, 2023
FUTURE_COPYBARA_INTEGRATE_REVIEW=#3200 from jakeh-gc:fp8_fnuz 3b96f8f
PiperOrigin-RevId: 544197768
@copybara-service copybara-service Bot mentioned this pull request Jun 29, 2023
copybara-service Bot pushed a commit to google/tsl that referenced this pull request Sep 30, 2024
Imported from GitHub PR openxla/xla#16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](openxla/xla#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](openxla/xla#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723027012a816d7e17f268c5f034863696e6 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16585 from apivovarov:float8_e4m3 ec1c723027012a816d7e17f268c5f034863696e6
PiperOrigin-RevId: 680651037
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready to pull PR ready for merge process

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants