Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions mlx/backend/cpu/simd/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,58 @@ Simd<T, N> erfinv(Simd<T, N> a_) {
}
}

/**
* Modified Bessel function of the first kind, order zero: I0(x).
* Cephes polynomial approximation in two domains:
* |x| <= 3.75 → polynomial in (x/3.75)^2
* |x| > 3.75 → exp(|x|) / sqrt(|x|) * polynomial in (3.75/|x|)
*/
template <typename T, int N>
Simd<T, N> i0(Simd<T, N> x_) {
Simd<float, N> x = x_;
Simd<float, N> y = abs(x);

// Branch 1: y <= 3.75
auto small = [](Simd<float, N> y) {
Simd<float, N> t = y / 3.75f;
t = t * t;
Simd<float, N> p(1.0f);
p = fma(t, Simd<float, N>(3.5156229f), p);
// Horner evaluation of the inner polynomial
Simd<float, N> r(0.0045813f);
r = fma(r, t, Simd<float, N>(0.0360768f));
r = fma(r, t, Simd<float, N>(0.2659732f));
r = fma(r, t, Simd<float, N>(1.2067492f));
r = fma(r, t, Simd<float, N>(3.0899424f));
r = fma(r, t, Simd<float, N>(3.5156229f));
r = fma(r, t, Simd<float, N>(1.0f));
return r;
};

// Branch 2: y > 3.75
auto large = [](Simd<float, N> y) {
Simd<float, N> t = Simd<float, N>(3.75f) / y;
Simd<float, N> p(0.00392377f);
p = fma(p, t, Simd<float, N>(-0.01647633f));
p = fma(p, t, Simd<float, N>(0.02635537f));
p = fma(p, t, Simd<float, N>(-0.02057706f));
p = fma(p, t, Simd<float, N>(0.00916281f));
p = fma(p, t, Simd<float, N>(-0.00157565f));
p = fma(p, t, Simd<float, N>(0.00225319f));
p = fma(p, t, Simd<float, N>(0.01328592f));
p = fma(p, t, Simd<float, N>(0.39894228f));
return (exp(y) / sqrt(y)) * p;
};

if constexpr (N == 1) {
if ((y <= 3.75f).value) {
return Simd<T, N>(small(y));
} else {
return Simd<T, N>(large(y));
}
} else {
return Simd<T, N>(select(y <= 3.75f, small(y), large(y)));
}
}

} // namespace mlx::core::simd
6 changes: 6 additions & 0 deletions mlx/backend/cpu/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ void ErfInv::eval_cpu(const std::vector<array>& inputs, array& out) {
unary_real_fp(in, out, detail::ErfInv(), stream());
}

void I0::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
unary_real_fp(in, out, detail::I0(), stream());
}

void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
const auto& in = inputs[0];
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/cpu/unary_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ DEFAULT_OP(ErfInv, erfinv)
DEFAULT_OP(Exp, exp)
DEFAULT_OP(Expm1, expm1)
DEFAULT_OP(Floor, floor);
DEFAULT_OP(I0, i0)
DEFAULT_OP(Log, log);
DEFAULT_OP(Log2, log2);
DEFAULT_OP(Log10, log10);
Expand Down
39 changes: 39 additions & 0 deletions mlx/backend/metal/kernels/i0.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright © 2025 Apple Inc.

#pragma once
#include <metal_math>

/*
* Modified Bessel function of the first kind, order zero: I0(x).
* Uses the Cephes polynomial approximation in two domains.
*
* Domain 1: |x| <= 3.75 → polynomial in (x/3.75)^2
* Domain 2: |x| > 3.75 → exp(|x|) / sqrt(|x|) * polynomial in (3.75/|x|)
*
* Reference: Cephes Math Library (netlib.org/cephes)
*/
float i0_impl(float x) {
float y = metal::abs(x);

if (y <= 3.75f) {
float t = y / 3.75f;
t = t * t;
return 1.0f +
t * (3.5156229f +
t * (3.0899424f +
t * (1.2067492f +
t * (0.2659732f + t * (0.0360768f + t * 0.0045813f)))));
} else {
float t = 3.75f / y;
float p = 0.00392377f;
p = metal::fma(p, t, -0.01647633f);
p = metal::fma(p, t, 0.02635537f);
p = metal::fma(p, t, -0.02057706f);
p = metal::fma(p, t, 0.00916281f);
p = metal::fma(p, t, -0.00157565f);
p = metal::fma(p, t, 0.00225319f);
p = metal::fma(p, t, 0.01328592f);
p = metal::fma(p, t, 0.39894228f);
return (metal::precise::exp(y) / metal::precise::sqrt(y)) * p;
}
}
1 change: 1 addition & 0 deletions mlx/backend/metal/kernels/unary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ instantiate_unary_types(Negative)
instantiate_unary_float(Sigmoid)
instantiate_unary_float(Erf)
instantiate_unary_float(ErfInv)
instantiate_unary_float(I0)
instantiate_unary_types(Sign)
instantiate_unary_float(Sin)
instantiate_unary_float(Sinh)
Expand Down
8 changes: 8 additions & 0 deletions mlx/backend/metal/kernels/unary_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlx/backend/metal/kernels/erf.h"
#include "mlx/backend/metal/kernels/expm1f.h"
#include "mlx/backend/metal/kernels/fp8.h"
#include "mlx/backend/metal/kernels/i0.h"

