-
Notifications
You must be signed in to change notification settings - Fork 95
Improved rules for cats
#451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
b25c00a
4a5a0e5
bddd0d9
347bb01
219aa90
44627a1
c72a0e0
b3b2a26
c660230
8060bd5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,66 +24,175 @@ end | |
| ##### `hcat` (🐈) | ||
| ##### | ||
|
|
||
| function rrule(::typeof(hcat), A::AbstractArray, Bs::AbstractArray...) | ||
| function hcat_pullback(Ȳ) | ||
| Xs = (A, Bs...) | ||
| ntuple(length(Bs) + 2) do full_i | ||
| full_i == 1 && return NoTangent() | ||
|
|
||
| i = full_i - 1 | ||
| l = mapreduce(j->size(Xs[j], 2), Base.add_sum, 1:i-1; init=0) | ||
| u = l + size(Xs[i], 2) | ||
| dim = u > l + 1 ? (l+1:u) : u | ||
| # NOTE: The copy here is defensive, since `selectdim` returns a view which we can | ||
| # materialize with `copy` | ||
| copy(selectdim(Ȳ, 2, dim)) | ||
| using Compat # get here needs Compat 3.31 | ||
|
|
||
| function rrule(::typeof(hcat), Xs...) | ||
mcabbott marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Y = hcat(Xs...) # note that Y always has 1-based indexing, even if X isa OffsetArray | ||
| ndimsY = Val(ndims(Y)) # this avoids closing over Y, Val() is essential for type-stability | ||
| sizes = map(size, Xs) # this avoids closing over Xs | ||
| function 🐈_pullback(dY) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 solid improvement
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pls no 😂 |
||
| hi = Ref(0) # Ref avoids hi::Core.Box | ||
oxinabox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| dXs = map(sizes) do sizeX | ||
| ndimsX = length(sizeX) | ||
| lo = hi[] + 1 | ||
| hi[] += get(sizeX, 2, 1) | ||
| ind = ntuple(ndimsY) do d | ||
| if d==2 | ||
| d > ndimsX ? lo : lo:hi[] | ||
| else | ||
| d > ndimsX ? 1 : (:) | ||
| end | ||
| end | ||
| if ndimsX > 0 | ||
| InplaceableThunk(@thunk(dY[ind...]), dX -> dX .+= view(dY, ind...)) | ||
| else | ||
| dY[ind...] | ||
| end | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aren't these two branches the same?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess for Not sure whether Also, for
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, any idea why this should fail to infer? As here #451 (comment), it did work before the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CUDA attempts: julia> v = cu(rand(3));
julia> CUDA.@allowscalar v[2] # the official way
0.08611474f0
julia> @macroexpand CUDA.@allowscalar v[2]
quote
GPUArrays.task_local_storage(:ScalarIndexing, GPUArrays.ScalarAllowed) do
v[2]
end
end
julia> sum(view(v,2:2)) # another idea, not easy in N dims?
0.08611474f0
julia> sum(view(v,2)) # another idea
ERROR: MethodError: no method matching ndims(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{0}, Nothing, typeof(identity), Tuple{CuArray{Float32, 0}}})
Closest candidates are:
...
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know nothing about GPU, we can probably poke vchuravy, if we need to for help with that. Re: inference:
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made JuliaGPU/GPUArrays.jl#363 after asking around. Maybe Base should also specialise...
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inference now only fails on 1.0. Somehow I thought that got auto-skipped?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it gets auto-skipped on a case by case basis.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It only gets skipped for PkgEval, our CI always runs these. I agree that it might make sense to always disable them on 1.0 though.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, thanks. Disabled by hand now, for 1.0, hope I got the right ones! |
||
| end | ||
| return (NoTangent(), dXs...) | ||
| end | ||
| return hcat(A, Bs...), hcat_pullback | ||
| return Y, 🐈_pullback | ||
| end | ||
|
|
||
| function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat}) | ||
| function reduce_hcat_pullback(ΔY) | ||
| sizes = size.(As, 2) | ||
| cumsizes = cumsum(sizes) | ||
| ∂As = map(cumsizes, sizes) do post, diff | ||
| pre = post - diff + 1 | ||
| return ΔY[:, pre:post] | ||
| widths = map(A -> size(A,2), As) | ||
| function reduce_hcat_pullback_2(dY) | ||
| hi = Ref(0) | ||
| dAs = map(widths) do w | ||
| lo = hi[]+1 | ||
| hi[] += w | ||
| dY[:, lo:hi[]] | ||
| end | ||
| return (NoTangent(), NoTangent(), ∂As) | ||
| return (NoTangent(), NoTangent(), dAs) | ||
| end | ||
| return reduce(hcat, As), reduce_hcat_pullback | ||
| return reduce(hcat, As), reduce_hcat_pullback_2 | ||
| end | ||
|
|
||
| function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVector}) | ||
| axe = axes(As,1) | ||
| function reduce_hcat_pullback_1(dY) | ||
| hi = Ref(0) | ||
| dAs = map(_ -> dY[:, hi[]+=1], axe) | ||
| return (NoTangent(), NoTangent(), dAs) | ||
| end | ||
| return reduce(hcat, As), reduce_hcat_pullback_1 | ||
| end | ||
|
|
||
| ##### | ||
| ##### `vcat` | ||
| ##### | ||
|
|
||
| function rrule(::typeof(vcat), A::AbstractArray, Bs::AbstractArray...) | ||
| function vcat_pullback(Ȳ) | ||
| n = size(A, 1) | ||
| ∂A = copy(selectdim(Ȳ, 1, 1:n)) | ||
| ∂Bs = ntuple(length(Bs)) do i | ||
| l = n + mapreduce(j->size(Bs[j], 1), Base.add_sum, 1:i-1; init=0) | ||
| u = l + size(Bs[i], 1) | ||
| copy(selectdim(Ȳ, 1, l+1:u)) | ||
| function rrule(::typeof(vcat), Xs...) | ||
| Y = vcat(Xs...) | ||
| ndimsY = Val(ndims(Y)) | ||
| sizes = map(size, Xs) | ||
| function vcat_pullback(dY) | ||
| hi = Ref(0) | ||
| dXs = map(sizes) do sizeX | ||
| ndimsX = length(sizeX) | ||
| lo = hi[] + 1 | ||
| hi[] += get(sizeX, 1, 1) | ||
| ind = ntuple(ndimsY) do d | ||
| if d==1 | ||
| d > ndimsX ? lo : lo:hi[] | ||
| else | ||
| d > ndimsX ? 1 : (:) | ||
| end | ||
| end | ||
| if ndimsX > 0 | ||
| InplaceableThunk(@thunk(dY[ind...]), dX -> dX .+= view(dY, ind...)) | ||
| else | ||
| dY[ind...] | ||
| end | ||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| end | ||
| return (NoTangent(), ∂A, ∂Bs...) | ||
| return (NoTangent(), dXs...) | ||
| end | ||
| return vcat(A, Bs...), vcat_pullback | ||
| return Y, vcat_pullback | ||
| end | ||
|
|
||
| function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat}) | ||
| function reduce_vcat_pullback(ΔY) | ||
| sizes = size.(As, 1) | ||
| cumsizes = cumsum(sizes) | ||
| ∂As = map(cumsizes, sizes) do post, diff | ||
| pre = post - diff + 1 | ||
| return ΔY[pre:post, :] | ||
| Y = reduce(vcat, As) | ||
| ndimsY = Val(ndims(Y)) | ||
| heights = map(A -> size(A,1), As) | ||
| function reduce_vcat_pullback(dY) | ||
| hi = Ref(0) | ||
| dAs = map(heights) do z | ||
| lo = hi[]+1 | ||
| hi[] += z | ||
| ind = ntuple(d -> d==1 ? (lo:hi[]) : (:), ndimsY) | ||
| dY[ind...] | ||
| end | ||
| return (NoTangent(), NoTangent(), dAs) | ||
| end | ||
| return Y, reduce_vcat_pullback | ||
| end | ||
|
|
||
| ##### | ||
| ##### `cat` | ||
| ##### | ||
|
|
||
| _val(::Val{x}) where {x} = x | ||
|
|
||
| function rrule(::typeof(cat), Xs...; dims) | ||
| Y = cat(Xs...; dims=dims) | ||
| cdims = dims isa Val ? Int(_val(dims)) : dims isa Integer ? Int(dims) : Tuple(dims) | ||
| ndimsY = Val(ndims(Y)) | ||
| sizes = map(size, Xs) | ||
| function cat_pullback(dY) | ||
| prev = fill(0, _val(ndimsY)) # note that Y always has 1-based indexing, even if X isa OffsetArray | ||
| dXs = map(sizes) do sizeX | ||
| ndimsX = length(sizeX) | ||
| index = ntuple(ndimsY) do d | ||
| if d in cdims | ||
| d > ndimsX ? (prev[d]+1) : (prev[d]+1:prev[d]+sizeX[d]) | ||
| else | ||
| d > ndimsX ? 1 : (:) | ||
| end | ||
| end | ||
| for d in cdims | ||
| prev[d] += get(sizeX, d, 1) | ||
| end | ||
| if ndimsX > 0 | ||
| InplaceableThunk(@thunk(dY[index...]), dX -> dX .+= view(dY, index...)) | ||
| else | ||
| dY[index...] | ||
| end | ||
| end | ||
| return (NoTangent(), dXs...) | ||
| end | ||
| return Y, cat_pullback | ||
| end | ||
|
|
||
| ##### | ||
| ##### `hvcat` | ||
| ##### | ||
|
|
||
| function rrule(::typeof(hvcat), rows, values...) | ||
| Y = hvcat(rows, values...) | ||
| cols = size(Y,2) | ||
| ndimsY = Val(ndims(Y)) | ||
| sizes = map(size, values) | ||
| function hvcat_pullback(dY) | ||
| prev = fill(0, 2) | ||
| dXs = map(sizes) do sizeX | ||
| ndimsX = length(sizeX) | ||
| index = ntuple(ndimsY) do d | ||
| if d in (1, 2) | ||
| d > ndimsX ? (prev[d]+1) : (prev[d]+1:prev[d]+sizeX[d]) | ||
| else | ||
| d > ndimsX ? 1 : (:) | ||
| end | ||
| end | ||
| prev[2] += get(sizeX, 2, 1) | ||
| if prev[2] == cols | ||
| prev[2] = 0 | ||
| prev[1] += get(sizeX, 1, 1) | ||
| end | ||
| dY[index...] | ||
| end | ||
| return (NoTangent(), NoTangent(), ∂As) | ||
| return (NoTangent(), NoTangent(), dXs...) | ||
| end | ||
| return reduce(vcat, As), reduce_vcat_pullback | ||
| return Y, hvcat_pullback | ||
| end | ||
|
|
||
| ##### | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.