diff --git a/Project.toml b/Project.toml index 91a2ee3e..db72e8a8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.15.8" +version = "0.15.9" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/ext/BijectorsMooncakeExt.jl b/ext/BijectorsMooncakeExt.jl index 0c2d8903..28c7e850 100644 --- a/ext/BijectorsMooncakeExt.jl +++ b/ext/BijectorsMooncakeExt.jl @@ -1,12 +1,12 @@ module BijectorsMooncakeExt using Mooncake: - @is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_rrule + @is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_chainrules using Bijectors: find_alpha, ChainRulesCore -for P in [Float16, Float32, Float64] - @from_rrule(MinimalCtx, Tuple{typeof(find_alpha),P,P,P}) -end +@from_chainrules(MinimalCtx, Tuple{typeof(find_alpha),Float16,Float16,Float16}) +@from_chainrules(MinimalCtx, Tuple{typeof(find_alpha),Float32,Float32,Float32}) +@from_chainrules(MinimalCtx, Tuple{typeof(find_alpha),Float64,Float64,Float64}) # The final argument could be an Integer of some kind. This should be fine provided that # it has tangent type equal to `NoTangent`, which means that it's non-differentiable and @@ -15,6 +15,7 @@ end # unusual Integer type is encountered. @is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat}) +# TODO: This needs a corresponding frule!! as well for it to work on forward-mode Mooncake. function Mooncake.rrule!!( ::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I} ) where {P<:Base.IEEEFloat,I<:Integer} diff --git a/src/chainrules.jl b/src/chainrules.jl index cfacfdcc..5a749e02 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -292,4 +292,3 @@ end # Fixes AD issues with `@debug` ChainRulesCore.@non_differentiable _debug(::Any) - diff --git a/test/Project.toml b/test/Project.toml index c6026495..447c36dd 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -45,7 +45,7 @@ LazyArrays = "1, 2" LogDensityProblems = "2" LogExpFunctions = "0.3.1" MCMCDiagnosticTools = "0.3" -Mooncake = "0.4" +Mooncake = "0.4.147" ReverseDiff = "1.4.2" StableRNGs = "1" Tracker = "0.2.11" diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index a217e66b..5c938502 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -31,36 +31,39 @@ end if @isdefined Mooncake rng = Xoshiro(123456) - Mooncake.TestUtils.test_rule( - rng, - Bijectors.find_alpha, - x, - y, - z; - is_primitive=true, - perf_flag=:none, - interp=Mooncake.MooncakeInterpreter(), - ) - Mooncake.TestUtils.test_rule( - rng, - Bijectors.find_alpha, - x, - y, - 3; - is_primitive=true, - perf_flag=:none, - interp=Mooncake.MooncakeInterpreter(), - ) - Mooncake.TestUtils.test_rule( - rng, - Bijectors.find_alpha, - x, - y, - UInt32(3); - is_primitive=true, - perf_flag=:none, - interp=Mooncake.MooncakeInterpreter(), - ) + # TODO: Enable Mooncake.ForwardMode as well. + @testset "$mode" for mode in (Mooncake.ReverseMode,) + Mooncake.TestUtils.test_rule( + rng, + Bijectors.find_alpha, + x, + y, + z; + is_primitive=true, + perf_flag=:none, + mode=mode, + ) + Mooncake.TestUtils.test_rule( + rng, + Bijectors.find_alpha, + x, + y, + 3; + is_primitive=true, + perf_flag=:none, + mode=mode, + ) + Mooncake.TestUtils.test_rule( + rng, + Bijectors.find_alpha, + x, + y, + UInt32(3); + is_primitive=true, + perf_flag=:none, + mode=mode, + ) + end end test_rrule( diff --git a/test/runtests.jl b/test/runtests.jl index 4a1e4c0c..ca7131d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -70,7 +70,9 @@ end if GROUP == "All" || GROUP == "AD" include("ad/chainrules.jl") - include("ad/enzyme.jl") + if get(ENV, "AD", "All") in ("All", "Enzyme") + include("ad/enzyme.jl") + end include("ad/flows.jl") include("ad/pd.jl") include("ad/corr.jl")