namespace {
constant float inf = metal::numeric_limits<float>::infinity();
Expand Down Expand Up @@ -174,6 +175,13 @@ struct ErfInv {
};
};

struct I0 {
template <typename T>
T operator()(T x) {
return static_cast<T>(i0_impl(static_cast<float>(x)));
};
};

struct Exp {
template <typename T>
T operator()(T x) {
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/metal/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ UNARY_GPU(Cos)
UNARY_GPU(Cosh)
UNARY_GPU(Erf)
UNARY_GPU(ErfInv)
UNARY_GPU(I0)
UNARY_GPU(Exp)
UNARY_GPU(Expm1)
UNARY_GPU(Imag)
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/no_cpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ NO_CPU_MULTI(Eigh)
NO_CPU(Equal)
NO_CPU(Erf)
NO_CPU(ErfInv)
NO_CPU(I0)
NO_CPU(Exp)
NO_CPU(ExpandDims)
NO_CPU(Expm1)
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/no_gpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ NO_GPU(Remainder)
NO_GPU(Equal)
NO_GPU(Erf)
NO_GPU(ErfInv)
NO_GPU(I0)
NO_GPU(Exp)
NO_GPU(ExpandDims)
NO_GPU(Expm1)
Expand Down
29 changes: 29 additions & 0 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3011,6 +3011,35 @@ array erfinv(const array& a, StreamOrDevice s /* = {} */) {
{astype(a, dtype, s)});
}

array i0(const array& a, StreamOrDevice s /* = {} */) {
auto dtype = at_least_float(a.dtype());
return array(
a.shape(),
dtype,
std::make_shared<I0>(to_stream(s)),
{astype(a, dtype, s)});
}

array kaiser(int M, float beta, StreamOrDevice s /* = {} */) {
if (M < 1) {
return array({});
}
if (M == 1) {
return ones({1}, float32, s);
}

// w(n) = I0(beta * sqrt(1 - ((2n/(M-1)) - 1)^2)) / I0(beta)
auto n = arange(0, M, float32, s);
auto alpha = array((M - 1) / 2.0f, float32);
auto x = divide(subtract(n, alpha, s), alpha, s); // (2n/(M-1)) - 1
auto arg = multiply( // beta * sqrt(1 - x^2)
array(beta, float32),
sqrt(subtract(array(1.0f, float32), square(x, s), s), s),
s);
auto denom = i0(array(beta, float32), s);
return divide(i0(arg, s), denom, s);
}

array stop_gradient(const array& a, StreamOrDevice s /* = {} */) {
return array(
a.shape(), a.dtype(), std::make_shared<StopGradient>(to_stream(s)), {a});
Expand Down
6 changes: 6 additions & 0 deletions mlx/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,12 @@ MLX_API array erf(const array& a, StreamOrDevice s = {});
/** Computes the inverse error function of the elements of an array. */
MLX_API array erfinv(const array& a, StreamOrDevice s = {});

/** Computes the modified Bessel function of the first kind, order zero. */
MLX_API array i0(const array& a, StreamOrDevice s = {});

/** Returns the Kaiser window of size M with shape parameter beta. */
MLX_API array kaiser(int M, float beta, StreamOrDevice s = {});

/** Computes the expm1 function of the elements of an array. */
MLX_API array expm1(const array& a, StreamOrDevice s = {});

Expand Down
82 changes: 82 additions & 0 deletions mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1930,6 +1930,88 @@ std::pair<std::vector<array>, std::vector<int>> ErfInv::vmap(
return {{erfinv(inputs[0], stream())}, axes};
}

// ---------------------------------------------------------------------------
// Helper: compute I1(x) – modified Bessel function of the first kind, order 1
// This is the derivative of I0 and is needed for I0 gradients.
// Cephes polynomial approximation (same two-domain split as I0).
// ---------------------------------------------------------------------------
static array i1_impl(const array& x, Stream s) {
auto dtype = x.dtype();
auto y = abs(x, s);
auto t_small = square(divide(y, array(3.75f, dtype), s), s); // (y/3.75)^2
// Horner evaluation for |x| <= 3.75: result = x * poly(t)
auto poly_small = [&](const array& t) -> array {
// coefficients [inner … outer]
static const float cs[] = {
0.00032411f,
0.00301532f,
0.02658733f,
0.15084934f,
0.51498869f,
0.87890594f,
0.5f,
};
array r(cs[0], dtype);
for (int i = 1; i < 7; ++i) {
r = add(multiply(r, t, s), array(cs[i], dtype), s);
}
return multiply(x, r, s); // I1 is odd: multiply by x (preserves sign)
};

auto t_large = divide(array(3.75f, dtype), y, s); // 3.75/|x|
// Horner evaluation for |x| > 3.75: result =
// sign(x)*exp(|x|)/sqrt(|x|)*poly(t)
auto poly_large = [&](const array& t) -> array {
static const float cl[] = {
-0.00420059f,
0.01787654f,
-0.02895312f,
0.02282967f,
-0.01031555f,
0.00163801f,
-0.00362018f,
-0.03988024f,
0.39894228f,
};
array r(cl[0], dtype);
for (int i = 1; i < 9; ++i) {
r = add(multiply(r, t, s), array(cl[i], dtype), s);
}
auto env = divide(exp(y, s), sqrt(y, s), s);
auto mag = multiply(env, r, s);
// Restore sign: I1 is odd
return multiply(sign(x, s), mag, s);
};

auto mask = less_equal(y, array(3.75f, dtype), s);
return where(mask, poly_small(t_small), poly_large(t_large), s);
}

std::vector<array> I0::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>&) {
return jvp(primals, cotangents, argnums);
}

std::vector<array> I0::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
assert(primals.size() == 1);
assert(argnums.size() == 1);
return {multiply(tangents[0], i1_impl(primals[0], stream()), stream())};
}

