Skip to content

Commit a8bffd3

Browse files
authored
Special case for Zeros(5)'*b and dot(::AbstractFill,::AbstractFill) (#160)
* Special case for Zeros(5)'*b and dot(::AbstractFill,::AbstractFill) * more overloads * add tests
1 parent 1f25ae1 commit a8bffd3

File tree

4 files changed

+159
-74
lines changed

4 files changed

+159
-74
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FillArrays"
22
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
3-
version = "0.12.5"
3+
version = "0.12.6"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/FillArrays.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
99
show, view, in, mapreduce
1010

1111
import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
12-
dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular, AdjointAbsVec,
13-
issymmetric, ishermitian
12+
dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular, AdjointAbsVec, TransposeAbsVec,
13+
issymmetric, ishermitian, AdjOrTransAbsVec
1414

1515
import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape
1616

src/fillalgebra.jl

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ end
8383
*(a::Zeros{<:Any,2}, b::AbstractVector) = mult_zeros(a, b)
8484
*(a::AbstractVector, b::Zeros{<:Any,2}) = mult_zeros(a, b)
8585

86+
*(a::Zeros{<:Any,1}, b::AdjOrTransAbsVec) = mult_zeros(a, b)
87+
8688
*(a::Zeros{<:Any,1}, b::Diagonal) = mult_zeros(a, b)
8789
*(a::Zeros{<:Any,2}, b::Diagonal) = mult_zeros(a, b)
8890
*(a::Diagonal, b::Zeros{<:Any,1}) = mult_zeros(a, b)
@@ -117,16 +119,29 @@ function *(a::StridedMatrix{T}, b::Fill{T, 2}) where T
117119
fill!(fB, b.value)
118120
return a*fB
119121
end
120-
function _adjvec_mul_zeros(a::Adjoint{T}, b::Zeros{S, 1}) where {T, S}
122+
function _adjvec_mul_zeros(a, b)
121123
la, lb = length(a), length(b)
122124
if la lb
123125
throw(DimensionMismatch("dot product arguments have lengths $la and $lb"))
124126
end
125-
return zero(Base.promote_op(*, T, S))
127+
return zero(Base.promote_op(*, eltype(a), eltype(b)))
126128
end
127129

130+
*(a::AdjointAbsVec{<:Any,<:Zeros{<:Any,1}}, b::AbstractMatrix) = (b' * a')'
131+
*(a::AdjointAbsVec{<:Any,<:Zeros{<:Any,1}}, b::Zeros{<:Any,2}) = (b' * a')'
132+
*(a::TransposeAbsVec{<:Any,<:Zeros{<:Any,1}}, b::AbstractMatrix) = transpose(transpose(b) * transpose(a))
133+
*(a::TransposeAbsVec{<:Any,<:Zeros{<:Any,1}}, b::Zeros{<:Any,2}) = transpose(transpose(b) * transpose(a))
134+
135+
*(a::AbstractVector, b::AdjOrTransAbsVec{<:Any,<:Zeros{<:Any,1}}) = a * permutedims(parent(b))
136+
*(a::AbstractMatrix, b::AdjOrTransAbsVec{<:Any,<:Zeros{<:Any,1}}) = a * permutedims(parent(b))
137+
*(a::Zeros{<:Any,1}, b::AdjOrTransAbsVec{<:Any,<:Zeros{<:Any,1}}) = a * permutedims(parent(b))
138+
*(a::Zeros{<:Any,2}, b::AdjOrTransAbsVec{<:Any,<:Zeros{<:Any,1}}) = a * permutedims(parent(b))
139+
128140
*(a::AdjointAbsVec, b::Zeros{<:Any, 1}) = _adjvec_mul_zeros(a, b)
129141
*(a::AdjointAbsVec{<:Number}, b::Zeros{<:Number, 1}) = _adjvec_mul_zeros(a, b)
142+
*(a::TransposeAbsVec, b::Zeros{<:Any, 1}) = _adjvec_mul_zeros(a, b)
143+
*(a::TransposeAbsVec{<:Number}, b::Zeros{<:Number, 1}) = _adjvec_mul_zeros(a, b)
144+
130145
*(a::Adjoint{T, <:AbstractMatrix{T}} where T, b::Zeros{<:Any, 1}) = mult_zeros(a, b)
131146

132147
function *(a::Transpose{T, <:AbstractVector{T}}, b::Zeros{T, 1}) where T<:Real
@@ -138,6 +153,39 @@ function *(a::Transpose{T, <:AbstractVector{T}}, b::Zeros{T, 1}) where T<:Real
138153
end
139154
*(a::Transpose{T, <:AbstractMatrix{T}}, b::Zeros{T, 1}) where T<:Real = mult_zeros(a, b)
140155

156+
# treat zero separately to support ∞-vectors
157+
function _zero_dot(a, b)
158+
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
159+
zero(promote_type(eltype(a),eltype(b)))
160+
end
161+
162+
_fill_dot(a::Zeros, b::Zeros) = _zero_dot(a, b)
163+
_fill_dot(a::Zeros, b) = _zero_dot(a, b)
164+
_fill_dot(a, b::Zeros) = _zero_dot(a, b)
165+
_fill_dot(a::Zeros, b::AbstractFill) = _zero_dot(a, b)
166+
_fill_dot(a::AbstractFill, b::Zeros) = _zero_dot(a, b)
167+
168+
function _fill_dot(a::AbstractFill, b::AbstractFill)
169+
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
170+
getindex_value(a)getindex_value(b)*length(b)
171+
end
172+
173+
# support types with fast sum
174+
function _fill_dot(a::AbstractFill, b)
175+
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
176+
getindex_value(a)sum(b)
177+
end
178+
179+
function _fill_dot(a, b::AbstractFill)
180+
axes(a) == axes(b) || throw(DimensionMismatch("dot product arguments have lengths $(length(a)) and $(length(b))"))
181+
sum(a)getindex_value(b)
182+
end
183+
184+
185+
dot(a::AbstractFill{<:Any,1}, b::AbstractFill{<:Any,1}) = _fill_dot(a, b)
186+
dot(a::AbstractFill{<:Any,1}, b::AbstractVector) = _fill_dot(a, b)
187+
dot(a::AbstractVector, b::AbstractFill{<:Any,1}) = _fill_dot(a, b)
188+
141189
function dot(u::AbstractVector, E::Eye, v::AbstractVector)
142190
length(u) == size(E,1) && length(v) == size(E,2) ||
143191
throw(DimensionMismatch("dot product arguments have dimensions $(length(u))×$(size(E))×$(length(v))"))

test/runtests.jl

Lines changed: 106 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -451,29 +451,49 @@ end
451451
@test [1,2,3]*Zeros(1,3) Zeros(3,3)
452452
@test_throws MethodError [1,2,3]*Zeros(3) # Not defined for [1,2,3]*[0,0,0] either
453453

454-
# Check multiplication by Adjoint vectors works as expected.
455-
@test randn(4, 3)' * Zeros(4) === Zeros(3)
456-
@test randn(4)' * Zeros(4) === zero(Float64)
457-
@test [1, 2, 3]' * Zeros{Int}(3) === zero(Int)
458-
@test [SVector(1,2)', SVector(2,3)', SVector(3,4)']' * Zeros{Int}(3) === SVector(0,0)
459-
@test_throws DimensionMismatch randn(4)' * Zeros(3)
460-
461-
# Check multiplication by Transpose-d vectors works as expected.
462-
@test transpose(randn(4, 3)) * Zeros(4) === Zeros(3)
463-
@test transpose(randn(4)) * Zeros(4) === zero(Float64)
464-
@test transpose([1, 2, 3]) * Zeros{Int}(3) === zero(Int)
465-
@test_throws DimensionMismatch transpose(randn(4)) * Zeros(3)
466-
467-
@test +(Zeros{Float64}(3, 5)) === Zeros{Float64}(3, 5)
468-
@test -(Zeros{Float32}(5, 2)) === Zeros{Float32}(5, 2)
469-
470-
# `Zeros` are closed under addition and subtraction (both unary and binary).
454+
@testset "Check multiplication by Adjoint vectors works as expected." begin
455+
@test randn(4, 3)' * Zeros(4) === Zeros(3)
456+
@test randn(4)' * Zeros(4) === zero(Float64)
457+
@test [1, 2, 3]' * Zeros{Int}(3) === zero(Int)
458+
@test [SVector(1,2)', SVector(2,3)', SVector(3,4)']' * Zeros{Int}(3) === SVector(0,0)
459+
@test_throws DimensionMismatch randn(4)' * Zeros(3)
460+
@test Zeros(5)' * randn(5,3) Zeros(5)'*Zeros(5,3) Zeros(5)'*Ones(5,3) Zeros(3)'
461+
@test Zeros(5)' * randn(5) Zeros(5)' * Zeros(5) Zeros(5)' * Ones(5) 0.0
462+
@test Zeros(5) * Zeros(6)' Zeros(5,1) * Zeros(6)' Zeros(5,6)
463+
@test randn(5) * Zeros(6)' randn(5,1) * Zeros(6)' Zeros(5,6)
464+
@test Zeros(5) * randn(6)' Zeros(5,6)
465+
466+
@test ([[1,2]])' * Zeros{SVector{2,Int}}(1) 0
467+
@test_broken ([[1,2,3]])' * Zeros{SVector{2,Int}}(1)
468+
end
469+
470+
@testset "Check multiplication by Transpose-d vectors works as expected." begin
471+
@test transpose(randn(4, 3)) * Zeros(4) === Zeros(3)
472+
@test transpose(randn(4)) * Zeros(4) === zero(Float64)
473+
@test transpose([1, 2, 3]) * Zeros{Int}(3) === zero(Int)
474+
@test_throws DimensionMismatch transpose(randn(4)) * Zeros(3)
475+
@test transpose(Zeros(5)) * randn(5,3) transpose(Zeros(5))*Zeros(5,3) transpose(Zeros(5))*Ones(5,3) transpose(Zeros(3))
476+
@test transpose(Zeros(5)) * randn(5) transpose(Zeros(5)) * Zeros(5) transpose(Zeros(5)) * Ones(5) 0.0
477+
@test randn(5) * transpose(Zeros(6)) randn(5,1) * transpose(Zeros(6)) Zeros(5,6)
478+
@test Zeros(5) * transpose(randn(6)) Zeros(5,6)
479+
@test transpose(randn(5)) * Zeros(5) 0.0
480+
481+
@test transpose([[1,2]]) * Zeros{SVector{2,Int}}(1) 0
482+
@test_broken transpose([[1,2,3]]) * Zeros{SVector{2,Int}}(1)
483+
end
484+
471485
z1, z2 = Zeros{Float64}(4), Zeros{Int}(4)
472-
@test +(z1) === z1
473-
@test -(z1) === z1
474486

475-
test_addition_and_subtraction([z1, z2], [z1, z2], Zeros)
476-
test_addition_and_subtraction_dim_mismatch(z1, Zeros{Float64}(4, 2))
487+
@testset "`Zeros` are closed under addition and subtraction (both unary and binary)." begin
488+
@test +(Zeros{Float64}(3, 5)) === Zeros{Float64}(3, 5)
489+
@test -(Zeros{Float32}(5, 2)) === Zeros{Float32}(5, 2)
490+
491+
@test +(z1) === z1
492+
@test -(z1) === z1
493+
494+
test_addition_and_subtraction([z1, z2], [z1, z2], Zeros)
495+
test_addition_and_subtraction_dim_mismatch(z1, Zeros{Float64}(4, 2))
496+
end
477497

478498
# `Zeros` +/- `Fill`s should yield `Fills`.
479499
fill1, fill2 = Fill(5.0, 4), Fill(5, 4)
@@ -502,36 +522,41 @@ end
502522
@test op(Zeros{Float64}(4, 5), Zeros{Int}(4, 5)) === Zeros{Float64}(4, 5)
503523
end
504524

505-
# Zeros +/- dense where + / - have different results.
506-
@test +(Zeros(3, 5), X) == X && +(X, Zeros(3, 5)) == X
507-
@test !(Zeros(3, 5) + X === X) && !(X + Zeros(3, 5) === X)
508-
@test -(Zeros(3, 5), X) == -X
509-
510-
# Addition with different eltypes.
511-
@test +(Zeros{Float32}(3, 5), X) isa Matrix{Float64}
512-
@test !(+(Zeros{Float32}(3, 5), X) === X)
513-
@test +(Zeros{Float32}(3, 5), X) == X
514-
@test !(+(Zeros{ComplexF64}(3, 5), X) === X)
515-
@test +(Zeros{ComplexF64}(3, 5), X) == X
516-
517-
# Subtraction with different eltypes.
518-
@test -(Zeros{Float32}(3, 5), X) isa Matrix{Float64}
519-
@test -(Zeros{Float32}(3, 5), X) == -X
520-
@test -(Zeros{ComplexF64}(3, 5), X) == -X
521-
522-
# Tests for ranges.
523-
X = randn(5)
524-
@test !(Zeros(5) + X === X)
525-
@test Zeros{Int}(5) + (1:5) === (1:5) && (1:5) + Zeros{Int}(5) === (1:5)
526-
@test Zeros(5) + (1:5) === (1.0:1.0:5.0) && (1:5) + Zeros(5) === (1.0:1.0:5.0)
527-
@test (1:5) - Zeros{Int}(5) === (1:5)
528-
@test Zeros{Int}(5) - (1:5) === -1:-1:-5
529-
@test Zeros(5) - (1:5) === -1.0:-1.0:-5.0
530-
531-
# test Base.zero
532-
@test zero(Zeros(10)) == Zeros(10)
533-
@test zero(Ones(10,10)) == Zeros(10,10)
534-
@test zero(Fill(0.5, 10, 10)) == Zeros(10,10)
525+
@testset "Zeros +/- dense where + / - have different results." begin
526+
@test +(Zeros(3, 5), X) == X && +(X, Zeros(3, 5)) == X
527+
@test !(Zeros(3, 5) + X === X) && !(X + Zeros(3, 5) === X)
528+
@test -(Zeros(3, 5), X) == -X
529+
end
530+
531+
@testset "Addition with different eltypes." begin
532+
@test +(Zeros{Float32}(3, 5), X) isa Matrix{Float64}
533+
@test !(+(Zeros{Float32}(3, 5), X) === X)
534+
@test +(Zeros{Float32}(3, 5), X) == X
535+
@test !(+(Zeros{ComplexF64}(3, 5), X) === X)
536+
@test +(Zeros{ComplexF64}(3, 5), X) == X
537+
end
538+
539+
@testset "Subtraction with different eltypes." begin
540+
@test -(Zeros{Float32}(3, 5), X) isa Matrix{Float64}
541+
@test -(Zeros{Float32}(3, 5), X) == -X
542+
@test -(Zeros{ComplexF64}(3, 5), X) == -X
543+
end
544+
545+
@testset "Tests for ranges." begin
546+
X = randn(5)
547+
@test !(Zeros(5) + X === X)
548+
@test Zeros{Int}(5) + (1:5) === (1:5) && (1:5) + Zeros{Int}(5) === (1:5)
549+
@test Zeros(5) + (1:5) === (1.0:1.0:5.0) && (1:5) + Zeros(5) === (1.0:1.0:5.0)
550+
@test (1:5) - Zeros{Int}(5) === (1:5)
551+
@test Zeros{Int}(5) - (1:5) === -1:-1:-5
552+
@test Zeros(5) - (1:5) === -1.0:-1.0:-5.0
553+
end
554+
555+
@testset "test Base.zero" begin
556+
@test zero(Zeros(10)) == Zeros(10)
557+
@test zero(Ones(10,10)) == Zeros(10,10)
558+
@test zero(Fill(0.5, 10, 10)) == Zeros(10,10)
559+
end
535560
end
536561

537562
@testset "maximum/minimum/svd/sort" begin
@@ -1135,25 +1160,7 @@ end
11351160
@test E*(1:5) 1.0:5.0
11361161
@test (1:5)'E == (1.0:5)'
11371162
@test E*E E
1138-
end
1139-
1140-
@testset "count" begin
1141-
@test count(Ones{Bool}(10)) == count(Fill(true,10)) == 10
1142-
@test count(Zeros{Bool}(10)) == count(Fill(false,10)) == 0
1143-
@test count(x -> 1  x < 2, Fill(1.3,10)) == 10
1144-
@test count(x -> 1  x < 2, Fill(2.0,10)) == 0
1145-
end
1146-
1147-
@testset "norm" begin
1148-
for a in (Zeros{Int}(5), Zeros(5,3), Zeros(2,3,3),
1149-
Ones{Int}(5), Ones(5,3), Ones(2,3,3),
1150-
Fill(2.3,5), Fill([2.3,4.2],5), Fill(4)),
1151-
p in (-Inf, 0, 0.1, 1, 2, 3, Inf)
1152-
@test norm(a,p) norm(Array(a),p)
1153-
end
1154-
end
11551163

1156-
@testset "multiplication" begin
11571164
for T in (Float64, ComplexF64)
11581165
fv = T == Float64 ? Float64(1.6) : ComplexF64(1.6, 1.3)
11591166
n = 10
@@ -1172,6 +1179,22 @@ end
11721179
end
11731180
end
11741181

1182+
@testset "count" begin
1183+
@test count(Ones{Bool}(10)) == count(Fill(true,10)) == 10
1184+
@test count(Zeros{Bool}(10)) == count(Fill(false,10)) == 0
1185+
@test count(x -> 1  x < 2, Fill(1.3,10)) == 10
1186+
@test count(x -> 1  x < 2, Fill(2.0,10)) == 0
1187+
end
1188+
1189+
@testset "norm" begin
1190+
for a in (Zeros{Int}(5), Zeros(5,3), Zeros(2,3,3),
1191+
Ones{Int}(5), Ones(5,3), Ones(2,3,3),
1192+
Fill(2.3,5), Fill([2.3,4.2],5), Fill(4)),
1193+
p in (-Inf, 0, 0.1, 1, 2, 3, Inf)
1194+
@test norm(a,p) norm(Array(a),p)
1195+
end
1196+
end
1197+
11751198
@testset "dot products" begin
11761199
n = 15
11771200
o = Ones(1:n)
@@ -1187,6 +1210,17 @@ end
11871210
@test dot(u, 2D, v) == 2dot(u, v)
11881211
@test dot(u, Z, v) == 0
11891212

1213+
@test dot(Zeros(5), Zeros{ComplexF16}(5)) zero(ComplexF64)
1214+
@test dot(Zeros(5), Ones{ComplexF16}(5)) zero(ComplexF64)
1215+
@test dot(Ones{ComplexF16}(5), Zeros(5)) zero(ComplexF64)
1216+
@test dot(randn(5), Zeros{ComplexF16}(5)) dot(Zeros{ComplexF16}(5), randn(5)) zero(ComplexF64)
1217+
1218+
@test dot(Fill(1,5), Fill(2.0,5)) 10.0
1219+
1220+
let N = 2^big(1000) # fast dot for fast sum
1221+
@test dot(Fill(2,N),1:N) == dot(Fill(2,N),1:N) == dot(1:N,Fill(2,N)) == 2*sum(1:N)
1222+
end
1223+
11901224
@test_throws DimensionMismatch dot(u[1:end-1], D, v)
11911225
@test_throws DimensionMismatch dot(u[1:end-1], D, v[1:end-1])
11921226

@@ -1195,6 +1229,9 @@ end
11951229

11961230
@test_throws DimensionMismatch dot(u, Z, v[1:end-1])
11971231
@test_throws DimensionMismatch dot(u, Z, v[1:end-1])
1232+
1233+
@test_throws DimensionMismatch dot(Zeros(5), Zeros(6))
1234+
@test_throws DimensionMismatch dot(Zeros(5), randn(6))
11981235
end
11991236

12001237
@testset "print" begin

0 commit comments

Comments
 (0)