From 19a1cdcce28eec5548dd28959fc665885d393a61 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 13 Aug 2025 10:11:00 +0100 Subject: [PATCH 1/8] unroll loop to account for https://github.com/chalk-lab/Mooncake.jl/issues/692 --- Project.toml | 2 +- ext/BijectorsMooncakeExt.jl | 6 +++--- src/chainrules.jl | 1 - 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 91a2ee3e..db72e8a8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.15.8" +version = "0.15.9" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/ext/BijectorsMooncakeExt.jl b/ext/BijectorsMooncakeExt.jl index 0c2d8903..761f3a1c 100644 --- a/ext/BijectorsMooncakeExt.jl +++ b/ext/BijectorsMooncakeExt.jl @@ -4,9 +4,9 @@ using Mooncake: @is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_rrule using Bijectors: find_alpha, ChainRulesCore -for P in [Float16, Float32, Float64] - @from_rrule(MinimalCtx, Tuple{typeof(find_alpha),P,P,P}) -end +@from_rrule(MinimalCtx, Tuple{typeof(find_alpha),Float16,Float16,Float16}) +@from_rrule(MinimalCtx, Tuple{typeof(find_alpha),Float32,Float32,Float32}) +@from_rrule(MinimalCtx, Tuple{typeof(find_alpha),Float64,Float64,Float64}) # The final argument could be an Integer of some kind. This should be fine provided that # it has tangent type equal to `NoTangent`, which means that it's non-differentiable and diff --git a/src/chainrules.jl b/src/chainrules.jl index cfacfdcc..5a749e02 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -292,4 +292,3 @@ end # Fixes AD issues with `@debug` ChainRulesCore.@non_differentiable _debug(::Any) - From 1cc4283adb5d965046c069d0839f4cc393d64f4e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 13 Aug 2025 10:24:02 +0100 Subject: [PATCH 2/8] fix test --- test/ad/chainrules.jl | 62 ++++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index a217e66b..5545f93d 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -31,36 +31,38 @@ end if @isdefined Mooncake rng = Xoshiro(123456) - Mooncake.TestUtils.test_rule( - rng, - Bijectors.find_alpha, - x, - y, - z; - is_primitive=true, - perf_flag=:none, - interp=Mooncake.MooncakeInterpreter(), - ) - Mooncake.TestUtils.test_rule( - rng, - Bijectors.find_alpha, - x, - y, - 3; - is_primitive=true, - perf_flag=:none, - interp=Mooncake.MooncakeInterpreter(), - ) - Mooncake.TestUtils.test_rule( - rng, - Bijectors.find_alpha, - x, - y, - UInt32(3); - is_primitive=true, - perf_flag=:none, - interp=Mooncake.MooncakeInterpreter(), - ) + @testset "$mode" for mode in (Mooncake.ForwardMode(), Mooncake.ReverseMode()) + Mooncake.TestUtils.test_rule( + rng, + Bijectors.find_alpha, + x, + y, + z; + is_primitive=true, + perf_flag=:none, + interp=Mooncake.MooncakeInterpreter(mode), + ) + Mooncake.TestUtils.test_rule( + rng, + Bijectors.find_alpha, + x, + y, + 3; + is_primitive=true, + perf_flag=:none, + interp=Mooncake.MooncakeInterpreter(mode), + ) + Mooncake.TestUtils.test_rule( + rng, + Bijectors.find_alpha, + x, + y, + UInt32(3); + is_primitive=true, + perf_flag=:none, + interp=Mooncake.MooncakeInterpreter(mode), + ) + end end test_rrule( From e8938b50567d946bf73ea55f5a6d4a912d6a8964 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 13 Aug 2025 10:37:51 +0100 Subject: [PATCH 3/8] it's the type not the value --- test/Project.toml | 2 +- test/ad/chainrules.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index c6026495..447c36dd 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -45,7 +45,7 @@ LazyArrays = "1, 2" LogDensityProblems = "2" LogExpFunctions = "0.3.1" MCMCDiagnosticTools = "0.3" -Mooncake = "0.4" +Mooncake = "0.4.147" ReverseDiff = "1.4.2" StableRNGs = "1" Tracker = "0.2.11" diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index 5545f93d..9838a509 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -31,7 +31,7 @@ end if @isdefined Mooncake rng = Xoshiro(123456) - @testset "$mode" for mode in (Mooncake.ForwardMode(), Mooncake.ReverseMode()) + @testset "$mode" for mode in (Mooncake.ForwardMode, Mooncake.ReverseMode) Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, From 369183c96955adb46249c4f47ac7b7d8f45ec3a8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 13 Aug 2025 10:48:53 +0100 Subject: [PATCH 4/8] more fixes? --- test/ad/chainrules.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index 9838a509..678854b5 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -40,7 +40,7 @@ end z; is_primitive=true, perf_flag=:none, - interp=Mooncake.MooncakeInterpreter(mode), + mode=mode, ) Mooncake.TestUtils.test_rule( rng, @@ -50,7 +50,7 @@ end 3; is_primitive=true, perf_flag=:none, - interp=Mooncake.MooncakeInterpreter(mode), + mode=mode, ) Mooncake.TestUtils.test_rule( rng, @@ -60,7 +60,7 @@ end UInt32(3); is_primitive=true, perf_flag=:none, - interp=Mooncake.MooncakeInterpreter(mode), + mode=mode, ) end end From 5fe59ae4d96faddd81c6705dff494e0f5200105f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 13 Aug 2025 10:55:06 +0100 Subject: [PATCH 5/8] from_rrule -> from_chainrules --- ext/BijectorsMooncakeExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/BijectorsMooncakeExt.jl b/ext/BijectorsMooncakeExt.jl index 761f3a1c..80265828 100644 --- a/ext/BijectorsMooncakeExt.jl +++ b/ext/BijectorsMooncakeExt.jl @@ -4,9 +4,9 @@ using Mooncake: @is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_rrule using Bijectors: find_alpha, ChainRulesCore -@from_rrule(MinimalCtx, Tuple{typeof(find_alpha),Float16,Float16,Float16}) -@from_rrule(MinimalCtx, Tuple{typeof(find_alpha),Float32,Float32,Float32}) -@from_rrule(MinimalCtx, Tuple{typeof(find_alpha),Float64,Float64,Float64}) +@from_chainrules(MinimalCtx, Tuple{typeof(find_alpha),Float16,Float16,Float16}) +@from_chainrules(MinimalCtx, Tuple{typeof(find_alpha),Float32,Float32,Float32}) +@from_chainrules(MinimalCtx, Tuple{typeof(find_alpha),Float64,Float64,Float64}) # The final argument could be an Integer of some kind. This should be fine provided that # it has tangent type equal to `NoTangent`, which means that it's non-differentiable and From a0df9b79ccea9c20cf3ba7178f5c2dfd55be11f1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 13 Aug 2025 11:02:34 +0100 Subject: [PATCH 6/8] fix import --- ext/BijectorsMooncakeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/BijectorsMooncakeExt.jl b/ext/BijectorsMooncakeExt.jl index 80265828..11a3fbf5 100644 --- a/ext/BijectorsMooncakeExt.jl +++ b/ext/BijectorsMooncakeExt.jl @@ -1,7 +1,7 @@ module BijectorsMooncakeExt using Mooncake: - @is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_rrule + @is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_chainrules using Bijectors: find_alpha, ChainRulesCore @from_chainrules(MinimalCtx, Tuple{typeof(find_alpha),Float16,Float16,Float16}) From 1ef4d9656c786d610bac9006d2eef7b74cfd614c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 13 Aug 2025 11:19:04 +0100 Subject: [PATCH 7/8] disable forward mode testing --- ext/BijectorsMooncakeExt.jl | 1 + test/ad/chainrules.jl | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ext/BijectorsMooncakeExt.jl b/ext/BijectorsMooncakeExt.jl index 11a3fbf5..28c7e850 100644 --- a/ext/BijectorsMooncakeExt.jl +++ b/ext/BijectorsMooncakeExt.jl @@ -15,6 +15,7 @@ using Bijectors: find_alpha, ChainRulesCore # unusual Integer type is encountered. @is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat}) +# TODO: This needs a corresponding frule!! as well for it to work on forward-mode Mooncake. function Mooncake.rrule!!( ::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I} ) where {P<:Base.IEEEFloat,I<:Integer} diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index 678854b5..5c938502 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -31,7 +31,8 @@ end if @isdefined Mooncake rng = Xoshiro(123456) - @testset "$mode" for mode in (Mooncake.ForwardMode, Mooncake.ReverseMode) + # TODO: Enable Mooncake.ForwardMode as well. + @testset "$mode" for mode in (Mooncake.ReverseMode,) Mooncake.TestUtils.test_rule( rng, Bijectors.find_alpha, From 439a77064d96dad23e62a2ebfb4f97036786c302 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 13 Aug 2025 12:03:02 +0100 Subject: [PATCH 8/8] don't test enzyme if it's not enzyme --- test/runtests.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 4a1e4c0c..ca7131d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -70,7 +70,9 @@ end if GROUP == "All" || GROUP == "AD" include("ad/chainrules.jl") - include("ad/enzyme.jl") + if get(ENV, "AD", "All") in ("All", "Enzyme") + include("ad/enzyme.jl") + end include("ad/flows.jl") include("ad/pd.jl") include("ad/corr.jl")