Skip to content

Commit 5ce09af

Browse files
wsmosesmcabbott
andauthored
Prevent recursion in _eps (#207)
* Revert "Reactant: add extension to prevent stackoverflow (#206)" This reverts commit 12b7f31. * Change eps to not be recursive * Update src/utils.jl Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> * Update src/utils.jl Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --------- Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
1 parent 12b7f31 commit 5ce09af

3 files changed

Lines changed: 10 additions & 14 deletions

File tree

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1313
[weakdeps]
1414
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
1515
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
16-
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
1716

1817
[extensions]
1918
OptimisersAdaptExt = ["Adapt"]
2019
OptimisersEnzymeCoreExt = "EnzymeCore"
21-
OptimisersReactantExt = "Reactant"
2220

2321
[compat]
2422
Adapt = "4"

ext/OptimisersReactantExt.jl

Lines changed: 0 additions & 8 deletions
This file was deleted.

src/utils.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,14 @@ end
1515

1616
ofeltype(x, y) = convert(float(eltype(x)), y)
1717

18-
_eps(T::Type{<:AbstractFloat}, e) = T(e)
19-
# catch complex and integers
20-
_eps(T::Type{<:Number}, e) = _eps(real(float(T)), e)
21-
# avoid small e being rounded to zero
18+
"""
19+
_eps(Type{T}, val)
20+
21+
Mostly this produces `real(T)(val)`, so that `_eps(Float32, 1e-8) === 1f-8` will
22+
convert the Float64 parameter epsilon to work nicely with Float32 parameter arrays.
23+
24+
But for Float16, it imposes a minimum of `Float16(1e-7)`, unless `val==0`.
25+
This is basically a hack to increase the default epsilon, to help many optimisers avoid NaN.
26+
"""
27+
_eps(T::Type{<:Number}, e) = real(float(T))(e)
2228
_eps(T::Type{Float16}, e) = e == 0 ? T(0) : max(T(1e-7), T(e))

0 commit comments

Comments
 (0)