diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 8b073c47..c36ce2dd 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -56,13 +56,13 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) if !(:EnzymeForwardCrash in broken) if forward_broken @test_broken( - Enzyme.gradient(Enzyme.Forward, f, x)[1] ≈ finitediff, + Enzyme.gradient(Forward, Enzyme.Const(f), x)[1] ≈ finitediff, rtol = rtol, atol = atol ) else @test( - Enzyme.gradient(Enzyme.Forward, f, x)[1] ≈ finitediff, + Enzyme.gradient(Forward, Enzyme.Const(f), x)[1] ≈ finitediff, rtol = rtol, atol = atol ) @@ -72,13 +72,15 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) if !(:EnzymeReverseCrash in broken) if reverse_broken @test_broken( - Enzyme.gradient(Enzyme.Reverse, f, x)[1] ≈ finitediff, + Enzyme.gradient(set_runtime_activity(Reverse), Enzyme.Const(f), x)[1] ≈ + finitediff, rtol = rtol, atol = atol ) else @test( - Enzyme.gradient(Enzyme.Reverse, f, x)[1] ≈ finitediff, + Enzyme.gradient(set_runtime_activity(Reverse), Enzyme.Const(f), x)[1] ≈ + finitediff, rtol = rtol, atol = atol )