Skip to content

Commit 8e949d6

Browse files
authored
specialize copyto! and multiplication by numbers for Q from qr (#39533)
* specialize copyto! and multiplication by numbers for Q from qr This fixes two performance bugs reported in https://github.com/JuliaLang/julia/issues/38972 and https://github.com/JuliaLang/julia/issues/38972 (multiplication of `Q` from `qr` by a `Diagonal` or `UniformScaling`). In particular, it improves the performance of generating random orthogonal matrices as described in https://discourse.julialang.org/t/random-orthogonal-matrices/9779/7. * fix typo in new qr tests * resolve mehod ambiguity of copyto!
1 parent 3230aef commit 8e949d6

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

stdlib/LinearAlgebra/src/qr.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,31 @@ function getindex(Q::AbstractQ, i::Integer, j::Integer)
533533
return dot(x, lmul!(Q, y))
534534
end
535535

536+
# specialization avoiding the fallback using slow `getindex`
537+
function copyto!(dest::AbstractMatrix, src::AbstractQ)
538+
copyto!(dest, I)
539+
lmul!(src, dest)
540+
end
541+
# needed to resolve method ambiguities
542+
function copyto!(dest::PermutedDimsArray{T,2,perm}, src::AbstractQ) where {T,perm}
543+
if perm == (1, 2)
544+
copyto!(parent(dest), src)
545+
else
546+
@assert perm == (2, 1) # there are no other permutations of two indices
547+
if T <: Real
548+
copyto!(parent(dest), I)
549+
lmul!(src', parent(dest))
550+
else
551+
# LAPACK does not offer inplace lmul!(transpose(Q), B) for complex Q
552+
tmp = similar(parent(dest))
553+
copyto!(tmp, I)
554+
rmul!(tmp, src)
555+
permutedims!(parent(dest), tmp, (2, 1))
556+
end
557+
end
558+
return dest
559+
end
560+
536561
## Multiplication by Q
537562
### QB
538563
lmul!(A::QRCompactWYQ{T,S}, B::StridedVecOrMat{T}) where {T<:BlasFloat, S<:StridedMatrix} =
@@ -590,6 +615,13 @@ function (*)(A::AbstractQ, B::StridedMatrix)
590615
lmul!(Anew, Bnew)
591616
end
592617

618+
function (*)(A::AbstractQ, b::Number)
619+
TAb = promote_type(eltype(A), typeof(b))
620+
dest = similar(A, TAb)
621+
copyto!(dest, b*I)
622+
lmul!(A, dest)
623+
end
624+
593625
### QcB
594626
lmul!(adjA::Adjoint{<:Any,<:QRCompactWYQ{T,S}}, B::StridedVecOrMat{T}) where {T<:BlasReal,S<:StridedMatrix} =
595627
(A = adjA.parent; LAPACK.gemqrt!('L','T',A.factors,A.T,B))
@@ -683,6 +715,13 @@ function (*)(A::StridedMatrix, Q::AbstractQ)
683715
return rmul!(copy_oftype(A, TAQ), convert(AbstractMatrix{TAQ}, Q))
684716
end
685717

718+
function (*)(a::Number, B::AbstractQ)
719+
TaB = promote_type(typeof(a), eltype(B))
720+
dest = similar(B, TaB)
721+
copyto!(dest, a*I)
722+
rmul!(dest, B)
723+
end
724+
686725
### AQc
687726
rmul!(A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:QRCompactWYQ{T}}) where {T<:BlasReal} =
688727
(B = adjB.parent; LAPACK.gemqrt!('R','T',B.factors,B.T,A))

stdlib/LinearAlgebra/test/qr.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,4 +322,53 @@ end
322322
end
323323
end
324324

325+
@testset "QR factorization of Q" begin
326+
for T in (Float32, Float64, ComplexF32, ComplexF64)
327+
Q1, R1 = qr(randn(T,5,5))
328+
Q2, R2 = qr(Q1)
329+
@test Q1 Q2
330+
@test R2 I
331+
end
332+
end
333+
334+
@testset "Generation of orthogonal matrices" begin
335+
for T in (Float32, Float64)
336+
n = 5
337+
Q, R = qr(randn(T,n,n))
338+
O = Q * Diagonal(sign.(diag(R)))
339+
@test O' * O I
340+
end
341+
end
342+
343+
@testset "Multiplication of Q by special matrices" begin
344+
for T in (Float32, Float64, ComplexF32, ComplexF64)
345+
n = 5
346+
Q, R = qr(randn(T,n,n))
347+
Qmat = Matrix(Q)
348+
D = Diagonal(randn(T,n))
349+
@test Q * D Qmat * D
350+
@test D * Q D * Qmat
351+
J = 2*I
352+
@test Q * J Qmat * J
353+
@test J * Q J * Qmat
354+
end
355+
end
356+
357+
@testset "copyto! for Q" begin
358+
for T in (Float32, Float64, ComplexF32, ComplexF64)
359+
n = 5
360+
Q, R = qr(randn(T,n,n))
361+
Qmat = Matrix(Q)
362+
dest1 = similar(Q)
363+
copyto!(dest1, Q)
364+
@test dest1 Qmat
365+
dest2 = PermutedDimsArray(similar(Q), (1, 2))
366+
copyto!(dest2, Q)
367+
@test dest2 Qmat
368+
dest3 = PermutedDimsArray(similar(Q), (2, 1))
369+
copyto!(dest3, Q)
370+
@test dest3 Qmat
371+
end
372+
end
373+
325374
end # module TestQR

0 commit comments

Comments
 (0)