|
1 | | -# batch-wise matrix multiplication |
2 | | -# wrapper for batched_gemm! |
| 1 | + |
3 | 2 | export batched_mul, batched_transpose, batched_adjoint |
4 | 3 |
|
5 | 4 | include("./batchedadjtrans.jl") |
6 | 5 |
|
| 6 | +using LinearAlgebra: BlasFloat, Transpose, Adjoint |
| 7 | + |
| 8 | +_unbatch(A) = A |
| 9 | +_unbatch(A::BatchedAdjOrTrans) = parent(A) |
| 10 | + |
7 | 11 | """ |
8 | 12 | batched_mul(A, B) -> C |
9 | 13 |
|
10 | 14 | Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for all `k`. |
| 15 | +If `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`. |
| 16 | +
|
| 17 | +To transpose each matrix apply `batched_transpose` to the array, |
| 18 | +and similarly `batched_adjoint`. Other permutations are also handled by BLAS, |
| 19 | +provided that the batch index `k` is not the first dimension of the underlying array. |
| 20 | +Thus `PermutedDimsArray(::Array, (1,3,2))` and `PermutedDimsArray(::Array, (3,1,2))` are fine. |
| 21 | +
|
| 22 | +However `A = PermutedDimsArray(::Array, (3,2,1))` is not acceptable to BLAS, |
| 23 | +since `stride(A,3) == 1`. This be copied, as doing so is faster than `batched_mul_generic!`. |
| 24 | +
|
| 25 | +Both this `copy` and `batched_mul_generic!` produce `@debug` messages, |
| 26 | +and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them. |
11 | 27 | """ |
12 | 28 | function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2} |
13 | | - axes(A, 3) == axes(B, 3) || throw(DimensionMismatch("batch size mismatch")) |
14 | | - T = promote_type(T1, T2) |
15 | | - C = similar(A, T, (axes(A, 1), axes(B, 2), axes(A, 3))) |
| 29 | + size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 || |
| 30 | + throw(DimensionMismatch("batch size mismatch: A != B")) |
| 31 | + _batched_mul(storage_typejoin(A, B), A, B) |
| 32 | +end |
| 33 | + |
| 34 | +function _batched_mul(::Type, A, B) |
| 35 | + T = promote_type(eltype(A), eltype(B)) |
| 36 | + C = similar(A, T, (size(A, 1), size(B, 2), max(size(A, 3), size(B, 3)))) |
16 | 37 | batched_mul!(C, A, B) |
| 38 | + C |
| 39 | +end |
| 40 | +function _batched_mul(::Type{<:DenseArray{T}}, A, B) where {T<:BlasFloat} |
| 41 | + C = similar(A, T, (size(A, 1), size(B, 2), max(size(A, 3), size(B, 3)))) |
| 42 | + batched_mul!(C, _copy_if_faster(A), _copy_if_faster(B)) |
| 43 | + C |
| 44 | +end |
| 45 | + |
| 46 | +function _copy_if_faster(X::AbstractArray{<:Number, 3}) |
| 47 | + is_strided(X) || return X |
| 48 | + if Base.stride(X, 3) == 1 && Base.stride(X, 1) != 1 |
| 49 | + @debug "copying to avoid batched_mul_generic!" typeof(X) size(X) strides(X) |
| 50 | + return copy(X) |
| 51 | + end |
| 52 | + X |
| 53 | +end |
| 54 | +function _copy_if_faster(X::BatchedAdjoint{<:Complex}) |
| 55 | + Xbase = _unbatch(X) |
| 56 | + is_strided(Xbase) || return X |
| 57 | + if Base.stride(Xbase, 1) != 1 |
| 58 | + @debug "copying to avoid batched_mul_generic!" typeof(X) size(X) strides(_unbatch(X)) |
| 59 | + return copy(X) # or batched_adjoint(copy(Xbase)), may be better on GPU? |
| 60 | + end |
| 61 | + X |
17 | 62 | end |
18 | 63 |
|
19 | 64 | """ |
20 | 65 | batched_mul!(C, A, B) -> C |
| 66 | + batched_mul!(C, A, B, α=1, β=0) |
| 67 | +
|
| 68 | +In-place batched matrix multiplication, equivalent to |
| 69 | +`mul!(C[:,:,k], A[:,:,k], B[:,:,k], α, β)` for all `k`. |
| 70 | +If `size(B,3) == 1` then every batch uses `B[:,:,1]` instead. |
21 | 71 |
|
22 | | -In-place batched matrix multiplication, |
23 | | -equivalent to `mul!(C[:,:,k], A[:,:,k], B[:,:,k])` for all `k`. |
| 72 | +This will call `batched_gemm!` whenever possible. For real arrays this means that, |
| 73 | +for `X ∈ [A,B,C]`, either `strides(X,1)==1` or `strides(X,2)==1`, the latter may |
| 74 | +be caused by `batched_transpose` or by for instance `PermutedDimsArray(::Array, (3,1,2))`. |
| 75 | +Unlike `batched_mul` this will never make a copy. |
| 76 | +
|
| 77 | +For complex arrays, the wrapper made by `batched_adjoint` must be outermost to be seen. |
| 78 | +In this case the strided accepted by BLAS are more restricted, if `stride(C,1)==1` then |
| 79 | +only `stride(AorB::BatchedAdjoint,2) == 1` is accepted. |
24 | 80 | """ |
25 | | -function batched_mul! end |
| 81 | +function batched_mul!(C::AbstractArray{T,3}, A::AbstractArray{<:Any,3}, B::AbstractArray{<:Any,3}, |
| 82 | + α::Number=one(T), β::Number=zero(T)) where {T} |
| 83 | + _batched_mul!(storage_typejoin(C,A,B), C, A, B, α, β) |
| 84 | + C |
| 85 | +end |
26 | 86 |
|
27 | | -_unbatch(A) = A |
28 | | -_unbatch(A::BatchedAdjOrTrans) = A.parent |
| 87 | +_batched_mul!(::Type, C, A, B, α::Number, β::Number) = batched_mul_generic!(C, A, B, α, β) |
29 | 88 |
|
30 | | -# batched_gemm! |
| 89 | +_batched_mul!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat} = |
| 90 | + _batched_try_gemm!(DT, C, A, B, α, β) |
31 | 91 |
|
32 | | -const _GemmFloat = Union{Float64, Float32, ComplexF64, ComplexF32} |
| 92 | +function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {DT<:DenseArray{T}} where {T<:BlasFloat} |
33 | 93 |
|
34 | | -_BATCHED_GEMM_LIST = [ |
35 | | - (:(StridedArray{T, 3}), 'N'), |
36 | | - (:(BatchedTranspose{T, <:StridedArray{T, 3}}), 'T'), |
37 | | - (:(BatchedAdjoint{T, <:StridedArray{T, 3}}), 'C') |
38 | | -] |
| 94 | + alpha, beta = promote(α, β, zero(T)) |
| 95 | + alpha isa T && beta isa T || return batched_mul_generic!(C, A, B, α, β) |
39 | 96 |
|
40 | | -for (TA, transA) in _BATCHED_GEMM_LIST, (TB, transB) in _BATCHED_GEMM_LIST |
41 | | - @eval function batched_mul!(C::StridedArray{T, 3}, A::$TA, B::$TB) where {T<:_GemmFloat} |
42 | | - batched_gemm!($transA, $transB, one(T), _unbatch(A), _unbatch(B), zero(T), C) |
43 | | - C |
| 97 | + are_strided(C, _unbatch(A), _unbatch(B)) || return batched_mul_generic!(C, A, B, α, β) |
| 98 | + |
| 99 | + if Base.stride(C,1) == 1 |
| 100 | + elseif Base.stride(C,2) == 1 |
| 101 | + @debug "transforming C = A * B into C' = B' * A'" size(C) strides(C) |
| 102 | + return batched_mul!(batched_adjoint(C), batched_adjoint(B), batched_adjoint(A), α, β) |
| 103 | + else |
| 104 | + return batched_mul_generic!(C, A, B, α, β) |
| 105 | + end |
| 106 | + |
| 107 | + blasA, transA = if A isa BatchedAdjoint && T <: Complex |
| 108 | + Base.stride(parent(A),1) == 1 || return batched_mul_generic!(C, A, B, α, β) |
| 109 | + parent(A), 'C' |
| 110 | + elseif Base.stride(A,1) == 1 |
| 111 | + A, 'N' |
| 112 | + elseif Base.stride(A,2) == 1 |
| 113 | + batched_transpose(A), 'T' |
| 114 | + else |
| 115 | + return batched_mul_generic!(C, A, B, α, β) |
44 | 116 | end |
| 117 | + |
| 118 | + blasB, transB = if B isa BatchedAdjoint && T <: Complex |
| 119 | + Base.stride(parent(B),1) == 1 || return batched_mul_generic!(C, A, B, α, β) |
| 120 | + parent(B), 'C' |
| 121 | + elseif Base.stride(B,1) == 1 |
| 122 | + B, 'N' |
| 123 | + elseif Base.stride(B,2) == 1 |
| 124 | + batched_transpose(B), 'T' |
| 125 | + else |
| 126 | + return batched_mul_generic!(C, A, B, α, β) |
| 127 | + end |
| 128 | + |
| 129 | + _batched_gemm!(DT, transA, transB, alpha, blasA, blasB, beta, C) |
| 130 | + C |
45 | 131 | end |
46 | 132 |
|
47 | | -# fallback |
| 133 | +_batched_gemm!(::Type{<:Array}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = |
| 134 | + batched_gemm!(transA, transB, α, A, B, β, C) |
48 | 135 |
|
49 | 136 | _BATCHED_LIST = [ |
50 | 137 | (:(AbstractArray{<:Any, 3}), :identity), |
51 | | - (:(BatchedTranspose{<:Any, <:AbstractArray{<:Any, 3}}), :transpose), |
52 | | - (:(BatchedAdjoint{<:Any, <:AbstractArray{<:Any, 3}}), :adjoint) |
| 138 | + (:BatchedTranspose, :transpose), |
| 139 | + (:BatchedAdjoint, :adjoint), |
53 | 140 | ] |
54 | 141 | for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST |
55 | | - @eval function batched_mul!(C::AbstractArray{<:Any, 3}, A::$TA, B::$TB) |
56 | | - axes(A, 3) == axes(B, 3) == axes(C, 3) || throw(DimensionMismatch("batch size mismatch")) |
| 142 | + |
| 143 | + @eval function batched_mul_generic!(C::AbstractArray{T, 3}, A::$TA, B::$TB, |
| 144 | + α::Number=one(T), β::Number=zero(T)) where {T} |
| 145 | + |
| 146 | + size(A, 3) == size(C, 3) || size(A, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != C")) |
| 147 | + size(B, 3) == size(C, 3) || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: B != C")) |
57 | 148 | @debug "calling fallback method for batched_mul!" typeof(A) typeof(B) typeof(C) |
58 | | - A′, B′ = _unbatch(A), _unbatch(B) |
59 | | - @inbounds for k in axes(C, 3) |
60 | | - @views mul!(C[:,:,k], $fA(A′[:,:,k]), $fB(B′[:,:,k])) |
| 149 | + |
| 150 | + Abase, Bbase = _unbatch(A), _unbatch(B) |
| 151 | + sA, oA = size(A,3) == 1 ? (0,1) : (1,0) |
| 152 | + sB, oB = size(B,3) == 1 ? (0,1) : (1,0) |
| 153 | + |
| 154 | + @inbounds for k in 1:size(C,3) |
| 155 | + @views mul!(C[:,:,k], $fA(Abase[:,:,k*sA+oA]), $fB(Bbase[:,:,k*sB+oB]), α, β) |
61 | 156 | end |
62 | 157 | C |
63 | 158 | end |
| 159 | + |
| 160 | +end |
| 161 | + |
| 162 | +""" |
| 163 | + storage_type(A) -> Type |
| 164 | +
|
| 165 | +Removes all wrappers to return the `Array` or `CuArray` (or whatever) type within. |
| 166 | +``` |
| 167 | +julia> view(reshape(ones(10)',2,5),:, 3:4) |> storage_type |
| 168 | +Array{Float64,1} |
| 169 | +
|
| 170 | +julia> reshape(sparse(rand(10)), 5,2) |> storage_type |
| 171 | +SparseVector{Float64,Int64} |
| 172 | +``` |
| 173 | +""" |
| 174 | +function storage_type(A::AbstractArray) |
| 175 | + P = parent(A) |
| 176 | + typeof(A) === typeof(P) ? typeof(A) : storage_type(P) |
64 | 177 | end |
| 178 | +storage_type(A) = typeof(A) |
| 179 | + |
| 180 | +""" |
| 181 | + storage_typejoin(A, B, C, ...) -> Type |
| 182 | +
|
| 183 | +Reduces with `Base.promote_typejoin`, in order that this conveys useful information |
| 184 | +for dispatching to BLAS. It does not tell you what container to allocate: |
| 185 | +``` |
| 186 | +julia> storage_typejoin(rand(2), rand(Float32, 2)) |
| 187 | +Array{T,1} where T |
| 188 | +
|
| 189 | +julia> eltype(ans) <: LinearAlgebra.BlasFloat |
| 190 | +false |
| 191 | +
|
| 192 | +julia> storage_typejoin(rand(2), rand(2,3), rand(2,3,4)) |
| 193 | +Array{Float64,N} where N |
| 194 | +``` |
| 195 | +""" |
| 196 | +storage_typejoin(A, Bs...) = Base.promote_typejoin(storage_type(A), storage_typejoin(Bs...)) |
| 197 | +storage_typejoin(A) = storage_type(A) |
| 198 | + |
| 199 | +""" |
| 200 | + is_strided(A::AbstractArray) -> Bool |
| 201 | +
|
| 202 | +This generalises `A isa StridedArray` to treat wrappers like `A::PermutedDimsArray`, |
| 203 | +for which it returns `is_strided(parent(A))`. |
| 204 | +
|
| 205 | +Other wrappers (defined outside Base, LinearAlgebra) are assumed not to break |
| 206 | +strided-ness, and hence also return `is_strided(parent(A))`. |
| 207 | +This correctly handles things like `NamedDimsArray` wihch don't alter indexing. |
| 208 | +However, it's a little pessimistic in that e.g. a `view` of such a container will return |
| 209 | +`false`, even in cases where the same `view` of `parent(A)` would be a `StridedArray`. |
| 210 | +""" |
| 211 | +is_strided(A::StridedArray) = true |
| 212 | +is_strided(A) = false |
| 213 | +function is_strided(A::AbstractArray) |
| 214 | + M = parentmodule(typeof(A)) |
| 215 | + if parent(A) === A # SparseMatrix, StaticArray, etc |
| 216 | + false |
| 217 | + elseif M === Base || M === Core || M ===LinearAlgebra |
| 218 | + # bad reshapes, etc, plus Diagonal, UpperTriangular, etc. |
| 219 | + false |
| 220 | + else |
| 221 | + is_strided(parent(A)) # PermutedDimsArray, NamedDimsArray |
| 222 | + end |
| 223 | +end |
| 224 | + |
| 225 | +is_strided(A::BatchedAdjoint) = eltype(A) <: Real && is_strided(parent(A)) |
| 226 | +is_strided(A::BatchedTranspose) = is_strided(parent(A)) |
| 227 | + |
| 228 | +is_strided(A::LinearAlgebra.Transpose) = is_strided(parent(A)) |
| 229 | +is_strided(A::LinearAlgebra.Adjoint) = eltype(A) <: Real && is_strided(parent(A)) |
| 230 | + |
| 231 | +are_strided(As...) = mapfoldl(is_strided, &, As; init=true) |
0 commit comments