Skip to content

Conversation

@jumerckx
Copy link
Collaborator

@jumerckx jumerckx commented Jan 7, 2025

fixes #442
needs Enzyme.jl: EnzymeAD/Enzyme.jl#2254

I had to introduce a new function call_with_reactant_within_autodiff to smuggle the within_autodiff in the call_with_reactant_generator through the self argument.
I also tried doing things through set_reactant_abi but that didn't seem to suffice (first commit).
Perhaps the extra code in set_reactant_abi isn't strictly necessary now so I can try removing it again if wanted.

@avik-pal avik-pal force-pushed the jm/deferred_within_autodiff branch 2 times, most recently from dd14b85 to 8731c7f Compare September 1, 2025 23:13
@avik-pal avik-pal requested a review from wsmoses September 1, 2025 23:41

if isempty(kwargs)
Reactant.call_with_reactant(f, traced_args...)
if within_autodiff
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like a cleaner way to do this, is not to have a second interpreter. But instead we can create a new global ref set to false, and overlay within_autodiff to lookup that var, and during autodiff set that to true

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with this though, but if we were to do it in this form, I would probably change call_with_reactant to take a config type var, which stores the current state of whether in autodiff or not (and also we can extend to other things down the line as well)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I got rid of the second interpreter like you described

@avik-pal avik-pal force-pushed the jm/deferred_within_autodiff branch from a3115b5 to 2a4d0f2 Compare September 2, 2025 00:18
@codecov
Copy link

codecov bot commented Sep 2, 2025

Codecov Report

❌ Patch coverage is 78.57143% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 42.56%. Comparing base (71b744e) to head (2a4d0f2).

Files with missing lines Patch % Lines
src/utils.jl 50.00% 2 Missing ⚠️
src/TracedUtils.jl 83.33% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #490      +/-   ##
==========================================
+ Coverage   42.55%   42.56%   +0.01%     
==========================================
  Files         123      123              
  Lines       21816    21826      +10     
==========================================
+ Hits         9283     9290       +7     
- Misses      12533    12536       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@avik-pal
Copy link
Collaborator

julia> using Enzyme

julia> function error_not_within_autodiff()
           !Enzyme.within_autodiff() && error("Not within autodiff")
           return nothing
       end
error_not_within_autodiff (generic function with 1 method)

julia> fwd_within_autodiff(Mode, RT) = Enzyme.autodiff(Mode, error_not_within_autodiff, RT)
fwd_within_autodiff (generic function with 1 method)

julia> error_not_within_autodiff()
ERROR: Not within autodiff
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] error_not_within_autodiff()
   @ Main ./REPL[5]:2
 [3] top-level scope
   @ REPL[7]:1
 [4] top-level scope
   @ none:1

julia> fwd_within_autodiff(Forward, Const)
()

julia> error_not_within_autodiff()

julia> Enzyme.within_autodiff()
false

I am extremely confused why is the 2nd call not throw an error here. Only happens if I call fwd_within_autodiff in between. cc @wsmoses this is in isolation from Reactant

@wsmoses
Copy link
Member

wsmoses commented Sep 13, 2025

@vchuravy er wat

