Skip to content

Commit 8734371

Browse files
authored
Fix OneElement multiplication with array elements (#335)
* Fix OneElement multiplication with array elements * Fix matmul for array elements in OneElMat * StridedMat
1 parent b0ee65f commit 8734371

File tree

2 files changed

+115
-39
lines changed

2 files changed

+115
-39
lines changed

src/oneelement.jl

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,12 @@ function mul!(C::AbstractVector, A::OneElementMatrix, B::OneElementVector, alpha
145145
end
146146

147147
@inline function __mul!(y, A::AbstractMatrix, x::OneElement, alpha, beta)
148-
αx = alpha * x.val
148+
= Ref(x.val * alpha)
149149
ind1 = x.ind[1]
150150
if iszero(beta)
151-
y .= αx .* view(A, :, ind1)
151+
y .= view(A, :, ind1) .*
152152
else
153-
y .= αx .* view(A, :, ind1) .+ beta .* y
153+
y .= view(A, :, ind1) .*.+ y .* beta
154154
end
155155
return y
156156
end
@@ -171,13 +171,14 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::OneElementMatrix, alpha,
171171
mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta)
172172
return C
173173
end
174+
nzrow, nzcol = B.ind
174175
if iszero(beta)
175-
C .= zero(eltype(C))
176+
C .= Ref(zero(eltype(C)))
176177
else
177-
view(C, :, 1:B.ind[2]-1) .*= beta
178-
view(C, :, B.ind[2]+1:size(C,2)) .*= beta
178+
view(C, :, 1:nzcol-1) .*= beta
179+
view(C, :, nzcol+1:size(C,2)) .*= beta
179180
end
180-
y = view(C, :, B.ind[2])
181+
y = view(C, :, nzcol)
181182
__mul!(y, A, B, alpha, beta)
182183
C
183184
end
@@ -187,17 +188,14 @@ function _mul!(C::AbstractMatrix, A::Diagonal, B::OneElementMatrix, alpha, beta)
187188
mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta)
188189
return C
189190
end
190-
if iszero(beta)
191-
C .= zero(eltype(C))
192-
else
193-
view(C, :, 1:B.ind[2]-1) .*= beta
194-
view(C, :, B.ind[2]+1:size(C,2)) .*= beta
195-
end
196-
ABα = A * B * alpha
197191
nzrow, nzcol = B.ind
192+
ABα = A * B * alpha
198193
if iszero(beta)
199-
C[B.ind...] = ABα[B.ind...]
194+
C .= Ref(zero(eltype(C)))
195+
C[nzrow, nzcol] = ABα[nzrow, nzcol]
200196
else
197+
view(C, :, 1:nzcol-1) .*= beta
198+
view(C, :, nzcol+1:size(C,2)) .*= beta
201199
y = view(C, :, nzcol)
202200
y .= view(ABα, :, nzcol) .+ y .* beta
203201
end
@@ -210,19 +208,16 @@ function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::AbstractMatrix, alpha,
210208
mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta)
211209
return C
212210
end
213-
if iszero(beta)
214-
C .= zero(eltype(C))
215-
else
216-
view(C, 1:A.ind[1]-1, :) .*= beta
217-
view(C, A.ind[1]+1:size(C,1), :) .*= beta
218-
end
219-
y = view(C, A.ind[1], :)
220-
ind2 = A.ind[2]
211+
nzrow, nzcol = A.ind
212+
y = view(C, nzrow, :)
221213
Aval = A.val
222214
if iszero(beta)
223-
y .= Aval .* view(B, ind2, :) .* alpha
215+
C .= Ref(zero(eltype(C)))
216+
y .= Ref(Aval) .* view(B, nzcol, :) .* alpha
224217
else
225-
y .= Aval .* view(B, ind2, :) .* alpha .+ y .* beta
218+
view(C, 1:nzrow-1, :) .*= beta
219+
view(C, nzrow+1:size(C,1), :) .*= beta
220+
y .= Ref(Aval) .* view(B, nzcol, :) .* alpha .+ y .* beta
226221
end
227222
C
228223
end
@@ -232,17 +227,14 @@ function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::Diagonal, alpha, beta)
232227
mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta)
233228
return C
234229
end
235-
if iszero(beta)
236-
C .= zero(eltype(C))
237-
else
238-
view(C, 1:A.ind[1]-1, :) .*= beta
239-
view(C, A.ind[1]+1:size(C,1), :) .*= beta
240-
end
241-
ABα = A * B * alpha
242230
nzrow, nzcol = A.ind
231+
ABα = A * B * alpha
243232
if iszero(beta)
244-
C[A.ind...] = ABα[A.ind...]
233+
C .= Ref(zero(eltype(C)))
234+
C[nzrow, nzcol] = ABα[nzrow, nzcol]
245235
else
236+
view(C, 1:nzrow-1, :) .*= beta
237+
view(C, nzrow+1:size(C,1), :) .*= beta
246238
y = view(C, nzrow, :)
247239
y .= view(ABα, nzrow, :) .+ y .* beta
248240
end
@@ -256,16 +248,13 @@ function _mul!(C::AbstractVector, A::OneElementMatrix, B::AbstractVector, alpha,
256248
return C
257249
end
258250
nzrow, nzcol = A.ind
259-
if iszero(beta)
260-
C .= zero(eltype(C))
261-
else
262-
view(C, 1:nzrow-1) .*= beta
263-
view(C, nzrow+1:size(C,1)) .*= beta
264-
end
265251
Aval = A.val
266252
if iszero(beta)
253+
C .= Ref(zero(eltype(C)))
267254
C[nzrow] = Aval * B[nzcol] * alpha
268255
else
256+
view(C, 1:nzrow-1) .*= beta
257+
view(C, nzrow+1:size(C,1)) .*= beta
269258
C[nzrow] = Aval * B[nzcol] * alpha + C[nzrow] * beta
270259
end
271260
C

test/runtests.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,6 +2318,93 @@ end
23182318
@test mul!(C, O, D, 2, 2) == 2 * O * D .+ 2
23192319
end
23202320
end
2321+
@testset "array elements" begin
2322+
A = [SMatrix{2,3}(1:6)*(i+j) for i in 1:3, j in 1:2]
2323+
@testset "StridedMatrix * OneElementMatrix" begin
2324+
B = OneElement(SMatrix{3,2}(1:6), (size(A,2),2), (size(A,2),4))
2325+
C = [SMatrix{2,2}(1:4) for i in axes(A,1), j in axes(B,2)]
2326+
@test mul!(copy(C), A, B) == A * B
2327+
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
2328+
end
2329+
@testset "StridedMatrix * OneElementVector" begin
2330+
B = OneElement(SMatrix{3,2}(1:6), (size(A,2),), (size(A,2),))
2331+
C = [SMatrix{2,2}(1:4) for i in axes(A,1)]
2332+
@test mul!(copy(C), A, B) == A * B
2333+
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
2334+
end
2335+
2336+
A = OneElement(SMatrix{3,2}(1:6), (3,2), (5,4))
2337+
@testset "OneElementMatrix * StridedMatrix" begin
2338+
B = [SMatrix{2,3}(1:6)*(i+j) for i in axes(A,2), j in 1:2]
2339+
C = [SMatrix{3,3}(1:9) for i in axes(A,1), j in axes(B,2)]
2340+
@test mul!(copy(C), A, B) == A * B
2341+
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
2342+
end
2343+
@testset "OneElementMatrix * StridedVector" begin
2344+
B = [SMatrix{2,3}(1:6)*i for i in axes(A,2)]
2345+
C = [SMatrix{3,3}(1:9) for i in axes(A,1)]
2346+
@test mul!(copy(C), A, B) == A * B
2347+
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
2348+
end
2349+
@testset "OneElementMatrix * OneElementMatrix" begin
2350+
B = OneElement(SMatrix{2,3}(1:6), (2,4), (size(A,2), 3))
2351+
C = [SMatrix{3,3}(1:9) for i in axes(A,1), j in axes(B,2)]
2352+
@test mul!(copy(C), A, B) == A * B
2353+
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
2354+
end
2355+
@testset "OneElementMatrix * OneElementVector" begin
2356+
B = OneElement(SMatrix{2,3}(1:6), 2, size(A,2))
2357+
C = [SMatrix{3,3}(1:9) for i in axes(A,1)]
2358+
@test mul!(copy(C), A, B) == A * B
2359+
@test mul!(copy(C), A, B, 2, 2) == 2 * A * B + 2 * C
2360+
end
2361+
end
2362+
@testset "non-commutative" begin
2363+
A = OneElement(quat(rand(4)...), (2,3), (3,4))
2364+
for (B,C) in (
2365+
# OneElementMatrix * OneElementVector
2366+
(OneElement(quat(rand(4)...), 3, size(A,2)),
2367+
[quat(rand(4)...) for i in axes(A,1)]),
2368+
2369+
# OneElementMatrix * OneElementMatrix
2370+
(OneElement(quat(rand(4)...), (3,2), (size(A,2), 4)),
2371+
[quat(rand(4)...) for i in axes(A,1), j in 1:4]),
2372+
)
2373+
@test mul!(copy(C), A, B) A * B
2374+
α, β = quat(0,0,1,0), quat(1,0,1,0)
2375+
@test mul!(copy(C), A, B, α, β) mul!(copy(C), A, Array(B), α, β) A * B * α + C * β
2376+
end
2377+
2378+
A = [quat(rand(4)...)*(i+j) for i in 1:2, j in 1:3]
2379+
for (B,C) in (
2380+
# StridedMatrix * OneElementVector
2381+
(OneElement(quat(rand(4)...), 1, size(A,2)),
2382+
[quat(rand(4)...) for i in axes(A,1)]),
2383+
2384+
# StridedMatrix * OneElementMatrix
2385+
(OneElement(quat(rand(4)...), (2,2), (size(A,2), 4)),
2386+
[quat(rand(4)...) for i in axes(A,1), j in 1:4]),
2387+
)
2388+
@test mul!(copy(C), A, B) A * B
2389+
α, β = quat(0,0,1,0), quat(1,0,1,0)
2390+
@test mul!(copy(C), A, B, α, β) mul!(copy(C), A, Array(B), α, β) A * B * α + C * β
2391+
end
2392+
2393+
A = OneElement(quat(rand(4)...), (2,2), (3, 4))
2394+
for (B,C) in (
2395+
# OneElementMatrix * StridedMatrix
2396+
([quat(rand(4)...) for i in axes(A,2), j in 1:3],
2397+
[quat(rand(4)...) for i in axes(A,1), j in 1:3]),
2398+
2399+
# OneElementMatrix * StridedVector
2400+
([quat(rand(4)...) for i in axes(A,2)],
2401+
[quat(rand(4)...) for i in axes(A,1)]),
2402+
)
2403+
@test mul!(copy(C), A, B) A * B
2404+
α, β = quat(0,0,1,0), quat(1,0,1,0)
2405+
@test mul!(copy(C), A, B, α, β) mul!(copy(C), A, Array(B), α, β) A * B * α + C * β
2406+
end
2407+
end
23212408
end
23222409

23232410
@testset "multiplication/division by a number" begin

0 commit comments

Comments
 (0)