From 03744a62e2ceccedb2776f1b23e3063ae21ada92 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 24 Oct 2020 15:20:06 +0200 Subject: [PATCH 01/17] allow PermutedDimsArray in gemm_strided_batched copied from https://github.com/JuliaGPU/CuArrays.jl/pull/664, needs https://github.com/FluxML/NNlib.jl/pull/191 --- Project.toml | 2 ++ lib/cublas/wrappers.jl | 25 +++++++++++++------------ src/nnlib.jl | 15 +++------------ test/nnlib.jl | 19 +++++++++++++++++++ 4 files changed, 37 insertions(+), 24 deletions(-) diff --git a/Project.toml b/Project.toml index 132993f12a..66d2c5beb0 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" @@ -31,6 +32,7 @@ AbstractFFTs = "0.4, 0.5" Adapt = "2.2" BFloat16s = "0.1" CEnum = "0.2, 0.3, 0.4" +Compat = "3.9" DataStructures = "0.17, 0.18" ExprTools = "0.1" GPUArrays = "6.1.0" diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index b2682cae8a..d81faf83f1 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -923,15 +923,16 @@ for (fname, elty) in function gemm_strided_batched!(transA::Char, transB::Char, alpha::Number, - A::DenseCuArray{$elty, 3}, - B::DenseCuArray{$elty, 3}, + A::AbstractArray{$elty, 3}, + B::AbstractArray{$elty, 3}, beta::Number, - C::DenseCuArray{$elty, 3}) + C::AbstractArray{$elty, 3}) m = size(A, transA == 'N' ? 1 : 2) k = size(A, transA == 'N' ? 2 : 1) n = size(B, transB == 'N' ? 2 : 1) - @assert size(A, 3) == size(B, 3) == size(C, 3) "Batch size mismatch" + @assert size(A, 3) == size(C, 3) || size(A, 3) == 1 "batch size mismatch: A != C" + @assert size(B, 3) == size(C, 3) || size(B, 3) == 1 "batch size mismatch: B != C" if m != size(C,1) || n != size(C,2) || k != size(B, transB == 'N' ? 1 : 2) throw(DimensionMismatch("")) @@ -940,10 +941,10 @@ for (fname, elty) in ldb = max(1,stride(B,2)) ldc = max(1,stride(C,2)) - strideA = stride(A, 3) - strideB = stride(B, 3) + strideA = size(A, 3) == 1 ? 0 : stride(A, 3) + strideB = size(B, 3) == 1 ? 0 : stride(B, 3) strideC = stride(C, 3) - batchCount = size(A, 3) + batchCount = size(C, 3) $fname(handle(), transA, transB, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount) C @@ -951,15 +952,15 @@ for (fname, elty) in function gemm_strided_batched(transA::Char, transB::Char, alpha::Number, - A::DenseCuArray{$elty, 3}, - B::DenseCuArray{$elty, 3}) - C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), size(A, 3))) + A::AbstractArray{$elty, 3}, + B::AbstractArray{$elty, 3}) + C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), max(size(A, 3), size(B, 3)))) gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C ) end function gemm_strided_batched(transA::Char, transB::Char, - A::DenseCuArray{$elty, 3}, - B::DenseCuArray{$elty, 3}) + A::AbstractArray{$elty, 3}, + B::AbstractArray{$elty, 3}) gemm_strided_batched(transA, transB, one($elty), A, B) end end diff --git a/src/nnlib.jl b/src/nnlib.jl index 9ffdba88ee..01d0b3a95b 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -23,16 +23,7 @@ end # Batched matrix multiplication +# Using storage_type from https://github.com/FluxML/NNlib.jl/pull/191 -const batched_gemm_args = [ - (:(CuArray{T, 3}), 'N'), - (:(NNlib.BatchedTranspose{T, <:CuArray{T, 3}}), 'T'), - (:(NNlib.BatchedAdjoint{T, <:CuArray{T, 3}}), 'C') -] - -for (TA, transA) in batched_gemm_args, (TB, transB) in batched_gemm_args - @eval function NNlib.batched_mul!(C::CuArray{T, 3}, A::$TA, B::$TB) where {T<:CUBLAS.CublasFloat} - CUBLAS.gemm_strided_batched!($transA, $transB, one(T), NNlib._unbatch(A), NNlib._unbatch(B), zero(T), C) - C - end -end + NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = + CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C) diff --git a/test/nnlib.jl b/test/nnlib.jl index 930b0a1b7d..df810b4361 100644 --- a/test/nnlib.jl +++ b/test/nnlib.jl @@ -16,6 +16,25 @@ using NNlib @test CuArray(Ca) ≈ batched_mul(CuArray(A), batched_adjoint(CuArray(B))) end +@testset "NNlib storage_type etc." begin + using LinearAlgebra + using NNlib: is_strided, are_strided, storage_type + + M = cu(ones(10,10)) + + @test is_strided(M) + @test is_strided(view(M, 1:2:5,:)) + @test is_strided(PermutedDimsArray(M, (2,1))) + + @test !is_strided(reshape(view(M, 1:2:10,:), 10,:)) + @test !is_strided((M .+ im)') + @test !is_strided(Diagonal(cu(ones(3)))) + + @test storage_type(M) == CuArray{Float32,2,Nothing} + @test storage_type(reshape(view(M, 1:2:10,:), 10,:)) == CuArray{Float32,2,Nothing} + +end + @testset "Broadcast Fix" begin if CUDA.has_cudnn() @test testf(x -> logσ.(x), rand(5)) From 0692e58d5c56daf4650ec1f8e78547e4a72b878b Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Thu, 12 Nov 2020 10:37:56 +0100 Subject: [PATCH 02/17] work via _gemm_strided_batched --- lib/cublas/wrappers.jl | 20 +++++++++++++++----- src/nnlib.jl | 6 ++---- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index d81faf83f1..0a95050188 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -923,10 +923,20 @@ for (fname, elty) in function gemm_strided_batched!(transA::Char, transB::Char, alpha::Number, - A::AbstractArray{$elty, 3}, + A::DenseCuArray{$elty, 3}, + B::DenseCuArray{$elty, 3}, + beta::Number, + C::DenseCuArray{$elty, 3}) + _gemm_strided_batched(transA, transB, alpha, A, B, beta, C) + end + function _gemm_strided_batched!(transA::Char, + transB::Char, + alpha::Number, + A::AbstractArray{$elty, 3}, # allows PermutedDimsArray B::AbstractArray{$elty, 3}, beta::Number, C::AbstractArray{$elty, 3}) + m = size(A, transA == 'N' ? 1 : 2) k = size(A, transA == 'N' ? 2 : 1) n = size(B, transB == 'N' ? 2 : 1) @@ -952,15 +962,15 @@ for (fname, elty) in function gemm_strided_batched(transA::Char, transB::Char, alpha::Number, - A::AbstractArray{$elty, 3}, - B::AbstractArray{$elty, 3}) + A::DenseCuArray{$elty, 3}, + B::DenseCuArray{$elty, 3}) C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), max(size(A, 3), size(B, 3)))) gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C ) end function gemm_strided_batched(transA::Char, transB::Char, - A::AbstractArray{$elty, 3}, - B::AbstractArray{$elty, 3}) + A::DenseCuArray{$elty, 3}, + B::DenseCuArray{$elty, 3}) gemm_strided_batched(transA, transB, one($elty), A, B) end end diff --git a/src/nnlib.jl b/src/nnlib.jl index 01d0b3a95b..2da7149223 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -23,7 +23,5 @@ end # Batched matrix multiplication -# Using storage_type from https://github.com/FluxML/NNlib.jl/pull/191 - - NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = - CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C) +NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = + CUBLAS._gemm_strided_batched!(transA, transB, α, A, B, β, C) From 63e23ba7bd0204af9a2174133a40aa083cdce075 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 13 Nov 2020 19:26:05 +0100 Subject: [PATCH 03/17] Revert "work via _gemm_strided_batched" This reverts commit 0692e58d5c56daf4650ec1f8e78547e4a72b878b. --- lib/cublas/wrappers.jl | 20 +++++--------------- src/nnlib.jl | 6 ++++-- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index 0a95050188..d81faf83f1 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -923,20 +923,10 @@ for (fname, elty) in function gemm_strided_batched!(transA::Char, transB::Char, alpha::Number, - A::DenseCuArray{$elty, 3}, - B::DenseCuArray{$elty, 3}, - beta::Number, - C::DenseCuArray{$elty, 3}) - _gemm_strided_batched(transA, transB, alpha, A, B, beta, C) - end - function _gemm_strided_batched!(transA::Char, - transB::Char, - alpha::Number, - A::AbstractArray{$elty, 3}, # allows PermutedDimsArray + A::AbstractArray{$elty, 3}, B::AbstractArray{$elty, 3}, beta::Number, C::AbstractArray{$elty, 3}) - m = size(A, transA == 'N' ? 1 : 2) k = size(A, transA == 'N' ? 2 : 1) n = size(B, transB == 'N' ? 2 : 1) @@ -962,15 +952,15 @@ for (fname, elty) in function gemm_strided_batched(transA::Char, transB::Char, alpha::Number, - A::DenseCuArray{$elty, 3}, - B::DenseCuArray{$elty, 3}) + A::AbstractArray{$elty, 3}, + B::AbstractArray{$elty, 3}) C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), max(size(A, 3), size(B, 3)))) gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C ) end function gemm_strided_batched(transA::Char, transB::Char, - A::DenseCuArray{$elty, 3}, - B::DenseCuArray{$elty, 3}) + A::AbstractArray{$elty, 3}, + B::AbstractArray{$elty, 3}) gemm_strided_batched(transA, transB, one($elty), A, B) end end diff --git a/src/nnlib.jl b/src/nnlib.jl index 2da7149223..01d0b3a95b 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -23,5 +23,7 @@ end # Batched matrix multiplication -NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = - CUBLAS._gemm_strided_batched!(transA, transB, α, A, B, β, C) +# Using storage_type from https://github.com/FluxML/NNlib.jl/pull/191 + + NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = + CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C) From 3108086c5d2b9b8430ac9c22674273959f46ef6b Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 14 Nov 2020 21:54:55 +0100 Subject: [PATCH 04/17] tests including permuted cases --- test/nnlib.jl | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/test/nnlib.jl b/test/nnlib.jl index df810b4361..cda3c736e6 100644 --- a/test/nnlib.jl +++ b/test/nnlib.jl @@ -1,7 +1,7 @@ using NNlib @testset "batched_mul" begin - using NNlib: batched_mul, batched_adjoint, batched_transpose + using NNlib: batched_mul, batched_mul!, batched_adjoint, batched_transpose A = randn(Float32, 3,3,2); B = randn(Float32, 3,3,2); @@ -14,6 +14,20 @@ using NNlib Ca = batched_mul(A, batched_adjoint(B)) @test CuArray(Ca) ≈ batched_mul(CuArray(A), batched_adjoint(CuArray(B))) + + # 5-arg batched_mul! + C .= pi + batched_mul!(C, A, B, 2f0, 3f0) + cuCpi = CuArray(similar(C)) .= pi + @test CuArray(C) ≈ batched_mul!(cuCpi, CuArray(A), CuArray(B), 2f0, 3f0) + + # PermutedDimsArray + @test CuArray(Ct) ≈ batched_mul(PermutedDimsArray(CuArray(A), (2,1,3)), CuArray(B)) + + D = permutedims(B, (1,3,2)) + Cp = batched_mul(batched_adjoint(A), B) + @test CuArray(Cp) ≈ batched_mul(batched_adjoint(CuArray(A)), PermutedDimsArray(CuArray(D), (1,3,2))) + end @testset "NNlib storage_type etc." begin From d7cc85bf51466a959550fac9e6d07aa7dd611208 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 14 Nov 2020 21:56:54 +0100 Subject: [PATCH 05/17] add NNlib master --- Manifest.toml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 84230b42f5..584723fc73 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -105,8 +105,10 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[NNlib]] -deps = ["Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "a8180fd1445e31c0b1add98dae8da694ac2c23fd" +deps = ["Compat", "Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"] +git-tree-sha1 = "e4115ca3298c22cbab9c908e053e42ab52566852" +repo-rev = "master" +repo-url = "https://github.com/FluxML/NNlib.jl.git" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" version = "0.7.6" From e68a5d29a2ffc6523679620d27295bbf09a68a62 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 14 Nov 2020 22:13:05 +0100 Subject: [PATCH 06/17] add a comment --- lib/cublas/wrappers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index d81faf83f1..68e8a2fd88 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -923,7 +923,7 @@ for (fname, elty) in function gemm_strided_batched!(transA::Char, transB::Char, alpha::Number, - A::AbstractArray{$elty, 3}, + A::AbstractArray{$elty, 3}, # allow PermutedDimsArray B::AbstractArray{$elty, 3}, beta::Number, C::AbstractArray{$elty, 3}) From 0cfb702b146da3990c6dae2410024470f1c0ceb8 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 14 Nov 2020 22:13:18 +0100 Subject: [PATCH 07/17] CuPtr for PermutedDimsArrays --- src/nnlib.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/nnlib.jl b/src/nnlib.jl index 01d0b3a95b..64de724bf3 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -27,3 +27,14 @@ end NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C) + + +# We need CuPtr for PermutedDimsArrays, +# recursive function will also handle e.g. NamedDimsArray +function Base.unsafe_convert(::Type{CUDAdrv.CuPtr{T}}, A::AbstractArray) where {T} + if A === parent(A) + throw(MethodError(Base.unsafe_convert, Tuple{CUDAdrv.CuPtr{T}, typeof(A)})) + else + return Base.unsafe_convert(CUDAdrv.CuPtr{T}, parent(A)) + end +end From 2bfe979098b96e414ace1544d2f80db09b7f4ffe Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 14 Nov 2020 22:21:38 +0100 Subject: [PATCH 08/17] rm Compat.jl not needed, as CUDA only supports Julia 1.5 and up. --- Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index 66d2c5beb0..132993f12a 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" -Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" @@ -32,7 +31,6 @@ AbstractFFTs = "0.4, 0.5" Adapt = "2.2" BFloat16s = "0.1" CEnum = "0.2, 0.3, 0.4" -Compat = "3.9" DataStructures = "0.17, 0.18" ExprTools = "0.1" GPUArrays = "6.1.0" From 71238cee7e37634f92aa4ec73d9b6e4886e03da8 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 14 Nov 2020 22:23:17 +0100 Subject: [PATCH 09/17] CUDAdrv by mistake --- src/nnlib.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/nnlib.jl b/src/nnlib.jl index 64de724bf3..e971a521b5 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -24,17 +24,16 @@ end # Batched matrix multiplication # Using storage_type from https://github.com/FluxML/NNlib.jl/pull/191 - - NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = +NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C) # We need CuPtr for PermutedDimsArrays, # recursive function will also handle e.g. NamedDimsArray -function Base.unsafe_convert(::Type{CUDAdrv.CuPtr{T}}, A::AbstractArray) where {T} +function Base.unsafe_convert(::Type{CuPtr{T}}, A::AbstractArray) where {T} if A === parent(A) - throw(MethodError(Base.unsafe_convert, Tuple{CUDAdrv.CuPtr{T}, typeof(A)})) + throw(MethodError(Base.unsafe_convert, Tuple{CuPtr{T}, typeof(A)})) else - return Base.unsafe_convert(CUDAdrv.CuPtr{T}, parent(A)) + return Base.unsafe_convert(CuPtr{T}, parent(A)) end end From ece15c71fed0e417311506b92922a49e53676c75 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 14 Nov 2020 22:39:53 +0100 Subject: [PATCH 10/17] narrower pointer --- src/nnlib.jl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/nnlib.jl b/src/nnlib.jl index e971a521b5..ad25edbffb 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -30,10 +30,13 @@ NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, # We need CuPtr for PermutedDimsArrays, # recursive function will also handle e.g. NamedDimsArray -function Base.unsafe_convert(::Type{CuPtr{T}}, A::AbstractArray) where {T} - if A === parent(A) - throw(MethodError(Base.unsafe_convert, Tuple{CuPtr{T}, typeof(A)})) - else - return Base.unsafe_convert(CuPtr{T}, parent(A)) - end -end +# function Base.unsafe_convert(::Type{CuPtr{T}}, A::AbstractArray) where {T} +# if A === parent(A) +# throw(MethodError(Base.unsafe_convert, Tuple{CuPtr{T}, typeof(A)})) +# else +# return Base.unsafe_convert(CuPtr{T}, parent(A)) +# end +# end +# Maybe much too broad, try: +Base.unsafe_convert.unsafe_convert(::Type{CuPtr{T}}, A::PermutedDimsArray) where {T} = + Base.unsafe_convert(CuPtr{T}, parent(A)) From c35137e666eedbabb77003f5c005539f198ef51b Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 14 Nov 2020 22:42:27 +0100 Subject: [PATCH 11/17] oops --- src/nnlib.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nnlib.jl b/src/nnlib.jl index ad25edbffb..befee546d2 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -38,5 +38,5 @@ NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, # end # end # Maybe much too broad, try: -Base.unsafe_convert.unsafe_convert(::Type{CuPtr{T}}, A::PermutedDimsArray) where {T} = +Base.unsafe_convert(::Type{CuPtr{T}}, A::PermutedDimsArray) where {T} = Base.unsafe_convert(CuPtr{T}, parent(A)) From f7d46dea23136f9cdde625a00f6df32025834bd9 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 14 Nov 2020 23:03:04 +0100 Subject: [PATCH 12/17] rm Nothing --- test/nnlib.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/nnlib.jl b/test/nnlib.jl index cda3c736e6..ea0eefb897 100644 --- a/test/nnlib.jl +++ b/test/nnlib.jl @@ -44,8 +44,8 @@ end @test !is_strided((M .+ im)') @test !is_strided(Diagonal(cu(ones(3)))) - @test storage_type(M) == CuArray{Float32,2,Nothing} - @test storage_type(reshape(view(M, 1:2:10,:), 10,:)) == CuArray{Float32,2,Nothing} + @test storage_type(M) == CuArray{Float32,2} + @test storage_type(reshape(view(M, 1:2:10,:), 10,:)) == CuArray{Float32,2} end From 2246bf901afb8105ca0b08792c97bbcabf06fe07 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Mon, 16 Nov 2020 09:21:11 +0100 Subject: [PATCH 13/17] move unsafe_convert(CuPtr... --- src/array.jl | 6 ++++++ src/nnlib.jl | 16 +--------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/src/array.jl b/src/array.jl index d63daf3c3f..f790545774 100644 --- a/src/array.jl +++ b/src/array.jl @@ -429,6 +429,12 @@ function Base.unsafe_convert(::Type{CuPtr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{ end +## PermutedDimsArray + +Base.unsafe_convert(::Type{CuPtr{T}}, A::PermutedDimsArray) where {T} = + Base.unsafe_convert(CuPtr{T}, parent(A)) + + ## reshape # optimize reshape to return a CuArray diff --git a/src/nnlib.jl b/src/nnlib.jl index befee546d2..1b37fe2b1d 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -23,20 +23,6 @@ end # Batched matrix multiplication -# Using storage_type from https://github.com/FluxML/NNlib.jl/pull/191 +# 1st argument is produced by NNlib.storage_type(A) NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C) - - -# We need CuPtr for PermutedDimsArrays, -# recursive function will also handle e.g. NamedDimsArray -# function Base.unsafe_convert(::Type{CuPtr{T}}, A::AbstractArray) where {T} -# if A === parent(A) -# throw(MethodError(Base.unsafe_convert, Tuple{CuPtr{T}, typeof(A)})) -# else -# return Base.unsafe_convert(CuPtr{T}, parent(A)) -# end -# end -# Maybe much too broad, try: -Base.unsafe_convert(::Type{CuPtr{T}}, A::PermutedDimsArray) where {T} = - Base.unsafe_convert(CuPtr{T}, parent(A)) From ea67dda553a25a5f3f46eeee0d4dc149aba84406 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Mon, 16 Nov 2020 11:40:32 +0100 Subject: [PATCH 14/17] free NNlib --- Manifest.toml | 6 ++---- Project.toml | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 584723fc73..a6aa43b692 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -106,11 +106,9 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[NNlib]] deps = ["Compat", "Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "e4115ca3298c22cbab9c908e053e42ab52566852" -repo-rev = "master" -repo-url = "https://github.com/FluxML/NNlib.jl.git" +git-tree-sha1 = "1ae42464fea5258fd2ff49f1c4a40fc41cba3860" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.6" +version = "0.7.7" [[OrderedCollections]] git-tree-sha1 = "cf59cfed2e2c12e8a2ff0a4f1e9b2cd8650da6db" diff --git a/Project.toml b/Project.toml index 132993f12a..cf7ae821cd 100644 --- a/Project.toml +++ b/Project.toml @@ -37,7 +37,7 @@ GPUArrays = "6.1.0" GPUCompiler = "0.8.1" LLVM = "3" MacroTools = "0.5" -NNlib = "0.6.5, 0.7" +NNlib = "0.7.7" Reexport = "0.2" Requires = "0.5, 1.0" TimerOutputs = "0.5" From a22a68dc2a14073dad124651faa916224dd8eb65 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Mon, 16 Nov 2020 11:40:46 +0100 Subject: [PATCH 15/17] test some more NNlib methods --- test/nnlib.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/nnlib.jl b/test/nnlib.jl index ea0eefb897..3bdb6be58b 100644 --- a/test/nnlib.jl +++ b/test/nnlib.jl @@ -28,6 +28,14 @@ using NNlib Cp = batched_mul(batched_adjoint(A), B) @test CuArray(Cp) ≈ batched_mul(batched_adjoint(CuArray(A)), PermutedDimsArray(CuArray(D), (1,3,2))) + # Methods which reshape + M = randn(Float32, 3,3) + + Cm = batched_mul(A, M) + @test CuArray(Cm) ≈ batched_mul(CuArray(A), CuArray(M)) + + Cv = batched_vec(permutedims(A,(3,1,2)), M) + @test CuArray(Cv) ≈ batched_vec(PermutedDimsArray(CuArray(A),(3,1,2)), CuArray(M)) end @testset "NNlib storage_type etc." begin From e4c3dc2c65ea509b2060b0f47a7a28bf8d87bfdf Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Mon, 16 Nov 2020 11:50:41 +0100 Subject: [PATCH 16/17] mostly to re-trigger CI --- test/nnlib.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/nnlib.jl b/test/nnlib.jl index 3bdb6be58b..8f57e27f90 100644 --- a/test/nnlib.jl +++ b/test/nnlib.jl @@ -1,7 +1,7 @@ using NNlib @testset "batched_mul" begin - using NNlib: batched_mul, batched_mul!, batched_adjoint, batched_transpose + using NNlib: batched_mul, batched_mul!, batched_vec, batched_adjoint, batched_transpose A = randn(Float32, 3,3,2); B = randn(Float32, 3,3,2); From 0be7a0e0d9eb12a148ae51057634edbe1831ca61 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Tue, 17 Nov 2020 08:59:39 +0100 Subject: [PATCH 17/17] one more pointer --- src/nnlib.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/nnlib.jl b/src/nnlib.jl index 1b37fe2b1d..1cf65bdd26 100644 --- a/src/nnlib.jl +++ b/src/nnlib.jl @@ -26,3 +26,6 @@ end # 1st argument is produced by NNlib.storage_type(A) NNlib._batched_gemm!(::Type{<:CuArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) = CUBLAS.gemm_strided_batched!(transA, transB, α, A, B, β, C) + +Base.unsafe_convert(::Type{CuPtr{T}}, A::NNlib.BatchedAdjOrTrans{T}) where {T} = + Base.unsafe_convert(CuPtr{T}, parent(A))