Skip to content

Conversation

@mcabbott
Copy link
Member

This provides more complete rules for cat and friends. They should do better at preserving dimensions, I have lost my list of edge cases (as I wrote this a while ago) but a few which currently go wrong are:

julia> Zygote.gradient(x -> sum(cat(x,x,dims=3)), rand(2,2))  # should be a matrix, trivial dim
(2×2×1 Fill{Float64}: entries equal to 2.0,)

julia> Zygote.gradient(x -> sum(hcat(x,x)),rand(2,1,2))  # should be a 3-array
(2×2 Fill{Float64}: entries equal to 2.0,)

julia> ChainRules.rrule(hcat, rand(2,1,2), rand(2,1,2))[2](ones(2,2,2))[2]
2×2 Matrix{Float64}:
 1.0  1.0
 1.0  1.0

They should also almost always be type-stable. This does not seem to lead to major performance changes:

julia> @btime gradient(x -> sum(hcat(x,x',x)), $(rand(10,10)));
  1.546 μs (29 allocations: 4.53 KiB)  # with Zygote's rules
  1.442 μs (25 allocations: 4.34 KiB)  # using this PR instead

julia> @btime gradient(xs -> sum(reduce(hcat, xs)), $([rand(10) for _ in 1:100]));
  2.125 μs (18 allocations: 10.11 KiB)
  2.125 μs (18 allocations: 10.11 KiB)  # this PR

julia> @btime gradient(x -> sum(vcat(x,3,x,4)), $(rand(10)));
  3.042 μs (62 allocations: 2.34 KiB)
  2.704 μs (38 allocations: 1.19 KiB)  # this PR

julia> @btime gradient(x -> sum(cat(x,3,x,4, dims=(1,2))), $(rand(10,10)));
  11.458 μs (111 allocations: 8.00 KiB)
  7.646 μs (82 allocations: 7.09 KiB)  # this PR

I know of two possible issues. One is that Zygote has special paths to allow the gradient of vcat(1, cu(rand(3))) to make a scalar, disabling CUDA's scalar indexing complaint.

The second is that vcat(fill(1), fill(2)) will have numbers as its gradient, not zero-arrays. That seems better than vectors (the present behaviour) and consistent with how broadcasting behaves:

julia> gradient(x -> sum(vcat(x,x)), fill(3.14))[1] isa AbstractVector
true

julia> ChainRules.rrule(vcat, fill(1), fill(2))[2](ones(2))[2]
1-element Vector{Float64}:
 1.0
 
julia> gradient(x -> x .+ 1, fill(1))
(1,)

d > ndimsX ? 1 : (:)
end
end
dY[ind...] # no thunk as Xs may have 1 arg but 1 thunk is disallowed,
Copy link
Member Author

@mcabbott mcabbott Jun 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it should make thunks? Perhaps only for arguments with ndimsX>1? Complaint in comment was fixed in JuliaDiff/ChainRulesTestUtils.jl#175

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying that causes inference failures:

  Got exception outside of a @test
42
  return type Tuple{NoTangent,InplaceableThunk{Thunk{getfield(ChainRules, Symbol("##1109#1114")){Tuple{Colon,UnitRange{Int64}},Array{Float64,2}}},getfield(ChainRules, Symbol("##1110#1115")){Tuple{Colon,UnitRange{Int64}},Array{Float64,2}}},InplaceableThunk{Thunk{getfield(ChainRules, Symbol("##1109#1114")){Tuple{Colon,Int64},Array{Float64,2}}},getfield(ChainRules, Symbol("##1110#1115")){Tuple{Colon,Int64},Array{Float64,2}}},InplaceableThunk{Thunk{getfield(ChainRules, Symbol("##1109#1114")){Tuple{Colon,UnitRange{Int64}},Array{Float64,2}}},getfield(ChainRules, Symbol("##1110#1115")){Tuple{Colon,UnitRange{Int64}},Array{Float64,2}}}} does not match inferred return type Tuple{NoTangent,InplaceableThunk{_1,_2} where _2 where _1,InplaceableThunk{_1,_2} where _2 where _1,InplaceableThunk{_1,_2} where _2 where _1}
43

Y = hcat(Xs...) # note that Y always has 1-based indexing, even if X isa OffsetArray
ndimsY = Val(ndims(Y)) # this avoids closing over Y, Val() is essential for type-stability
sizes = map(size, Xs) # this avoids closing over Xs
function 🐈_pullback(dY)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 solid improvement

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls no 😂

Comment on lines 46 to 50
if ndimsX > 0
InplaceableThunk(@thunk(dY[ind...]), dX -> dX .+= view(dY, ind...))
else
dY[ind...]
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these two branches the same?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess for vcat(1,2,3) this frees dY.

Not sure whether @thunk is close enough to zero cost for scalar arguments? Haven't tried to time that.

Also, for vcat(1, cu([2,3])) this ought to be getindex_allowing_scalar(dY, ind...). Haven't thought more about where that ought to live.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, any idea why this should fail to infer? As here #451 (comment), it did work before the InplaceableThunk(@thunk(....

Copy link
Member Author

@mcabbott mcabbott Jun 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA attempts:

julia> v = cu(rand(3));

julia> CUDA.@allowscalar v[2]  # the official way
0.08611474f0

julia> @macroexpand CUDA.@allowscalar v[2]
quote
    GPUArrays.task_local_storage(:ScalarIndexing, GPUArrays.ScalarAllowed) do 
        v[2]
    end
end

julia> sum(view(v,2:2))  # another idea, not easy in N dims?
0.08611474f0

julia> sum(view(v,2))  # another idea
ERROR: MethodError: no method matching ndims(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{0}, Nothing, typeof(identity), Tuple{CuArray{Float32, 0}}})
Closest candidates are:
...
julia> Base.ndims(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{0}}) = 0

# some GPU times:

julia> @btime CUDA.@sync sum(view($v,2))
  44.976 μs (67 allocations: 1.44 KiB)
0.08611474f0

julia> @btime CUDA.@sync CUDA.@allowscalar $v[2]
  16.437 μs (26 allocations: 496 bytes)
0.08611474f0

# and CPU

julia> @btime ($(rand(100)))[2];
  3.569 ns (0 allocations: 0 bytes)

julia> @btime sum(view($(rand(100)),2));
  2.988 ns (0 allocations: 0 bytes)

julia> @btime sum(view($(rand(100)),2:2));
  3.571 ns (0 allocations: 0 bytes)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know nothing about GPU, we can probably poke vchuravy, if we need to for help with that.

Re: inference:
It might be worth just not using Thunks (or InplaceThunks) for now,
and we can add them back in later

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made JuliaGPU/GPUArrays.jl#363 after asking around. Maybe Base should also specialise...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inference now only fails on 1.0. Somehow I thought that got auto-skipped?

Copy link
Member

@oxinabox oxinabox Jun 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it gets auto-skipped on a case by case basis.
I am fine to skip it in general though.
We could set the global setting now for this to skip all on 1.0, that @simeonschaub added.
But i think that would be a seperate PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It only gets skipped for PkgEval, our CI always runs these. I agree that it might make sense to always disable them on 1.0 though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks. Disabled by hand now, for 1.0, hope I got the right ones!

@codecov-commenter
Copy link

codecov-commenter commented Jun 22, 2021

Codecov Report

Merging #451 (8060bd5) into master (52a0eea) will increase coverage by 0.05%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #451      +/-   ##
==========================================
+ Coverage   98.46%   98.51%   +0.05%     
==========================================
  Files          21       21              
  Lines        2024     2094      +70     
==========================================
+ Hits         1993     2063      +70     
  Misses         31       31              
Impacted Files Coverage Δ
src/rulesets/Base/array.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 52a0eea...8060bd5. Read the comment docs.

@oxinabox oxinabox merged commit 830e97d into JuliaDiff:master Jun 24, 2021
@mcabbott mcabbott deleted the cat branch July 28, 2021 12:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants