From 71146f07b132741bb0e4bf5143edee67754f69a6 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Tue, 12 Sep 2023 14:11:53 +0100 Subject: [PATCH 1/2] Fix Vec * Mat --- Project.toml | 2 +- src/muladd.jl | 28 ++++++++++++++++++++++------ test/test_muladd.jl | 4 ++++ 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 1bc458d..ec05d16 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ArrayLayouts" uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" authors = ["Sheehan Olver "] -version = "1.4" +version = "1.4.1" [deps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" diff --git a/src/muladd.jl b/src/muladd.jl index 2873d21..6dba71c 100644 --- a/src/muladd.jl +++ b/src/muladd.jl @@ -189,6 +189,23 @@ function default_blasmul!(α, A::AbstractMatrix, B::AbstractMatrix, β, C::Abstr C end +function default_blasmul!(α, A::AbstractVector, B::AbstractMatrix, β, C::AbstractMatrix) + mA, = size(A) + mB, nB = size(B) + 1 == mB || throw(DimensionMismatch("Dimensions must match")) + size(C) == (mA, nB) || throw(DimensionMismatch("Dimensions must match")) + + lmul!(β, C) + + (iszero(mA) || iszero(nB)) && return C + + for k in colsupport(A), j in rowsupport(B) + _default_blasmul_loop!(α, A, B, β, C, k, j) + end + C +end + + function _default_blasmul!(::IndexLinear, α, A::AbstractMatrix, B::AbstractVector, β, C::AbstractVector) mA, nA = size(A) mB = length(B) @@ -266,6 +283,11 @@ function materialize!(M::MatMulVecAdd) default_blasmul!(α, unalias(C,A), unalias(C,B), iszero(β) ? false : β, C) end +function materialize!(M::VecMulMatAdd) + α, A, B, β, C = M.α, M.A, M.B, M.β, M.C + default_blasmul!(α, unalias(C,A), unalias(C,B), iszero(β) ? false : β, C) +end + @inline _gemv!(tA, α, A, x, β, y) = BLAS.gemv!(tA, α, unalias(y,A), unalias(y,x), β, y) @inline _gemm!(tA, tB, α, A, B, β, C) = BLAS.gemm!(tA, tB, α, unalias(C,A), unalias(C,B), β, C) @@ -424,12 +446,6 @@ function similar(M::MulAdd{<:DualLayout,<:Any,ZerosLayout}, ::Type{T}, (x,y)) wh trans(similar(trans(M.A), T, y)) end -function similar(M::MulAdd{<:Any,<:DualLayout,ZerosLayout}, ::Type{T}, (x,y)) where T - @assert length(x) == 1 - trans = transtype(M.B) - trans(similar(trans(M.B), T, y)) -end - const ZerosLayouts = Union{ZerosLayout,DualLayout{ZerosLayout}} copy(M::MulAdd{<:ZerosLayouts, <:ZerosLayouts, <:ZerosLayouts}) = M.C copy(M::MulAdd{<:ZerosLayouts, <:Any, <:ZerosLayouts}) = M.C diff --git a/test/test_muladd.jl b/test/test_muladd.jl index 30cd243..e98ce67 100644 --- a/test/test_muladd.jl +++ b/test/test_muladd.jl @@ -731,4 +731,8 @@ Random.seed!(0) Y = randn(rng, 8, 2) @test mul(Y',X) ≈ Y'X end + + @testset "Vec * Adj" begin + @test ArrayLayouts.mul(1:5, (1:4)') == (1:5) * (1:4)' + end end From b4dc21638d22a2e57393663aa901f57f22a69a43 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Tue, 12 Sep 2023 14:28:30 +0100 Subject: [PATCH 2/2] Update muladd.jl --- src/muladd.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/muladd.jl b/src/muladd.jl index 6dba71c..83d9b70 100644 --- a/src/muladd.jl +++ b/src/muladd.jl @@ -446,7 +446,13 @@ function similar(M::MulAdd{<:DualLayout,<:Any,ZerosLayout}, ::Type{T}, (x,y)) wh trans(similar(trans(M.A), T, y)) end +function similar(M::MulAdd{ScalarLayout,<:DualLayout,ZerosLayout}, ::Type{T}, (x,y)) where T + trans = transtype(M.B) + trans(similar(trans(M.B), T, y)) +end + const ZerosLayouts = Union{ZerosLayout,DualLayout{ZerosLayout}} copy(M::MulAdd{<:ZerosLayouts, <:ZerosLayouts, <:ZerosLayouts}) = M.C copy(M::MulAdd{<:ZerosLayouts, <:Any, <:ZerosLayouts}) = M.C copy(M::MulAdd{<:Any, <:ZerosLayouts, <:ZerosLayouts}) = M.C +