diff --git a/Project.toml b/Project.toml index 021d3a1..b948a26 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ArrayLayouts" uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" authors = ["Sheehan Olver "] -version = "1.3.0" +version = "1.3.1" [deps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" diff --git a/src/mul.jl b/src/mul.jl index 0afe23a..0b167fd 100644 --- a/src/mul.jl +++ b/src/mul.jl @@ -321,10 +321,17 @@ end # mul! for subarray of layout matrix const SubLayoutMatrix = Union{SubArray{<:Any,2,<:LayoutMatrix}, SubArray{<:Any,2,<:AdjOrTrans{<:Any,<:LayoutMatrix}}} +*(A::Diagonal, B::SubLayoutMatrix) = mul(A, B) +*(A::SubLayoutMatrix, B::Diagonal) = mul(A, B) + LinearAlgebra.mul!(C::AbstractMatrix, A::SubLayoutMatrix, B::AbstractMatrix, α::Number, β::Number) = ArrayLayouts.mul!(C, A, B, α, β) LinearAlgebra.mul!(C::AbstractMatrix, A::AbstractMatrix, B::SubLayoutMatrix, α::Number, β::Number) = ArrayLayouts.mul!(C, A, B, α, β) +LinearAlgebra.mul!(C::AbstractMatrix, A::Diagonal, B::SubLayoutMatrix, α::Number, β::Number) = + ArrayLayouts.mul!(C, A, B, α, β) +LinearAlgebra.mul!(C::AbstractMatrix, A::SubLayoutMatrix, B::Diagonal, α::Number, β::Number) = + ArrayLayouts.mul!(C, A, B, α, β) LinearAlgebra.mul!(C::AbstractMatrix, A::SubLayoutMatrix, B::LayoutMatrix, α::Number, β::Number) = ArrayLayouts.mul!(C, A, B, α, β) LinearAlgebra.mul!(C::AbstractMatrix, A::LayoutMatrix, B::SubLayoutMatrix, α::Number, β::Number) = diff --git a/test/test_layoutarray.jl b/test/test_layoutarray.jl index 196957a..bbac943 100644 --- a/test/test_layoutarray.jl +++ b/test/test_layoutarray.jl @@ -508,6 +508,8 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor() V = view(A, 1:3, 1:3) B = randn(3,3) x = randn(3) + D = Diagonal(1:3) + @test mul!(similar(B), V, B) ≈ A * B @test mul!(similar(B), B, V) ≈ B * A @test mul!(similar(B), V, V) ≈ A^2 @@ -524,6 +526,10 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor() @test mul!(MyMatrix(copy(B)), A, V, 2.0, 3.0) ≈ 2A * A + 3B @test mul!(copy(x), V, x, 2.0, 3.0) ≈ 2A * x + 3x + @test D * V == D * A == D * A.A + @test V * D == A * D == A.A * D + @test mul!(copy(B), D, A, 2.0, 3.0) ≈ mul!(copy(B), D, V, 2.0, 3.0) ≈ 2D * A + 3B + @test mul!(copy(B), A, D, 2.0, 3.0) ≈ mul!(copy(B), V, D, 2.0, 3.0) ≈ 2A * D + 3B end end