jumerckx and others added 5 commits November 17, 2025 14:17
… === overload_autodiff`.

This doesn't work for some reason, the function within overload autodiff uses the original interpreter (?)
…mlir_fn. In order to pass this information from make_mlir_fn to call_with_reactant_generator, I introduced a new function `call_with_reactant_within_autodiff` which allows detection by looking at `self`.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@jumerckx jumerckx force-pushed the jm/deferred_within_autodiff branch 2 times, most recently from 15eaa8f to 159837e Compare November 18, 2025 16:19
work with global ref instead

This has the unfortunate downside of introducing a try finally block 
around `overload_autodiff`.
put try-finally outside of overload_autodiff call
@jumerckx jumerckx force-pushed the jm/deferred_within_autodiff branch 2 times, most recently from 14864f3 to bb72f52 Compare November 18, 2025 16:21
Comment on lines 98 to 102
#=forward_rules=#false,
#=reverse_rules=#false,
#=inactive_rules=#false,
#=broadcast_rewrite=#false,
#=within_autodiff_rewrite=#set_reactant_abi,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
#=forward_rules=#false,
#=reverse_rules=#false,
#=inactive_rules=#false,
#=broadcast_rewrite=#false,
#=within_autodiff_rewrite=#set_reactant_abi,
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi, #=within_autodiff_rewrite=#

Comment on lines 116 to 120
#=forward_rules=#false,
#=reverse_rules=#false,
#=inactive_rules=#false,
#=broadcast_rewrite=#false,
#=within_autodiff_rewrite=#set_reactant_abi,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
#=forward_rules=#false,
#=reverse_rules=#false,
#=inactive_rules=#false,
#=broadcast_rewrite=#false,
#=within_autodiff_rewrite=#set_reactant_abi,
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi, #=within_autodiff_rewrite=#

@jumerckx jumerckx requested a review from wsmoses November 20, 2025 16:53
false, #=broadcast_rewrite=#
false, #=within_autodiff_rewrite=#
set_reactant_abi,
set_reactant_abi, #=within_autodiff_rewrite=#
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment seems off

src/Enzyme.jl Outdated
Comment on lines 99 to 429
::Type{RT}, seen::IdDict, prev::RT, (::Val{copy_if_inactive})=Val(false)
)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}}
if haskey(seen, prev)
return seen[prev]
end
if Enzyme.Compiler.guaranteed_const(eltype(RT))
return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev
end
res = zero(prev)
seen[prev] = res
return res
end

function Enzyme.onehot(x::TracedRArray{T,N}) where {T,N}
onehot_matrix = promote_to(TracedRArray{T,2}, LinearAlgebra.I(length(x)))
return Tuple(
materialize_traced_array(reshape(y, size(x))) for y in eachcol(onehot_matrix)
)
end

function EnzymeRules.inactive_noinl(::typeof(XLA.buffer_on_cpu), args...)
return nothing
end

function EnzymeRules.inactive_noinl(::typeof(XLA.addressable_devices), args...)
return nothing
end

function EnzymeRules.noalias(::typeof(Base.similar), a::ConcretePJRTArray, ::Type, args...)
return nothing
end

function EnzymeRules.noalias(::typeof(Base.similar), a::ConcreteIFRTArray, ::Type, args...)
return nothing
end

function EnzymeRules.augmented_primal(
config,
ofn::Const{typeof(Base.similar)},
::Type{RT},
uval::Annotation{<:ConcretePJRTArray},
T::Const{<:Type},
args...,
) where {RT}
primargs = ntuple(Val(length(args))) do i
Base.@_inline_meta
args[i].val
return args[i].val
end

primal = if EnzymeCore.needs_primal(config)
ofn.val(uval.val, T.val, primargs...)
else
nothing
end

shadow = if EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
ConcretePJRTArray(
zeros(T.val, primargs...);
client=XLA.client(uval.val),
device=XLA.device(uval.val),
uval.val.sharding,
)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
ConcretePJRTArray(
return ConcretePJRTArray(
zeros(T.val, primargs...);
client=XLA.client(uval.val),
device=XLA.device(uval.val),
uval.val.sharding,
)
end
end
else
nothing
end

return EnzymeRules.AugmentedReturn{
EnzymeRules.primal_type(config, RT),EnzymeRules.shadow_type(config, RT),Nothing
}(
primal, shadow, nothing
)
end

function EnzymeRules.reverse(
config,
ofn::Const{typeof(Base.similar)},
::Type{RT},
tape,
uval::Annotation{<:ConcretePJRTArray},
T::Const{<:Type},
args::Vararg{Annotation,N},
) where {RT,N}
ntuple(Val(N + 2)) do i
Base.@_inline_meta
nothing
return nothing
end
end

@inline function act_from_type(::A, reverse, needs_primal=true) where {A<:Annotation}
return act_from_type(A, reverse, needs_primal)
end

@inline function act_from_type(::Type{<:Active}, reverse, needs_primal)
return needs_primal ? enzyme_out : enzyme_outnoneed
end
@inline function act_from_type(::Type{<:Const}, reverse, needs_primal)
return needs_primal ? enzyme_const : enzyme_constnoneed
end

@inline function act_from_type(::Type{<:Duplicated}, reverse, needs_primal)
if reverse
return needs_primal ? enzyme_out : enzyme_outnoneed
else
return needs_primal ? enzyme_dup : enzyme_dupnoneed
end
end
@inline function act_from_type(
::Type{<:Union{BatchDuplicated,StackedBatchDuplicated}}, reverse, needs_primal
)
return act_from_type(Duplicated, reverse, needs_primal)
end

@inline function act_from_type(::Type{<:DuplicatedNoNeed}, reverse, needs_primal)
return reverse ? enzyme_out : enzyme_dupnoneed
end
@inline function act_from_type(
::Type{<:Union{BatchDuplicatedNoNeed,StackedBatchDuplicatedNoNeed}},
reverse,
needs_primal,
)
return act_from_type(DuplicatedNoNeed, reverse, needs_primal)
end

function push_acts!(ad_inputs, x::Union{Const,Active}, path, reverse)
TracedUtils.push_val!(ad_inputs, x.val, path)
return nothing
end

function push_acts!(ad_inputs, x::Union{Duplicated,DuplicatedNoNeed}, path, reverse)
TracedUtils.push_val!(ad_inputs, x.val, path)
if !reverse
TracedUtils.push_val!(ad_inputs, x.dval, path)
end
end

function push_acts!(
ad_inputs, x::Union{BatchDuplicated,BatchDuplicatedNoNeed}, path, reverse
)
TracedUtils.push_val!(ad_inputs, x.val, path)
if !reverse
TracedUtils.push_val!(ad_inputs, call_with_reactant(stack, x.dval), path)
end
end

function push_acts!(
ad_inputs, x::Union{StackedBatchDuplicated,StackedBatchDuplicatedNoNeed}, path, reverse
)
TracedUtils.push_val!(ad_inputs, x.val, path)
if !reverse
TracedUtils.push_val!(ad_inputs, x.dval, path)
end
end

function set_act!(inp, path, reverse, tostore; emptypath=false, width=1)
x = if inp isa Active
inp.val
else
inp.dval
end

for p in path
x = traced_getfield(x, p)
end

if width == 1
TracedUtils.set_mlir_data!(x, tostore)
elseif x isa AbstractArray
TracedUtils.set_mlir_data!(x, tostore)
else
tostore_traced = TracedRArray(tostore)
@assert length(x) == size(tostore_traced, ndims(tostore_traced))
for (i, sl) in enumerate(eachslice(tostore_traced; dims=ndims(tostore_traced)))
TracedUtils.set_mlir_data!(x[i], TracedUtils.get_mlir_data(sl))
end
end

emptypath && TracedUtils.set_paths!(x, ())
return nothing
end

function act_attr(val)
val = @ccall MLIR.API.mlir_c.enzymeActivityAttrGet(
MLIR.IR.context()::MLIR.API.MlirContext, val::Int32
)::MLIR.API.MlirAttribute
return MLIR.IR.Attribute(val)
end

function overload_autodiff(
::CMode, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs}
) where {CMode<:Mode,FA<:Annotation,A<:Annotation,Nargs}
reverse = CMode <: ReverseMode

width = Enzyme.same_or_one(1, args...)
if width == 0
throw(ErrorException("Cannot differentiate with a batch size of 0"))
end

primf = f.val
primargs = ((v.val for v in args)...,)

argprefix::Symbol = gensym("autodiffarg")
resprefix::Symbol = gensym("autodiffresult")
resargprefix::Symbol = gensym("autodiffresarg")

mlir_fn_res = TracedUtils.make_mlir_fn(
primf,
primargs,
(),
string(f) * "_autodiff",
false;
argprefix,
resprefix,
resargprefix,
)
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
fnwrap = mlir_fn_res.fnwrapped
func2 = mlir_fn_res.f

activity = Int32[]
ad_inputs = MLIR.IR.Value[]

for a in linear_args
idx, path = TracedUtils.get_argidx(a, argprefix)
arg = idx == 1 && fnwrap ? f : args[idx - fnwrap]
push!(activity, act_from_type(arg, reverse))
push_acts!(ad_inputs, arg, path[3:end], reverse)
end

outtys = MLIR.IR.Type[]
ret_activity = Int32[]

for a in linear_results
if TracedUtils.has_idx(a, resprefix)
if EnzymeCore.needs_primal(CMode)
push!(
outtys,
TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))),
)
end

if CMode <: ForwardMode && !(A <: Const)
push!(
outtys,
TracedUtils.batch_ty(
width,
TracedUtils.transpose_ty(
MLIR.IR.type(TracedUtils.get_mlir_data(a))
),
),
)
end

act = act_from_type(A, reverse, EnzymeCore.needs_primal(CMode))
push!(ret_activity, act)
if act == enzyme_out || act == enzyme_outnoneed
if width == 1
cst = @opcall fill(one(unwrapped_eltype(a)), size(a))
else
cst = @opcall fill(one(unwrapped_eltype(a)), (size(a)..., width))
end
push!(ad_inputs, cst.mlir_data)
end
else
if TracedUtils.has_idx(a, argprefix)
idx, path = TracedUtils.get_argidx(a, argprefix)
arg = idx == 1 && fnwrap ? f : args[idx - fnwrap]

act = act_from_type(arg, reverse, true)
push!(ret_activity, act)

if act == enzyme_out || act == enzyme_outnoneed
if width == 1
TracedUtils.push_val!(ad_inputs, arg.dval, path[3:end])
elseif arg.dval isa AbstractArray
TracedUtils.push_val!(ad_inputs, arg.dval, path[3:end])
else
TracedUtils.push_val!(
ad_inputs, call_with_reactant(stack, arg.dval), path[3:end]
)
end
end
else
act = act_from_type(Const, reverse, true)
push!(ret_activity, act)
end

push!(
outtys, TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a)))
)
end
end

for (i, act) in enumerate(activity)
if act == enzyme_out || act == enzyme_dup || act == enzyme_dupnoneed
push!(outtys, TracedUtils.batch_ty(width, in_tys[i]))
end
end

fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)(
[TracedUtils.transpose_val(v) for v in ad_inputs];
outputs=outtys,
fn=fname,
width,
strong_zero=EnzymeCore.strong_zero(CMode),
activity=MLIR.IR.Attribute([act_attr(a) for a in activity]),
ret_activity=MLIR.IR.Attribute([act_attr(a) for a in ret_activity]),
)

residx = 1

dresult = if CMode <: ForwardMode && !(A <: Const)
if width == 1
deepcopy(result)
else
ntuple(Val(width)) do i
Base.@_inline_meta
deepcopy(result)
return deepcopy(result)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert these changes. possible originated from juliaformatter v2?

@jumerckx jumerckx force-pushed the jm/deferred_within_autodiff branch from e0a228f to 106375a Compare December 6, 2025 05:52
@jumerckx jumerckx force-pushed the jm/deferred_within_autodiff branch from 106375a to 232b01f Compare December 6, 2025 05:53
Co-authored-by: Avik Pal <[email protected]>
@jumerckx jumerckx force-pushed the jm/deferred_within_autodiff branch from 232b01f to 626e909 Compare December 6, 2025 05:55
@avik-pal
Copy link
Collaborator

@jumerckx is this good to go?

@jumerckx
Copy link
Collaborator Author

Yes! Will merge after CI has finished

@jumerckx jumerckx merged commit 4f8b904 into main Dec 14, 2025
70 checks passed
@jumerckx jumerckx deleted the jm/deferred_within_autodiff branch December 14, 2025 18:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Enzyme.within_autodiff returns true inside compile

4 participants