-
Notifications
You must be signed in to change notification settings - Fork 64
Closed
Labels
ProjectTorelated to the projection functionalityrelated to the projection functionalityStructural TangentRelated to the `Tangent` type for structured (composite) valuesRelated to the `Tangent` type for structured (composite) valuesinplace accumulationfor things relating to inplace accumulation of gradientsfor things relating to inplace accumulation of gradients
Description
Because ProjectTo(::Tuple) recurses into the struct, and unthunk does not, this gives an error:
julia> x = ([1,2], [3,4]);
julia> g = map(y -> rrule(sum, y)[2](1)[2], x) # same as Diffractor.gradient(x -> sum(sum, x), x)[1]
(InplaceableThunk(ChainRules.var"#1489#1492"{Int64, Colon}(1, Colon()), Thunk(ChainRules.var"#1490#1493"{Int64, Colon, Vector{Int64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}(1, Colon(), [1, 2], ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(2),))))), InplaceableThunk(ChainRules.var"#1489#1492"{Int64, Colon}(1, Colon()), Thunk(ChainRules.var"#1490#1493"{Int64, Colon, Vector{Int64}, ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}(1, Colon(), [3, 4], ProjectTo{AbstractArray}(element = ProjectTo{Float64}(), axes = (Base.OneTo(2),))))))
julia> ProjectTo(x)(unthunk(g))
ERROR: MethodError: no method matching (::ProjectTo{AbstractArray, ...}}}}})(::InplaceableThunk{Thunk{...})
julia> ProjectTo(x)(Tangent{typeof(x), typeof(g)}(g)) # same
ERROR: MethodError: no method matching (::ProjectTo{AbstractArray{...I think we should make make Projection work with InplaceableThunk, either by unthunking, or else by inserting itself into the out-of-place version.
In this example g above is constructed by hand because Zygote un-thunks, and Diffractor has a pirate rule to deal with this. But in the real case, my nested g is being constructed by an rrule and the error occurs even with Zygote.
julia> @which ProjectTo(x[1])(g[1])
(project::ProjectTo{<:AbstractArray})(th::InplaceableThunk) in Diffractor at /Users/me/.julia/dev/Diffractor/src/extra_rules.jl:258Metadata
Metadata
Assignees
Labels
ProjectTorelated to the projection functionalityrelated to the projection functionalityStructural TangentRelated to the `Tangent` type for structured (composite) valuesRelated to the `Tangent` type for structured (composite) valuesinplace accumulationfor things relating to inplace accumulation of gradientsfor things relating to inplace accumulation of gradients