Skip to content

Commit 5e584fb

Browse files
carstenbauerandreasnoack
authored andcommitted
alg keyword for svd and svd! (#31057)
* alg keyword for LinearAlgebra.svd * SVDAlgorithms -> Algorithms * default_svd_alg * refined docstring Co-Authored-By: Andreas Noack <[email protected]> * rename to QRIteration; _svd! dispatch * compat annotation
1 parent 0eabe22 commit 5e584fb

File tree

6 files changed

+68
-16
lines changed

6 files changed

+68
-16
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ Standard library changes
7474
* The BLAS submodule no longer exports `dot`, which conflicts with that in LinearAlgebra ([#31838]).
7575
* `diagm` and `spdiagm` now accept optional `m,n` initial arguments to specify a size ([#31654]).
7676
* `Hessenberg` factorizations `H` now support efficient shifted solves `(H+µI) \ b` and determinants, and use a specialized tridiagonal factorization for Hermitian matrices. There is also a new `UpperHessenberg` matrix type ([#31853]).
77+
* Added keyword argument `alg` to `svd` and `svd!` that allows one to switch between different SVD algorithms ([#31057]).
7778
* Five-argument `mul!(C, A, B, α, β)` now implements inplace multiplication fused with addition _C = A B α + C β_ ([#23919]).
7879

7980
#### SparseArrays

stdlib/LinearAlgebra/src/LinearAlgebra.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,12 @@ else
156156
const BlasInt = Int32
157157
end
158158

159+
160+
abstract type Algorithm end
161+
struct DivideAndConquer <: Algorithm end
162+
struct QRIteration <: Algorithm end
163+
164+
159165
# Check that stride of matrix/vector is 1
160166
# Writing like this to avoid splatting penalty when called with multiple arguments,
161167
# see PR 16416

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,8 @@ function svd!(M::Bidiagonal{<:BlasReal}; full::Bool = false)
200200
d, e, U, Vt, Q, iQ = LAPACK.bdsdc!(M.uplo, 'I', M.dv, M.ev)
201201
SVD(U, d, Vt)
202202
end
203-
function svd(M::Bidiagonal; full::Bool = false)
204-
svd!(copy(M), full = full)
203+
function svd(M::Bidiagonal; kw...)
204+
svd!(copy(M), kw...)
205205
end
206206

207207
####################

stdlib/LinearAlgebra/src/svd.jl

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,19 @@ function SVD{T}(U::AbstractArray, S::AbstractVector{Tr}, Vt::AbstractArray) wher
5454
convert(AbstractArray{T}, Vt))
5555
end
5656

57+
5758
# iteration for destructuring into components
5859
Base.iterate(S::SVD) = (S.U, Val(:S))
5960
Base.iterate(S::SVD, ::Val{:S}) = (S.S, Val(:V))
6061
Base.iterate(S::SVD, ::Val{:V}) = (S.V, Val(:done))
6162
Base.iterate(S::SVD, ::Val{:done}) = nothing
6263

64+
65+
default_svd_alg(A) = DivideAndConquer()
66+
67+
6368
"""
64-
svd!(A; full::Bool = false) -> SVD
69+
svd!(A; full::Bool = false, alg::Algorithm = default_svd_alg(A)) -> SVD
6570
6671
`svd!` is the same as [`svd`](@ref), but saves space by
6772
overwriting the input `A`, instead of creating a copy.
@@ -92,18 +97,28 @@ julia> A
9297
0.0 0.0 -2.0 0.0 0.0
9398
```
9499
"""
95-
function svd!(A::StridedMatrix{T}; full::Bool = false) where T<:BlasFloat
100+
function svd!(A::StridedMatrix{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where T<:BlasFloat
96101
m,n = size(A)
97102
if m == 0 || n == 0
98103
u,s,vt = (Matrix{T}(I, m, full ? m : n), real(zeros(T,0)), Matrix{T}(I, n, n))
99104
else
100-
u,s,vt = LAPACK.gesdd!(full ? 'A' : 'S', A)
105+
u,s,vt = _svd!(A,full,alg)
101106
end
102107
SVD(u,s,vt)
103108
end
104109

110+
111+
_svd!(A::StridedMatrix{T}, full::Bool, alg::Algorithm) where T<:BlasFloat = throw(ArgumentError("Unsupported value for `alg` keyword."))
112+
_svd!(A::StridedMatrix{T}, full::Bool, alg::DivideAndConquer) where T<:BlasFloat = LAPACK.gesdd!(full ? 'A' : 'S', A)
113+
function _svd!(A::StridedMatrix{T}, full::Bool, alg::QRIteration) where T<:BlasFloat
114+
c = full ? 'A' : 'S'
115+
u,s,vt = LAPACK.gesvd!(c, c, A)
116+
end
117+
118+
119+
105120
"""
106-
svd(A; full::Bool = false) -> SVD
121+
svd(A; full::Bool = false, alg::Algorithm = default_svd_alg(A)) -> SVD
107122
108123
Compute the singular value decomposition (SVD) of `A` and return an `SVD` object.
109124
@@ -120,6 +135,12 @@ and `V` is `N \\times N`, while in the thin factorization `U` is `M
120135
\\times K` and `V` is `N \\times K`, where `K = \\min(M,N)` is the
121136
number of singular values.
122137
138+
If `alg = DivideAndConquer()` a divide-and-conquer algorithm is used to calculate the SVD.
139+
Another (typically slower but more accurate) option is `alg = QRIteration()`.
140+
141+
!!! compat "Julia 1.3"
142+
The `alg` keyword argument requires Julia 1.3 or later.
143+
123144
# Examples
124145
```jldoctest
125146
julia> A = [1. 0. 0. 0. 2.; 0. 0. 3. 0. 0.; 0. 0. 0. 0. 0.; 0. 2. 0. 0. 0.]
@@ -144,21 +165,21 @@ julia> u == F.U && s == F.S && v == F.V
144165
true
145166
```
146167
"""
147-
function svd(A::StridedVecOrMat{T}; full::Bool = false) where T
148-
svd!(copy_oftype(A, eigtype(T)), full = full)
168+
function svd(A::StridedVecOrMat{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where T
169+
svd!(copy_oftype(A, eigtype(T)), full = full, alg = alg)
149170
end
150-
function svd(x::Number; full::Bool = false)
171+
function svd(x::Number; full::Bool = false, alg::Algorithm = default_svd_alg(x))
151172
SVD(x == 0 ? fill(one(x), 1, 1) : fill(x/abs(x), 1, 1), [abs(x)], fill(one(x), 1, 1))
152173
end
153-
function svd(x::Integer; full::Bool = false)
154-
svd(float(x), full = full)
174+
function svd(x::Integer; full::Bool = false, alg::Algorithm = default_svd_alg(x))
175+
svd(float(x), full = full, alg = alg)
155176
end
156-
function svd(A::Adjoint; full::Bool = false)
157-
s = svd(A.parent, full = full)
177+
function svd(A::Adjoint; full::Bool = false, alg::Algorithm = default_svd_alg(A))
178+
s = svd(A.parent, full = full, alg = alg)
158179
return SVD(s.Vt', s.S, s.U')
159180
end
160-
function svd(A::Transpose; full::Bool = false)
161-
s = svd(A.parent, full = full)
181+
function svd(A::Transpose; full::Bool = false, alg::Algorithm = default_svd_alg(A))
182+
s = svd(A.parent, full = full, alg = alg)
162183
return SVD(transpose(s.Vt), s.S, transpose(s.U))
163184
end
164185

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2532,7 +2532,7 @@ eigen(A::AbstractTriangular) = Eigen(eigvals(A), eigvecs(A))
25322532
# Generic singular systems
25332533
for func in (:svd, :svd!, :svdvals)
25342534
@eval begin
2535-
($func)(A::AbstractTriangular) = ($func)(copyto!(similar(parent(A)), A))
2535+
($func)(A::AbstractTriangular; kwargs...) = ($func)(copyto!(similar(parent(A)), A); kwargs...)
25362536
end
25372537
end
25382538

stdlib/LinearAlgebra/test/svd.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,28 @@ aimg = randn(n,n)/2
143143
end
144144
end
145145

146+
147+
148+
@testset "SVD Algorithms" begin
149+
(x,y) = isapprox(x,y,rtol=1e-15)
150+
151+
x = [0.1 0.2; 0.3 0.4]
152+
153+
for alg in [LinearAlgebra.QRIteration(), LinearAlgebra.DivideAndConquer()]
154+
sx1 = svd(x, alg = alg)
155+
@test sx1.U * Diagonal(sx1.S) * sx1.Vt x
156+
@test sx1.V * sx1.Vt I
157+
@test sx1.U * sx1.U' I
158+
@test all(sx1.S .≥ 0)
159+
160+
sx2 = svd!(copy(x), alg = alg)
161+
@test sx2.U * Diagonal(sx2.S) * sx2.Vt x
162+
@test sx2.V * sx2.Vt I
163+
@test sx2.U * sx2.U' I
164+
@test all(sx2.S .≥ 0)
165+
end
166+
end
167+
168+
169+
146170
end # module TestSVD

0 commit comments

Comments
 (0)