-
Notifications
You must be signed in to change notification settings - Fork 95
Improved rules for cats
#451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/rulesets/Base/array.jl
Outdated
| d > ndimsX ? 1 : (:) | ||
| end | ||
| end | ||
| dY[ind...] # no thunk as Xs may have 1 arg but 1 thunk is disallowed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it should make thunks? Perhaps only for arguments with ndimsX>1? Complaint in comment was fixed in JuliaDiff/ChainRulesTestUtils.jl#175
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying that causes inference failures:
Got exception outside of a @test
42
return type Tuple{NoTangent,InplaceableThunk{Thunk{getfield(ChainRules, Symbol("##1109#1114")){Tuple{Colon,UnitRange{Int64}},Array{Float64,2}}},getfield(ChainRules, Symbol("##1110#1115")){Tuple{Colon,UnitRange{Int64}},Array{Float64,2}}},InplaceableThunk{Thunk{getfield(ChainRules, Symbol("##1109#1114")){Tuple{Colon,Int64},Array{Float64,2}}},getfield(ChainRules, Symbol("##1110#1115")){Tuple{Colon,Int64},Array{Float64,2}}},InplaceableThunk{Thunk{getfield(ChainRules, Symbol("##1109#1114")){Tuple{Colon,UnitRange{Int64}},Array{Float64,2}}},getfield(ChainRules, Symbol("##1110#1115")){Tuple{Colon,UnitRange{Int64}},Array{Float64,2}}}} does not match inferred return type Tuple{NoTangent,InplaceableThunk{_1,_2} where _2 where _1,InplaceableThunk{_1,_2} where _2 where _1,InplaceableThunk{_1,_2} where _2 where _1}
43
| Y = hcat(Xs...) # note that Y always has 1-based indexing, even if X isa OffsetArray | ||
| ndimsY = Val(ndims(Y)) # this avoids closing over Y, Val() is essential for type-stability | ||
| sizes = map(size, Xs) # this avoids closing over Xs | ||
| function 🐈_pullback(dY) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 solid improvement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls no 😂
| if ndimsX > 0 | ||
| InplaceableThunk(@thunk(dY[ind...]), dX -> dX .+= view(dY, ind...)) | ||
| else | ||
| dY[ind...] | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't these two branches the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess for vcat(1,2,3) this frees dY.
Not sure whether @thunk is close enough to zero cost for scalar arguments? Haven't tried to time that.
Also, for vcat(1, cu([2,3])) this ought to be getindex_allowing_scalar(dY, ind...). Haven't thought more about where that ought to live.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, any idea why this should fail to infer? As here #451 (comment), it did work before the InplaceableThunk(@thunk(....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CUDA attempts:
julia> v = cu(rand(3));
julia> CUDA.@allowscalar v[2] # the official way
0.08611474f0
julia> @macroexpand CUDA.@allowscalar v[2]
quote
GPUArrays.task_local_storage(:ScalarIndexing, GPUArrays.ScalarAllowed) do
v[2]
end
end
julia> sum(view(v,2:2)) # another idea, not easy in N dims?
0.08611474f0
julia> sum(view(v,2)) # another idea
ERROR: MethodError: no method matching ndims(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{0}, Nothing, typeof(identity), Tuple{CuArray{Float32, 0}}})
Closest candidates are:
...julia> Base.ndims(::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{0}}) = 0
# some GPU times:
julia> @btime CUDA.@sync sum(view($v,2))
44.976 μs (67 allocations: 1.44 KiB)
0.08611474f0
julia> @btime CUDA.@sync CUDA.@allowscalar $v[2]
16.437 μs (26 allocations: 496 bytes)
0.08611474f0
# and CPU
julia> @btime ($(rand(100)))[2];
3.569 ns (0 allocations: 0 bytes)
julia> @btime sum(view($(rand(100)),2));
2.988 ns (0 allocations: 0 bytes)
julia> @btime sum(view($(rand(100)),2:2));
3.571 ns (0 allocations: 0 bytes)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know nothing about GPU, we can probably poke vchuravy, if we need to for help with that.
Re: inference:
It might be worth just not using Thunks (or InplaceThunks) for now,
and we can add them back in later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made JuliaGPU/GPUArrays.jl#363 after asking around. Maybe Base should also specialise...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inference now only fails on 1.0. Somehow I thought that got auto-skipped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it gets auto-skipped on a case by case basis.
I am fine to skip it in general though.
We could set the global setting now for this to skip all on 1.0, that @simeonschaub added.
But i think that would be a seperate PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It only gets skipped for PkgEval, our CI always runs these. I agree that it might make sense to always disable them on 1.0 though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, thanks. Disabled by hand now, for 1.0, hope I got the right ones!
Codecov Report
@@ Coverage Diff @@
## master #451 +/- ##
==========================================
+ Coverage 98.46% 98.51% +0.05%
==========================================
Files 21 21
Lines 2024 2094 +70
==========================================
+ Hits 1993 2063 +70
Misses 31 31
Continue to review full report at Codecov.
|
This provides more complete rules for
catand friends. They should do better at preserving dimensions, I have lost my list of edge cases (as I wrote this a while ago) but a few which currently go wrong are:They should also almost always be type-stable. This does not seem to lead to major performance changes:
I know of two possible issues. One is that Zygote has special paths to allow the gradient of
vcat(1, cu(rand(3)))to make a scalar, disabling CUDA's scalar indexing complaint.The second is that
vcat(fill(1), fill(2))will have numbers as its gradient, not zero-arrays. That seems better than vectors (the present behaviour) and consistent with how broadcasting behaves: