From fd5338000324e67581ec33138ba6847bc745db68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefanos=20Carlstr=C3=B6m?= Date: Fri, 19 Feb 2021 10:42:04 +0100 Subject: [PATCH] Specialized dot(u, D::Diagonal{<:Any,<:Union{Ones,Fill}}, v) --- Project.toml | 2 +- src/FillArrays.jl | 2 +- src/fillalgebra.jl | 17 +++++++++++++++++ test/runtests.jl | 29 +++++++++++++++++++++++++++-- 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 55933f0c..5729f109 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.11.3" +version = "0.11.4" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/FillArrays.jl b/src/FillArrays.jl index cac058d7..542ac898 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -9,7 +9,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, show, view, in import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, - norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular, AdjointAbsVec + dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular, AdjointAbsVec import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 8bcea244..148db495 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -138,6 +138,23 @@ function *(a::Transpose{T, <:AbstractVector{T}}, b::Zeros{T, 1}) where T<:Real end *(a::Transpose{T, <:AbstractMatrix{T}}, b::Zeros{T, 1}) where T<:Real = mult_zeros(a, b) +function dot(u::AbstractVector, E::Eye, v::AbstractVector) + length(u) == size(E,1) && length(v) == size(E,2) || + throw(DimensionMismatch("dot product arguments have dimensions $(length(u))×$(size(E))×$(length(v))")) + dot(u, v) +end + +function dot(u::AbstractVector, D::Diagonal{<:Any,<:Fill}, v::AbstractVector) + length(u) == size(D,1) && length(v) == size(D,2) || + throw(DimensionMismatch("dot product arguments have dimensions $(length(u))×$(size(D))×$(length(v))")) + D.diag.value*dot(u, v) +end + +function dot(u::AbstractVector{T}, D::Diagonal{U,<:Zeros}, v::AbstractVector{V}) where {T,U,V} + length(u) == size(D,1) && length(v) == size(D,2) || + throw(DimensionMismatch("dot product arguments have dimensions $(length(u))×$(size(D))×$(length(v))")) + zero(promote_type(T,U,V)) +end +(a::Zeros) = a -(a::Zeros) = a diff --git a/test/runtests.jl b/test/runtests.jl index f048f655..7ba5c52b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -244,7 +244,7 @@ end @test Z[:,1] ≡ Z[1:5,1] ≡ Zeros(5) @test Z[1,:] ≡ Z[1,1:6] ≡ Zeros(6) @test Z[:,:] ≡ Z[1:5,1:6] ≡ Z[1:5,:] ≡ Z[:,1:6] ≡ Z - + A = Fill(2.0,5,6,7) Z = Zeros(5,6,7) @test A[:,1,1] ≡ A[1:5,1,1] ≡ Fill(2.0,5) @@ -1098,6 +1098,31 @@ end end end +@testset "dot products" begin + n = 15 + o = Ones(1:n) + z = Zeros(1:n) + D = Diagonal(o) + Z = Diagonal(z) + + Random.seed!(5) + u = rand(n) + v = rand(n) + + @test dot(u, D, v) == dot(u, v) + @test dot(u, 2D, v) == 2dot(u, v) + @test dot(u, Z, v) == 0 + + @test_throws DimensionMismatch dot(u[1:end-1], D, v) + @test_throws DimensionMismatch dot(u[1:end-1], D, v[1:end-1]) + + @test_throws DimensionMismatch dot(u, 2D, v[1:end-1]) + @test_throws DimensionMismatch dot(u, 2D, v[1:end-1]) + + @test_throws DimensionMismatch dot(u, Z, v[1:end-1]) + @test_throws DimensionMismatch dot(u, Z, v[1:end-1]) +end + if VERSION ≥ v"1.5" @testset "print" begin @test stringmime("text/plain", Zeros(3)) == "3-element Zeros{Float64}" @@ -1203,4 +1228,4 @@ end @test FillArrays.getindex_value(transpose(a)) == FillArrays.unique_value(transpose(a)) == 2.0 @test convert(Fill, transpose(a)) ≡ Fill(2.0,1,5) end -end \ No newline at end of file +end