Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ end
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
false, #=within_autodiff_rewrite=#
set_reactant_abi,
set_reactant_abi, #=handler=#
)
end
else
Expand All @@ -117,7 +117,7 @@ else
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
false, #=within_autodiff_rewrite=#
set_reactant_abi,
set_reactant_abi, #=handler=#
)
end
end
22 changes: 20 additions & 2 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,34 @@
end

# Enzyme.jl overlays
const WITHIN_AUTODIFF = Ref(false)

@reactant_overlay @noinline function Enzyme.within_autodiff()
return WITHIN_AUTODIFF[]
end

@reactant_overlay @noinline function Enzyme.autodiff_deferred(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
original_within_autodiff = WITHIN_AUTODIFF[]
try
WITHIN_AUTODIFF[] = true
return overload_autodiff(rmode, f, rt, args...)
finally
WITHIN_AUTODIFF[] = original_within_autodiff
end
end

@reactant_overlay @noinline function Enzyme.autodiff(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
original_within_autodiff = WITHIN_AUTODIFF[]
try
WITHIN_AUTODIFF[] = true
return overload_autodiff(rmode, f, rt, args...)
finally
WITHIN_AUTODIFF[] = original_within_autodiff
end
end

@reactant_overlay function EnzymeCore.ignore_derivatives(args...)
Expand Down
1 change: 1 addition & 0 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ function make_mlir_fn(
args_in_result::Symbol=:all,
construct_function_without_args::Bool=false,
do_transpose=true,
within_autodiff=false,
input_shardings=nothing, # This is not meant to be used by the user.
output_shardings=nothing, # This is not meant to be used by the user.
runtime=nothing,
Expand Down
15 changes: 15 additions & 0 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,21 @@ end
@test res1[1] ≈ fill(2.0, 3, 2)
end

function error_not_within_autodiff()
!Enzyme.within_autodiff() && error("Not within autodiff")
return nothing
end

fwd_within_autodiff(Mode, RT) = Enzyme.autodiff(Mode, error_not_within_autodiff, RT)

@testset "within_autodiff" begin
@test_throws ErrorException error_not_within_autodiff()
@test fwd_within_autodiff(Forward, Const) == ()

@test_throws ErrorException @jit error_not_within_autodiff()
@test (@jit fwd_within_autodiff(Forward, Const)) == ()
end

function gw(z)
return Enzyme.gradient(Forward, sum, z; chunk=Val(1))
end
Expand Down
Loading