Skip to content

Commit ca97af9

Browse files
Change the default implementation of GeLU to a numerically stable formulation.
The old formulation explicitly computed (1 + erf(x/sqrt(2))), which can be extremely inaccurate for negative x due to cancellation. PiperOrigin-RevId: 676944344
1 parent 1b34880 commit ca97af9

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

jax/_src/nn/functions.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,8 @@ def gelu(x: ArrayLike, approximate: bool = True) -> Array:
430430
If ``approximate=False``, computes the element-wise function:
431431
432432
.. math::
433-
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
434-
\frac{x}{\sqrt{2}} \right) \right)
433+
\mathrm{gelu}(x) = \frac{x}{2} \left(\mathrm{erfc} \left(
434+
\frac{-x}{\sqrt{2}} \right) \right)
435435
436436
If ``approximate=True``, uses the approximate formulation of GELU:
437437
@@ -443,7 +443,7 @@ def gelu(x: ArrayLike, approximate: bool = True) -> Array:
443443
<https://arxiv.org/abs/1606.08415>`_, section 2.
444444
445445
Args:
446-
x : input array
446+
x: input array
447447
approximate: whether to use the approximate or exact formulation.
448448
"""
449449
[x_arr] = numpy_util.promote_args_inexact("gelu", x)
@@ -453,8 +453,10 @@ def gelu(x: ArrayLike, approximate: bool = True) -> Array:
453453
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x_arr + 0.044715 * (x_arr ** 3))))
454454
return x_arr * cdf
455455
else:
456-
sqrt_2 = np.sqrt(2).astype(x_arr.dtype)
457-
return jnp.array(x_arr * (lax.erf(x_arr / sqrt_2) + 1) / 2, dtype=x_arr.dtype)
456+
sqrt_half = np.sqrt(0.5).astype(x_arr.dtype)
457+
return jnp.array(
458+
0.5 * x_arr * (lax.erfc(-x_arr * sqrt_half)), dtype=x_arr.dtype
459+
)
458460

459461
@partial(jax.jit, static_argnames=("axis",))
460462
def glu(x: ArrayLike, axis: int = -1) -> Array:
@@ -541,7 +543,7 @@ def log_softmax(x: ArrayLike,
541543

542544

543545
# TODO(phawkins): this jit was found to change numerics in a test. Debug this.
544-
#@partial(jax.jit, static_argnames=("axis",))
546+
# @partial(jax.jit, static_argnames=("axis",))
545547
def softmax(x: ArrayLike,
546548
axis: int | tuple[int, ...] | None = -1,
547549
where: ArrayLike | None = None,

tests/nn_test.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,11 +308,18 @@ def testGeluIntType(self, approximate):
308308
def testGelu(self, approximate):
309309
def gelu_reference(x):
310310
return x * scipy.stats.norm.cdf(x)
311-
rng = jtu.rand_default(self.rng())
312-
args_maker = lambda: [rng((4, 5, 6), jnp.float32)]
311+
args_maker = lambda: [jnp.linspace(-12, 5, 10000, dtype=jnp.float32)]
312+
rtol = 2e-5
313+
atol = 1e-3 if approximate else 0
313314
self._CheckAgainstNumpy(
314-
gelu_reference, partial(nn.gelu, approximate=approximate), args_maker,
315-
check_dtypes=False, tol=1e-3 if approximate else None)
315+
gelu_reference,
316+
partial(nn.gelu, approximate=approximate),
317+
args_maker,
318+
check_dtypes=False,
319+
tol=0,
320+
rtol=rtol,
321+
atol=atol,
322+
)
316323

317324
@parameterized.parameters(*itertools.product(
318325
(jnp.float32, jnp.bfloat16, jnp.float16),

0 commit comments

Comments
 (0)