diff --git a/third_party/py/ml_dtypes/workspace.bzl b/third_party/py/ml_dtypes/workspace.bzl index 0047319ec..962fb487c 100644 --- a/third_party/py/ml_dtypes/workspace.bzl +++ b/third_party/py/ml_dtypes/workspace.bzl @@ -7,8 +7,8 @@ float8 varieties, and int4. load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - ML_DTYPES_COMMIT = "215c9f02a121e6286662b2efd30546c71054d5e5" - ML_DTYPES_SHA256 = "4a03237ef6345e1467a33d126176b9c6a7539b0f60a34b344f39b3c9e8b82438" + ML_DTYPES_COMMIT = "0fa5313b65efe848c5968a15dd37dd220cc29567" + ML_DTYPES_SHA256 = "69c562bb961a21d92357c7709430553c226caac75a751c0aa52955ca14ce8641" tf_http_archive( name = "ml_dtypes", build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD", diff --git a/tsl/platform/BUILD b/tsl/platform/BUILD index 8355483dc..09adde466 100644 --- a/tsl/platform/BUILD +++ b/tsl/platform/BUILD @@ -984,6 +984,7 @@ cc_library( deps = [ "@ml_dtypes//:float8", "@ml_dtypes//:intn", + "@ml_dtypes//:mxfloat", ], ) diff --git a/tsl/platform/ml_dtypes.h b/tsl/platform/ml_dtypes.h index a6a1b56af..a03fa0244 100644 --- a/tsl/platform/ml_dtypes.h +++ b/tsl/platform/ml_dtypes.h @@ -18,8 +18,10 @@ limitations under the License. #include "ml_dtypes/include/float8.h" // from @ml_dtypes #include "ml_dtypes/include/intn.h" // from @ml_dtypes +#include "ml_dtypes/include/mxfloat.h" // from @ml_dtypes namespace tsl { +using float4_e2m1fn = ::ml_dtypes::float4_e2m1fn; using float8_e3m4 = ::ml_dtypes::float8_e3m4; using float8_e4m3 = ::ml_dtypes::float8_e4m3; using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn; @@ -27,6 +29,7 @@ using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz; using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz; using float8_e5m2 = ::ml_dtypes::float8_e5m2; using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz; +using float8_e8m0fnu = ::ml_dtypes::float8_e8m0fnu; using int1 = ::ml_dtypes::int1; using uint1 = ::ml_dtypes::uint1;