-
Notifications
You must be signed in to change notification settings - Fork 39
Properly set within_autodiff (#442) #490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
dd14b85 to
8731c7f
Compare
src/TracedUtils.jl
Outdated
|
|
||
| if isempty(kwargs) | ||
| Reactant.call_with_reactant(f, traced_args...) | ||
| if within_autodiff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like a cleaner way to do this, is not to have a second interpreter. But instead we can create a new global ref set to false, and overlay within_autodiff to lookup that var, and during autodiff set that to true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm okay with this though, but if we were to do it in this form, I would probably change call_with_reactant to take a config type var, which stores the current state of whether in autodiff or not (and also we can extend to other things down the line as well)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I got rid of the second interpreter like you described
a3115b5 to
2a4d0f2
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #490 +/- ##
==========================================
+ Coverage 42.55% 42.56% +0.01%
==========================================
Files 123 123
Lines 21816 21826 +10
==========================================
+ Hits 9283 9290 +7
- Misses 12533 12536 +3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
julia> using Enzyme
julia> function error_not_within_autodiff()
!Enzyme.within_autodiff() && error("Not within autodiff")
return nothing
end
error_not_within_autodiff (generic function with 1 method)
julia> fwd_within_autodiff(Mode, RT) = Enzyme.autodiff(Mode, error_not_within_autodiff, RT)
fwd_within_autodiff (generic function with 1 method)
julia> error_not_within_autodiff()
ERROR: Not within autodiff
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] error_not_within_autodiff()
@ Main ./REPL[5]:2
[3] top-level scope
@ REPL[7]:1
[4] top-level scope
@ none:1
julia> fwd_within_autodiff(Forward, Const)
()
julia> error_not_within_autodiff()
julia> Enzyme.within_autodiff()
falseI am extremely confused why is the 2nd call not throw an error here. Only happens if I call fwd_within_autodiff in between. cc @wsmoses this is in isolation from Reactant |
|
@vchuravy er wat |
… === overload_autodiff`. This doesn't work for some reason, the function within overload autodiff uses the original interpreter (?)
…mlir_fn. In order to pass this information from make_mlir_fn to call_with_reactant_generator, I introduced a new function `call_with_reactant_within_autodiff` which allows detection by looking at `self`.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
15eaa8f to
159837e
Compare
work with global ref instead This has the unfortunate downside of introducing a try finally block around `overload_autodiff`.
14864f3 to
bb72f52
Compare
src/Interpreter.jl
Outdated
| #=forward_rules=#false, | ||
| #=reverse_rules=#false, | ||
| #=inactive_rules=#false, | ||
| #=broadcast_rewrite=#false, | ||
| #=within_autodiff_rewrite=#set_reactant_abi, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| #=forward_rules=#false, | |
| #=reverse_rules=#false, | |
| #=inactive_rules=#false, | |
| #=broadcast_rewrite=#false, | |
| #=within_autodiff_rewrite=#set_reactant_abi, | |
| false, #=forward_rules=# | |
| false, #=reverse_rules=# | |
| false, #=inactive_rules=# | |
| false, #=broadcast_rewrite=# | |
| set_reactant_abi, #=within_autodiff_rewrite=# |
src/Interpreter.jl
Outdated
| #=forward_rules=#false, | ||
| #=reverse_rules=#false, | ||
| #=inactive_rules=#false, | ||
| #=broadcast_rewrite=#false, | ||
| #=within_autodiff_rewrite=#set_reactant_abi, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
| #=forward_rules=#false, | |
| #=reverse_rules=#false, | |
| #=inactive_rules=#false, | |
| #=broadcast_rewrite=#false, | |
| #=within_autodiff_rewrite=#set_reactant_abi, | |
| false, #=forward_rules=# | |
| false, #=reverse_rules=# | |
| false, #=inactive_rules=# | |
| false, #=broadcast_rewrite=# | |
| set_reactant_abi, #=within_autodiff_rewrite=# |
src/Interpreter.jl
Outdated
| false, #=broadcast_rewrite=# | ||
| false, #=within_autodiff_rewrite=# | ||
| set_reactant_abi, | ||
| set_reactant_abi, #=within_autodiff_rewrite=# |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment seems off
src/Enzyme.jl
Outdated
| ::Type{RT}, seen::IdDict, prev::RT, (::Val{copy_if_inactive})=Val(false) | ||
| )::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}} | ||
| if haskey(seen, prev) | ||
| return seen[prev] | ||
| end | ||
| if Enzyme.Compiler.guaranteed_const(eltype(RT)) | ||
| return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev | ||
| end | ||
| res = zero(prev) | ||
| seen[prev] = res | ||
| return res | ||
| end | ||
|
|
||
| function Enzyme.onehot(x::TracedRArray{T,N}) where {T,N} | ||
| onehot_matrix = promote_to(TracedRArray{T,2}, LinearAlgebra.I(length(x))) | ||
| return Tuple( | ||
| materialize_traced_array(reshape(y, size(x))) for y in eachcol(onehot_matrix) | ||
| ) | ||
| end | ||
|
|
||
| function EnzymeRules.inactive_noinl(::typeof(XLA.buffer_on_cpu), args...) | ||
| return nothing | ||
| end | ||
|
|
||
| function EnzymeRules.inactive_noinl(::typeof(XLA.addressable_devices), args...) | ||
| return nothing | ||
| end | ||
|
|
||
| function EnzymeRules.noalias(::typeof(Base.similar), a::ConcretePJRTArray, ::Type, args...) | ||
| return nothing | ||
| end | ||
|
|
||
| function EnzymeRules.noalias(::typeof(Base.similar), a::ConcreteIFRTArray, ::Type, args...) | ||
| return nothing | ||
| end | ||
|
|
||
| function EnzymeRules.augmented_primal( | ||
| config, | ||
| ofn::Const{typeof(Base.similar)}, | ||
| ::Type{RT}, | ||
| uval::Annotation{<:ConcretePJRTArray}, | ||
| T::Const{<:Type}, | ||
| args..., | ||
| ) where {RT} | ||
| primargs = ntuple(Val(length(args))) do i | ||
| Base.@_inline_meta | ||
| args[i].val | ||
| return args[i].val | ||
| end | ||
|
|
||
| primal = if EnzymeCore.needs_primal(config) | ||
| ofn.val(uval.val, T.val, primargs...) | ||
| else | ||
| nothing | ||
| end | ||
|
|
||
| shadow = if EnzymeRules.needs_shadow(config) | ||
| if EnzymeRules.width(config) == 1 | ||
| ConcretePJRTArray( | ||
| zeros(T.val, primargs...); | ||
| client=XLA.client(uval.val), | ||
| device=XLA.device(uval.val), | ||
| uval.val.sharding, | ||
| ) | ||
| else | ||
| ntuple(Val(EnzymeRules.width(config))) do i | ||
| Base.@_inline_meta | ||
| ConcretePJRTArray( | ||
| return ConcretePJRTArray( | ||
| zeros(T.val, primargs...); | ||
| client=XLA.client(uval.val), | ||
| device=XLA.device(uval.val), | ||
| uval.val.sharding, | ||
| ) | ||
| end | ||
| end | ||
| else | ||
| nothing | ||
| end | ||
|
|
||
| return EnzymeRules.AugmentedReturn{ | ||
| EnzymeRules.primal_type(config, RT),EnzymeRules.shadow_type(config, RT),Nothing | ||
| }( | ||
| primal, shadow, nothing | ||
| ) | ||
| end | ||
|
|
||
| function EnzymeRules.reverse( | ||
| config, | ||
| ofn::Const{typeof(Base.similar)}, | ||
| ::Type{RT}, | ||
| tape, | ||
| uval::Annotation{<:ConcretePJRTArray}, | ||
| T::Const{<:Type}, | ||
| args::Vararg{Annotation,N}, | ||
| ) where {RT,N} | ||
| ntuple(Val(N + 2)) do i | ||
| Base.@_inline_meta | ||
| nothing | ||
| return nothing | ||
| end | ||
| end | ||
|
|
||
| @inline function act_from_type(::A, reverse, needs_primal=true) where {A<:Annotation} | ||
| return act_from_type(A, reverse, needs_primal) | ||
| end | ||
|
|
||
| @inline function act_from_type(::Type{<:Active}, reverse, needs_primal) | ||
| return needs_primal ? enzyme_out : enzyme_outnoneed | ||
| end | ||
| @inline function act_from_type(::Type{<:Const}, reverse, needs_primal) | ||
| return needs_primal ? enzyme_const : enzyme_constnoneed | ||
| end | ||
|
|
||
| @inline function act_from_type(::Type{<:Duplicated}, reverse, needs_primal) | ||
| if reverse | ||
| return needs_primal ? enzyme_out : enzyme_outnoneed | ||
| else | ||
| return needs_primal ? enzyme_dup : enzyme_dupnoneed | ||
| end | ||
| end | ||
| @inline function act_from_type( | ||
| ::Type{<:Union{BatchDuplicated,StackedBatchDuplicated}}, reverse, needs_primal | ||
| ) | ||
| return act_from_type(Duplicated, reverse, needs_primal) | ||
| end | ||
|
|
||
| @inline function act_from_type(::Type{<:DuplicatedNoNeed}, reverse, needs_primal) | ||
| return reverse ? enzyme_out : enzyme_dupnoneed | ||
| end | ||
| @inline function act_from_type( | ||
| ::Type{<:Union{BatchDuplicatedNoNeed,StackedBatchDuplicatedNoNeed}}, | ||
| reverse, | ||
| needs_primal, | ||
| ) | ||
| return act_from_type(DuplicatedNoNeed, reverse, needs_primal) | ||
| end | ||
|
|
||
| function push_acts!(ad_inputs, x::Union{Const,Active}, path, reverse) | ||
| TracedUtils.push_val!(ad_inputs, x.val, path) | ||
| return nothing | ||
| end | ||
|
|
||
| function push_acts!(ad_inputs, x::Union{Duplicated,DuplicatedNoNeed}, path, reverse) | ||
| TracedUtils.push_val!(ad_inputs, x.val, path) | ||
| if !reverse | ||
| TracedUtils.push_val!(ad_inputs, x.dval, path) | ||
| end | ||
| end | ||
|
|
||
| function push_acts!( | ||
| ad_inputs, x::Union{BatchDuplicated,BatchDuplicatedNoNeed}, path, reverse | ||
| ) | ||
| TracedUtils.push_val!(ad_inputs, x.val, path) | ||
| if !reverse | ||
| TracedUtils.push_val!(ad_inputs, call_with_reactant(stack, x.dval), path) | ||
| end | ||
| end | ||
|
|
||
| function push_acts!( | ||
| ad_inputs, x::Union{StackedBatchDuplicated,StackedBatchDuplicatedNoNeed}, path, reverse | ||
| ) | ||
| TracedUtils.push_val!(ad_inputs, x.val, path) | ||
| if !reverse | ||
| TracedUtils.push_val!(ad_inputs, x.dval, path) | ||
| end | ||
| end | ||
|
|
||
| function set_act!(inp, path, reverse, tostore; emptypath=false, width=1) | ||
| x = if inp isa Active | ||
| inp.val | ||
| else | ||
| inp.dval | ||
| end | ||
|
|
||
| for p in path | ||
| x = traced_getfield(x, p) | ||
| end | ||
|
|
||
| if width == 1 | ||
| TracedUtils.set_mlir_data!(x, tostore) | ||
| elseif x isa AbstractArray | ||
| TracedUtils.set_mlir_data!(x, tostore) | ||
| else | ||
| tostore_traced = TracedRArray(tostore) | ||
| @assert length(x) == size(tostore_traced, ndims(tostore_traced)) | ||
| for (i, sl) in enumerate(eachslice(tostore_traced; dims=ndims(tostore_traced))) | ||
| TracedUtils.set_mlir_data!(x[i], TracedUtils.get_mlir_data(sl)) | ||
| end | ||
| end | ||
|
|
||
| emptypath && TracedUtils.set_paths!(x, ()) | ||
| return nothing | ||
| end | ||
|
|
||
| function act_attr(val) | ||
| val = @ccall MLIR.API.mlir_c.enzymeActivityAttrGet( | ||
| MLIR.IR.context()::MLIR.API.MlirContext, val::Int32 | ||
| )::MLIR.API.MlirAttribute | ||
| return MLIR.IR.Attribute(val) | ||
| end | ||
|
|
||
| function overload_autodiff( | ||
| ::CMode, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs} | ||
| ) where {CMode<:Mode,FA<:Annotation,A<:Annotation,Nargs} | ||
| reverse = CMode <: ReverseMode | ||
|
|
||
| width = Enzyme.same_or_one(1, args...) | ||
| if width == 0 | ||
| throw(ErrorException("Cannot differentiate with a batch size of 0")) | ||
| end | ||
|
|
||
| primf = f.val | ||
| primargs = ((v.val for v in args)...,) | ||
|
|
||
| argprefix::Symbol = gensym("autodiffarg") | ||
| resprefix::Symbol = gensym("autodiffresult") | ||
| resargprefix::Symbol = gensym("autodiffresarg") | ||
|
|
||
| mlir_fn_res = TracedUtils.make_mlir_fn( | ||
| primf, | ||
| primargs, | ||
| (), | ||
| string(f) * "_autodiff", | ||
| false; | ||
| argprefix, | ||
| resprefix, | ||
| resargprefix, | ||
| ) | ||
| (; result, linear_args, in_tys, linear_results) = mlir_fn_res | ||
| fnwrap = mlir_fn_res.fnwrapped | ||
| func2 = mlir_fn_res.f | ||
|
|
||
| activity = Int32[] | ||
| ad_inputs = MLIR.IR.Value[] | ||
|
|
||
| for a in linear_args | ||
| idx, path = TracedUtils.get_argidx(a, argprefix) | ||
| arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] | ||
| push!(activity, act_from_type(arg, reverse)) | ||
| push_acts!(ad_inputs, arg, path[3:end], reverse) | ||
| end | ||
|
|
||
| outtys = MLIR.IR.Type[] | ||
| ret_activity = Int32[] | ||
|
|
||
| for a in linear_results | ||
| if TracedUtils.has_idx(a, resprefix) | ||
| if EnzymeCore.needs_primal(CMode) | ||
| push!( | ||
| outtys, | ||
| TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))), | ||
| ) | ||
| end | ||
|
|
||
| if CMode <: ForwardMode && !(A <: Const) | ||
| push!( | ||
| outtys, | ||
| TracedUtils.batch_ty( | ||
| width, | ||
| TracedUtils.transpose_ty( | ||
| MLIR.IR.type(TracedUtils.get_mlir_data(a)) | ||
| ), | ||
| ), | ||
| ) | ||
| end | ||
|
|
||
| act = act_from_type(A, reverse, EnzymeCore.needs_primal(CMode)) | ||
| push!(ret_activity, act) | ||
| if act == enzyme_out || act == enzyme_outnoneed | ||
| if width == 1 | ||
| cst = @opcall fill(one(unwrapped_eltype(a)), size(a)) | ||
| else | ||
| cst = @opcall fill(one(unwrapped_eltype(a)), (size(a)..., width)) | ||
| end | ||
| push!(ad_inputs, cst.mlir_data) | ||
| end | ||
| else | ||
| if TracedUtils.has_idx(a, argprefix) | ||
| idx, path = TracedUtils.get_argidx(a, argprefix) | ||
| arg = idx == 1 && fnwrap ? f : args[idx - fnwrap] | ||
|
|
||
| act = act_from_type(arg, reverse, true) | ||
| push!(ret_activity, act) | ||
|
|
||
| if act == enzyme_out || act == enzyme_outnoneed | ||
| if width == 1 | ||
| TracedUtils.push_val!(ad_inputs, arg.dval, path[3:end]) | ||
| elseif arg.dval isa AbstractArray | ||
| TracedUtils.push_val!(ad_inputs, arg.dval, path[3:end]) | ||
| else | ||
| TracedUtils.push_val!( | ||
| ad_inputs, call_with_reactant(stack, arg.dval), path[3:end] | ||
| ) | ||
| end | ||
| end | ||
| else | ||
| act = act_from_type(Const, reverse, true) | ||
| push!(ret_activity, act) | ||
| end | ||
|
|
||
| push!( | ||
| outtys, TracedUtils.transpose_ty(MLIR.IR.type(TracedUtils.get_mlir_data(a))) | ||
| ) | ||
| end | ||
| end | ||
|
|
||
| for (i, act) in enumerate(activity) | ||
| if act == enzyme_out || act == enzyme_dup || act == enzyme_dupnoneed | ||
| push!(outtys, TracedUtils.batch_ty(width, in_tys[i])) | ||
| end | ||
| end | ||
|
|
||
| fname = TracedUtils.get_attribute_by_name(func2, "sym_name") | ||
| fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) | ||
| res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)( | ||
| [TracedUtils.transpose_val(v) for v in ad_inputs]; | ||
| outputs=outtys, | ||
| fn=fname, | ||
| width, | ||
| strong_zero=EnzymeCore.strong_zero(CMode), | ||
| activity=MLIR.IR.Attribute([act_attr(a) for a in activity]), | ||
| ret_activity=MLIR.IR.Attribute([act_attr(a) for a in ret_activity]), | ||
| ) | ||
|
|
||
| residx = 1 | ||
|
|
||
| dresult = if CMode <: ForwardMode && !(A <: Const) | ||
| if width == 1 | ||
| deepcopy(result) | ||
| else | ||
| ntuple(Val(width)) do i | ||
| Base.@_inline_meta | ||
| deepcopy(result) | ||
| return deepcopy(result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert these changes. possible originated from juliaformatter v2?
e0a228f to
106375a
Compare
Co-authored-by: avik Co-authored-by: Avik Pal <[email protected]>
106375a to
232b01f
Compare
Co-authored-by: Avik Pal <[email protected]>
232b01f to
626e909
Compare
|
@jumerckx is this good to go? |
|
Yes! Will merge after CI has finished |
fixes #442
needs Enzyme.jl: EnzymeAD/Enzyme.jl#2254
I had to introduce a new function
call_with_reactant_within_autodiffto smuggle thewithin_autodiffin thecall_with_reactant_generatorthrough theselfargument.I also tried doing things through
set_reactant_abibut that didn't seem to suffice (first commit).Perhaps the extra code in
set_reactant_abiisn't strictly necessary now so I can try removing it again if wanted.