diff --git a/CHANGELOG.md b/CHANGELOG.md index 9baeb540..e6c0eb72 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): * Added new 8-bit float type following IEEE 754 convention: `ml_dtypes.float8_e4m3`. +* Fix outputs of float `divmod` and `floor_divide` when denominator is zero. ## [0.4.0] - 2024-04-1 diff --git a/ml_dtypes/_src/ufuncs.h b/ml_dtypes/_src/ufuncs.h index e3262091..ef6f07e1 100644 --- a/ml_dtypes/_src/ufuncs.h +++ b/ml_dtypes/_src/ufuncs.h @@ -168,7 +168,13 @@ struct TrueDivide { inline std::pair divmod(float a, float b) { if (b == 0.0f) { float nan = std::numeric_limits::quiet_NaN(); - return {nan, nan}; + float inf = std::numeric_limits::infinity(); + + if (std::isnan(a) || (a == 0.0f)) { + return {nan, nan}; + } else { + return {std::signbit(a) == std::signbit(b) ? inf : -inf, nan}; + } } float mod = std::fmod(a, b); float div = (a - mod) / b; diff --git a/ml_dtypes/tests/custom_float_test.py b/ml_dtypes/tests/custom_float_test.py index cd76b3ec..b94dc0da 100644 --- a/ml_dtypes/tests/custom_float_test.py +++ b/ml_dtypes/tests/custom_float_test.py @@ -830,6 +830,47 @@ def testDivmod(self, float_type): float_type=float_type, ) + @ignore_warning(category=RuntimeWarning, message="invalid value encountered") + @ignore_warning(category=RuntimeWarning, message="divide by zero encountered") + def testDivmodCornerCases(self, float_type): + x = np.array( + [-np.nan, -np.inf, -1.0, -0.0, 0.0, 1.0, np.inf, np.nan], + dtype=float_type, + ) + xf32 = x.astype("float32") + out = np.divmod.outer(x, x) + expected = np.divmod.outer(xf32, xf32) + numpy_assert_allclose( + out[0], + truncate(expected[0], float_type=float_type), + rtol=0.0, + float_type=float_type, + ) + numpy_assert_allclose( + out[1], + truncate(expected[1], float_type=float_type), + rtol=0.0, + float_type=float_type, + ) + + @ignore_warning(category=RuntimeWarning, message="invalid value encountered") + @ignore_warning(category=RuntimeWarning, message="divide by zero encountered") + def testFloordivCornerCases(self, float_type): + # Regression test for https://github.com/jax-ml/ml_dtypes/issues/170 + x = np.array( + [-np.nan, -np.inf, -1.0, -0.0, 0.0, 1.0, np.inf, np.nan], + dtype=float_type, + ) + xf32 = x.astype("float32") + out = np.floor_divide.outer(x, x) + expected = np.floor_divide.outer(xf32, xf32) + numpy_assert_allclose( + out, + truncate(expected, float_type=float_type), + rtol=0.0, + float_type=float_type, + ) + def testModf(self, float_type): rng = np.random.RandomState(seed=42) x = rng.randn(3, 7).astype(float_type)