-
Notifications
You must be signed in to change notification settings - Fork 260
Description
Previously we were using a CuArray type that could represent a view, reshape, reinterpret, etc. For the sake of simplicity, I switched to a simpler CuArray type while reusing Base.SubArray, Base.ReshapeArray, etc. That requires use of type unions to, e.g., represent all dense or strided CuArrays:
Lines 146 to 164 in 75f7d30
| ContiguousSubCuArray{T,N,A<:CuArray} = Base.FastContiguousSubArray{T,N,A} | |
| # dense arrays: stored contiguously in memory | |
| DenseReinterpretCuArray{T,N,A<:Union{CuArray,ContiguousSubCuArray}} = Base.ReinterpretArray{T,N,S,A} where S | |
| DenseReshapedCuArray{T,N,A<:Union{CuArray,ContiguousSubCuArray,DenseReinterpretCuArray}} = Base.ReshapedArray{T,N,A} | |
| DenseSubCuArray{T,N,A<:Union{CuArray,DenseReshapedCuArray,DenseReinterpretCuArray}} = Base.FastContiguousSubArray{T,N,A} | |
| DenseCuArray{T,N} = Union{CuArray{T,N}, DenseSubCuArray{T,N}, DenseReshapedCuArray{T,N}, DenseReinterpretCuArray{T,N}} | |
| DenseCuVector{T} = DenseCuArray{T,1} | |
| DenseCuMatrix{T} = DenseCuArray{T,2} | |
| DenseCuVecOrMat{T} = Union{DenseCuVector{T}, DenseCuMatrix{T}} | |
| # strided arrays | |
| StridedSubCuArray{T,N,A<:Union{CuArray,DenseReshapedCuArray,DenseReinterpretCuArray}, | |
| I<:Tuple{Vararg{Union{Base.RangeIndex, Base.ReshapedUnitRange, | |
| Base.AbstractCartesianIndex}}}} = SubArray{T,N,A,I} | |
| StridedCuArray{T,N} = Union{CuArray{T,N}, StridedSubCuArray{T,N}, DenseReshapedCuArray{T,N}, DenseReinterpretCuArray{T,N}} | |
| StridedCuVector{T} = StridedCuArray{T,1} | |
| StridedCuMatrix{T} = StridedCuArray{T,2} | |
| StridedCuVecOrMat{T} = Union{StridedCuVector{T}, StridedCuMatrix{T}} |
These definitions are almost identical to how Base defines StridedArray. However, using them significantly regresses load time. For example, #450 adds them to a bunch of LinearAlgebra.mul! methods which badly affects time of using CUDA: +25%, https://speed.juliagpu.org/timeline/#/?exe=4&ben=latency/import&env=1&revs=50&base=3+96&equid=off&quarts=on&extr=on
In a similar vein, Adapt.jl defines a union that captures all array instances that can be used on the GPU (i.e. not necessarily dense or strided, but an Adjoint or PermuteDimsArray): https://github.com/JuliaGPU/Adapt.jl/blob/11d96a531cb70359e88ed2ad0d0a13a85727a204/src/wrappers.jl#L73-L92
Using these unions makes load time go crazy, e.g. with mul!(::CuArray, ::AnyCuArray...) (where AnyCuArray uses the Adapt.WrappedArray union) it goes from 5 to 25s.
I can understand how the large union from Adapt.jl is needlessly taxing on inference, and I guess we may need something like an AbstractWrappedArray here (JuliaLang/julia#31563). However, with StridedCuArray I had not expected these regressions, as Base uses similar patterns. Am I doing anything especially bad here? I'd like to start using StridedCuArray much more, in order to cover APIs that take stride inputs (which there are quite some).