std::pair<std::vector<array>, std::vector<int>> I0::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
assert(inputs.size() == 1);
assert(axes.size() == 1);
return {{i0(inputs[0], stream())}, axes};
}

std::vector<array> Exp::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
Expand Down
14 changes: 14 additions & 0 deletions mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,20 @@ class ErfInv : public UnaryPrimitive {
DEFINE_INPUT_OUTPUT_SHAPE()
};

class I0 : public UnaryPrimitive {
public:
explicit I0(Stream stream) : UnaryPrimitive(stream) {}

void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;

DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_NAME(I0)
DEFINE_DEFAULT_IS_EQUIVALENT()
DEFINE_INPUT_OUTPUT_SHAPE()
};

class MLX_API Exp : public UnaryPrimitive {
public:
explicit Exp(Stream stream) : UnaryPrimitive(stream) {}
Expand Down
47 changes: 47 additions & 0 deletions python/src/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,28 @@ void init_ops(nb::module_& m) {
Returns:
array: The inverse error function of ``a``.
)pbdoc");
m.def(
"i0",
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
return mx::i0(to_array(a), s);
},
nb::arg(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def i0(a: array, /, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Element-wise modified Bessel function of the first kind, order zero.

.. math::
I_0(x) = \sum_{k=0}^{\infty} \frac{(x/2)^{2k}}{(k!)^2}

Args:
a (array): Input array.

Returns:
array: The modified Bessel function :math:`I_0` evaluated element-wise on ``a``.
)pbdoc");
m.def(
"sin",
[](const ScalarOrArray& a, mx::StreamOrDevice s) {
Expand Down Expand Up @@ -1520,6 +1542,31 @@ void init_ops(nb::module_& m) {
array: The window, with the maximum value normalized to one (the value one
appears only if the number of samples is odd).
)pbdoc");
m.def(
"kaiser",
&mlx::core::kaiser,
"M"_a,
"beta"_a,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def kaiser(M: int, beta: float, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Return the Kaiser window.

The Kaiser window is a taper formed by using a Bessel function.

.. math::
w(n) = \frac{I_0\!\left(\beta\,\sqrt{1 - \!\left(\frac{2n}{M-1} - 1\right)^{\!2}}\right)}{I_0(\beta)}
\qquad 0 \le n \le M-1

Args:
M (int): Number of points in the output window.
beta (float): Shape parameter for the window.

Returns:
array: The Kaiser window of length ``M`` with shape parameter ``beta``.
)pbdoc");
m.def(
"linspace",
[](Scalar start,
Expand Down
Loading