diff --git a/mlx/backend/cpu/simd/math.h b/mlx/backend/cpu/simd/math.h index f9fc8317a5..3670c5fc24 100644 --- a/mlx/backend/cpu/simd/math.h +++ b/mlx/backend/cpu/simd/math.h @@ -190,4 +190,58 @@ Simd erfinv(Simd a_) { } } +/** + * Modified Bessel function of the first kind, order zero: I0(x). + * Cephes polynomial approximation in two domains: + * |x| <= 3.75 → polynomial in (x/3.75)^2 + * |x| > 3.75 → exp(|x|) / sqrt(|x|) * polynomial in (3.75/|x|) + */ +template +Simd i0(Simd x_) { + Simd x = x_; + Simd y = abs(x); + + // Branch 1: y <= 3.75 + auto small = [](Simd y) { + Simd t = y / 3.75f; + t = t * t; + Simd p(1.0f); + p = fma(t, Simd(3.5156229f), p); + // Horner evaluation of the inner polynomial + Simd r(0.0045813f); + r = fma(r, t, Simd(0.0360768f)); + r = fma(r, t, Simd(0.2659732f)); + r = fma(r, t, Simd(1.2067492f)); + r = fma(r, t, Simd(3.0899424f)); + r = fma(r, t, Simd(3.5156229f)); + r = fma(r, t, Simd(1.0f)); + return r; + }; + + // Branch 2: y > 3.75 + auto large = [](Simd y) { + Simd t = Simd(3.75f) / y; + Simd p(0.00392377f); + p = fma(p, t, Simd(-0.01647633f)); + p = fma(p, t, Simd(0.02635537f)); + p = fma(p, t, Simd(-0.02057706f)); + p = fma(p, t, Simd(0.00916281f)); + p = fma(p, t, Simd(-0.00157565f)); + p = fma(p, t, Simd(0.00225319f)); + p = fma(p, t, Simd(0.01328592f)); + p = fma(p, t, Simd(0.39894228f)); + return (exp(y) / sqrt(y)) * p; + }; + + if constexpr (N == 1) { + if ((y <= 3.75f).value) { + return Simd(small(y)); + } else { + return Simd(large(y)); + } + } else { + return Simd(select(y <= 3.75f, small(y), large(y))); + } +} + } // namespace mlx::core::simd diff --git a/mlx/backend/cpu/unary.cpp b/mlx/backend/cpu/unary.cpp index eafe98866f..063cca67b2 100644 --- a/mlx/backend/cpu/unary.cpp +++ b/mlx/backend/cpu/unary.cpp @@ -103,6 +103,12 @@ void ErfInv::eval_cpu(const std::vector& inputs, array& out) { unary_real_fp(in, out, detail::ErfInv(), stream()); } +void I0::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + unary_real_fp(in, out, detail::I0(), stream()); +} + void Exp::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); const auto& in = inputs[0]; diff --git a/mlx/backend/cpu/unary_ops.h b/mlx/backend/cpu/unary_ops.h index f441e88bd5..411cef5976 100644 --- a/mlx/backend/cpu/unary_ops.h +++ b/mlx/backend/cpu/unary_ops.h @@ -44,6 +44,7 @@ DEFAULT_OP(ErfInv, erfinv) DEFAULT_OP(Exp, exp) DEFAULT_OP(Expm1, expm1) DEFAULT_OP(Floor, floor); +DEFAULT_OP(I0, i0) DEFAULT_OP(Log, log); DEFAULT_OP(Log2, log2); DEFAULT_OP(Log10, log10); diff --git a/mlx/backend/metal/kernels/i0.h b/mlx/backend/metal/kernels/i0.h new file mode 100644 index 0000000000..fab1704cf8 --- /dev/null +++ b/mlx/backend/metal/kernels/i0.h @@ -0,0 +1,39 @@ +// Copyright © 2025 Apple Inc. + +#pragma once +#include + +/* + * Modified Bessel function of the first kind, order zero: I0(x). + * Uses the Cephes polynomial approximation in two domains. + * + * Domain 1: |x| <= 3.75 → polynomial in (x/3.75)^2 + * Domain 2: |x| > 3.75 → exp(|x|) / sqrt(|x|) * polynomial in (3.75/|x|) + * + * Reference: Cephes Math Library (netlib.org/cephes) + */ +float i0_impl(float x) { + float y = metal::abs(x); + + if (y <= 3.75f) { + float t = y / 3.75f; + t = t * t; + return 1.0f + + t * (3.5156229f + + t * (3.0899424f + + t * (1.2067492f + + t * (0.2659732f + t * (0.0360768f + t * 0.0045813f))))); + } else { + float t = 3.75f / y; + float p = 0.00392377f; + p = metal::fma(p, t, -0.01647633f); + p = metal::fma(p, t, 0.02635537f); + p = metal::fma(p, t, -0.02057706f); + p = metal::fma(p, t, 0.00916281f); + p = metal::fma(p, t, -0.00157565f); + p = metal::fma(p, t, 0.00225319f); + p = metal::fma(p, t, 0.01328592f); + p = metal::fma(p, t, 0.39894228f); + return (metal::precise::exp(y) / metal::precise::sqrt(y)) * p; + } +} diff --git a/mlx/backend/metal/kernels/unary.metal b/mlx/backend/metal/kernels/unary.metal index 54a0f566c8..f3abcc0c5d 100644 --- a/mlx/backend/metal/kernels/unary.metal +++ b/mlx/backend/metal/kernels/unary.metal @@ -67,6 +67,7 @@ instantiate_unary_types(Negative) instantiate_unary_float(Sigmoid) instantiate_unary_float(Erf) instantiate_unary_float(ErfInv) +instantiate_unary_float(I0) instantiate_unary_types(Sign) instantiate_unary_float(Sin) instantiate_unary_float(Sinh) diff --git a/mlx/backend/metal/kernels/unary_ops.h b/mlx/backend/metal/kernels/unary_ops.h index 327bb5a940..3199576377 100644 --- a/mlx/backend/metal/kernels/unary_ops.h +++ b/mlx/backend/metal/kernels/unary_ops.h @@ -9,6 +9,7 @@ #include "mlx/backend/metal/kernels/erf.h" #include "mlx/backend/metal/kernels/expm1f.h" #include "mlx/backend/metal/kernels/fp8.h" +#include "mlx/backend/metal/kernels/i0.h" namespace { constant float inf = metal::numeric_limits::infinity(); @@ -174,6 +175,13 @@ struct ErfInv { }; }; +struct I0 { + template + T operator()(T x) { + return static_cast(i0_impl(static_cast(x))); + }; +}; + struct Exp { template T operator()(T x) { diff --git a/mlx/backend/metal/unary.cpp b/mlx/backend/metal/unary.cpp index 833b23f632..ccb9cc7d78 100644 --- a/mlx/backend/metal/unary.cpp +++ b/mlx/backend/metal/unary.cpp @@ -125,6 +125,7 @@ UNARY_GPU(Cos) UNARY_GPU(Cosh) UNARY_GPU(Erf) UNARY_GPU(ErfInv) +UNARY_GPU(I0) UNARY_GPU(Exp) UNARY_GPU(Expm1) UNARY_GPU(Imag) diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index ae51dd9b2f..ec35c4455c 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -60,6 +60,7 @@ NO_CPU_MULTI(Eigh) NO_CPU(Equal) NO_CPU(Erf) NO_CPU(ErfInv) +NO_CPU(I0) NO_CPU(Exp) NO_CPU(ExpandDims) NO_CPU(Expm1) diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 4819ed2724..a6139dedee 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -87,6 +87,7 @@ NO_GPU(Remainder) NO_GPU(Equal) NO_GPU(Erf) NO_GPU(ErfInv) +NO_GPU(I0) NO_GPU(Exp) NO_GPU(ExpandDims) NO_GPU(Expm1) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index c7af8834fe..1f1f328dc1 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3011,6 +3011,35 @@ array erfinv(const array& a, StreamOrDevice s /* = {} */) { {astype(a, dtype, s)}); } +array i0(const array& a, StreamOrDevice s /* = {} */) { + auto dtype = at_least_float(a.dtype()); + return array( + a.shape(), + dtype, + std::make_shared(to_stream(s)), + {astype(a, dtype, s)}); +} + +array kaiser(int M, float beta, StreamOrDevice s /* = {} */) { + if (M < 1) { + return array({}); + } + if (M == 1) { + return ones({1}, float32, s); + } + + // w(n) = I0(beta * sqrt(1 - ((2n/(M-1)) - 1)^2)) / I0(beta) + auto n = arange(0, M, float32, s); + auto alpha = array((M - 1) / 2.0f, float32); + auto x = divide(subtract(n, alpha, s), alpha, s); // (2n/(M-1)) - 1 + auto arg = multiply( // beta * sqrt(1 - x^2) + array(beta, float32), + sqrt(subtract(array(1.0f, float32), square(x, s), s), s), + s); + auto denom = i0(array(beta, float32), s); + return divide(i0(arg, s), denom, s); +} + array stop_gradient(const array& a, StreamOrDevice s /* = {} */) { return array( a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); diff --git a/mlx/ops.h b/mlx/ops.h index 74032c01e0..73d6bbbf77 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -951,6 +951,12 @@ MLX_API array erf(const array& a, StreamOrDevice s = {}); /** Computes the inverse error function of the elements of an array. */ MLX_API array erfinv(const array& a, StreamOrDevice s = {}); +/** Computes the modified Bessel function of the first kind, order zero. */ +MLX_API array i0(const array& a, StreamOrDevice s = {}); + +/** Returns the Kaiser window of size M with shape parameter beta. */ +MLX_API array kaiser(int M, float beta, StreamOrDevice s = {}); + /** Computes the expm1 function of the elements of an array. */ MLX_API array expm1(const array& a, StreamOrDevice s = {}); diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 92e54f9991..16520a2c32 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -1930,6 +1930,88 @@ std::pair, std::vector> ErfInv::vmap( return {{erfinv(inputs[0], stream())}, axes}; } +// --------------------------------------------------------------------------- +// Helper: compute I1(x) – modified Bessel function of the first kind, order 1 +// This is the derivative of I0 and is needed for I0 gradients. +// Cephes polynomial approximation (same two-domain split as I0). +// --------------------------------------------------------------------------- +static array i1_impl(const array& x, Stream s) { + auto dtype = x.dtype(); + auto y = abs(x, s); + auto t_small = square(divide(y, array(3.75f, dtype), s), s); // (y/3.75)^2 + // Horner evaluation for |x| <= 3.75: result = x * poly(t) + auto poly_small = [&](const array& t) -> array { + // coefficients [inner … outer] + static const float cs[] = { + 0.00032411f, + 0.00301532f, + 0.02658733f, + 0.15084934f, + 0.51498869f, + 0.87890594f, + 0.5f, + }; + array r(cs[0], dtype); + for (int i = 1; i < 7; ++i) { + r = add(multiply(r, t, s), array(cs[i], dtype), s); + } + return multiply(x, r, s); // I1 is odd: multiply by x (preserves sign) + }; + + auto t_large = divide(array(3.75f, dtype), y, s); // 3.75/|x| + // Horner evaluation for |x| > 3.75: result = + // sign(x)*exp(|x|)/sqrt(|x|)*poly(t) + auto poly_large = [&](const array& t) -> array { + static const float cl[] = { + -0.00420059f, + 0.01787654f, + -0.02895312f, + 0.02282967f, + -0.01031555f, + 0.00163801f, + -0.00362018f, + -0.03988024f, + 0.39894228f, + }; + array r(cl[0], dtype); + for (int i = 1; i < 9; ++i) { + r = add(multiply(r, t, s), array(cl[i], dtype), s); + } + auto env = divide(exp(y, s), sqrt(y, s), s); + auto mag = multiply(env, r, s); + // Restore sign: I1 is odd + return multiply(sign(x, s), mag, s); + }; + + auto mask = less_equal(y, array(3.75f, dtype), s); + return where(mask, poly_small(t_small), poly_large(t_large), s); +} + +std::vector I0::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + return jvp(primals, cotangents, argnums); +} + +std::vector I0::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + assert(primals.size() == 1); + assert(argnums.size() == 1); + return {multiply(tangents[0], i1_impl(primals[0], stream()), stream())}; +} + +std::pair, std::vector> I0::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + return {{i0(inputs[0], stream())}, axes}; +} + std::vector Exp::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 4091aafcfb..2f9ff02712 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -1018,6 +1018,20 @@ class ErfInv : public UnaryPrimitive { DEFINE_INPUT_OUTPUT_SHAPE() }; +class I0 : public UnaryPrimitive { + public: + explicit I0(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(I0) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + class MLX_API Exp : public UnaryPrimitive { public: explicit Exp(Stream stream) : UnaryPrimitive(stream) {} diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a4ce55f8b3..1a2bd1ae57 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -919,6 +919,28 @@ void init_ops(nb::module_& m) { Returns: array: The inverse error function of ``a``. )pbdoc"); + m.def( + "i0", + [](const ScalarOrArray& a, mx::StreamOrDevice s) { + return mx::i0(to_array(a), s); + }, + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def i0(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Element-wise modified Bessel function of the first kind, order zero. + + .. math:: + I_0(x) = \sum_{k=0}^{\infty} \frac{(x/2)^{2k}}{(k!)^2} + + Args: + a (array): Input array. + + Returns: + array: The modified Bessel function :math:`I_0` evaluated element-wise on ``a``. + )pbdoc"); m.def( "sin", [](const ScalarOrArray& a, mx::StreamOrDevice s) { @@ -1520,6 +1542,31 @@ void init_ops(nb::module_& m) { array: The window, with the maximum value normalized to one (the value one appears only if the number of samples is odd). )pbdoc"); + m.def( + "kaiser", + &mlx::core::kaiser, + "M"_a, + "beta"_a, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def kaiser(M: int, beta: float, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Return the Kaiser window. + + The Kaiser window is a taper formed by using a Bessel function. + + .. math:: + w(n) = \frac{I_0\!\left(\beta\,\sqrt{1 - \!\left(\frac{2n}{M-1} - 1\right)^{\!2}}\right)}{I_0(\beta)} + \qquad 0 \le n \le M-1 + + Args: + M (int): Number of points in the output window. + beta (float): Shape parameter for the window. + + Returns: + array: The Kaiser window of length ``M`` with shape parameter ``beta``. + )pbdoc"); m.def( "linspace", [](Scalar start, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index fd40f3b651..defc58dee2 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1498,6 +1498,106 @@ def test_blackman_general(self): self.assertEqual(a.size, 0) self.assertEqual(a.dtype, mx.float32) + def test_i0(self): + try: + from scipy.special import i0 as ref_i0 + except ImportError: + # numpy.i0 uses the same Cephes approximation + ref_i0 = np.i0 + + # Test both polynomial branches: |x| <= 3.75 and |x| > 3.75 + inputs = np.array([-10.0, -5.0, -3.75, -1.0, 0.0, 1.0, 3.75, 5.0, 10.0]) + x_mx = mx.array(inputs, dtype=mx.float32) + x_np = inputs.astype(np.float32) + + result = np.array(mx.i0(x_mx).tolist()) + expected = ref_i0(x_np).astype(np.float32) + + self.assertTrue( + np.allclose(result, expected, rtol=1e-4, atol=1e-4), + f"i0 mismatch:\n got: {result}\n expected: {expected}", + ) + + # Scalar input: I0(0) == 1 + self.assertAlmostEqual(float(mx.i0(mx.array(0.0))), 1.0, places=5) + + # Symmetry: I0 is even + self.assertTrue( + np.allclose( + np.array(mx.i0(x_mx).tolist()), + np.array(mx.i0(-x_mx).tolist()), + atol=1e-5, + ) + ) + + def test_i0_cpu_gpu_parity(self): + """CPU and GPU evaluations of i0 must agree.""" + inputs = np.linspace(-10, 10, 41, dtype=np.float32) + + # Force CPU execution + with mx.stream(mx.cpu): + x_cpu = mx.array(inputs) + cpu_result = np.array(mx.i0(x_cpu).tolist()) + + # Force GPU execution (default behavior on Apple Silicon, but good to be explicit) + with mx.stream(mx.gpu): + x_gpu = mx.array(inputs) + gpu_result = np.array(mx.i0(x_gpu).tolist()) + + self.assertTrue( + np.allclose(cpu_result, gpu_result, atol=1e-5), + "i0 CPU/GPU parity failed", + ) + + def test_i0_grad(self): + try: + from scipy.special import i1 as ref_i1 + except ImportError: + # SciPy is required to get the ground truth for I1 + self.skipTest("SciPy is required for i0 gradient tests") + + # Test values covering both small (<=3.75) and large (>3.75) domains, plus negative + inputs = np.array([-5.0, -1.0, 0.0, 1.0, 3.75, 5.0]) + x_mx = mx.array(inputs, dtype=mx.float32) + + # The derivative of sum(i0(x)) with respect to x is exactly i1(x) + def f(x): + return mx.sum(mx.i0(x)) + + grad_fn = mx.grad(f) + mlx_grad = np.array(grad_fn(x_mx).tolist()) + + expected_grad = ref_i1(inputs).astype(np.float32) + + self.assertTrue( + np.allclose(mlx_grad, expected_grad, rtol=1e-4, atol=1e-4), + f"i0 gradient mismatch:\n got: {mlx_grad}\n expected: {expected_grad}", + ) + + def test_kaiser_general(self): + # Basic comparison with numpy + for M, beta in [(10, 5.0), (20, 8.6), (5, 0.0), (11, 14.0)]: + a = mx.kaiser(M, beta) + expected = np.kaiser(M, beta) + self.assertTrue( + np.allclose(np.array(a.tolist()), expected, atol=1e-4, rtol=1e-4), + f"kaiser(M={M}, beta={beta}) mismatch:\n got {np.array(a.tolist())}\n expected {expected}", + ) + + # Edge cases + a = mx.kaiser(1, 5.0) + self.assertEqual(a.item(), 1.0) + + a = mx.kaiser(0, 5.0) + self.assertEqual(a.size, 0) + self.assertEqual(a.dtype, mx.float32) + + # Symmetry: Kaiser window is symmetric + a = np.array(mx.kaiser(21, 8.6).tolist()) + self.assertTrue( + np.allclose(a, a[::-1], atol=1e-5), "Kaiser window should be symmetric" + ) + def test_unary_ops(self): def test_ops(npop, mlxop, x, y, atol, rtol): r_np = npop(x)