diff --git a/base/linalg/bunchkaufman.jl b/base/linalg/bunchkaufman.jl index 7f753306f3b7d..12763756fca5b 100644 --- a/base/linalg/bunchkaufman.jl +++ b/base/linalg/bunchkaufman.jl @@ -69,11 +69,135 @@ size(B::BunchKaufman, d::Integer) = size(B.LD, d) issymmetric(B::BunchKaufman) = B.symmetric ishermitian(B::BunchKaufman) = !B.symmetric +function _ipiv2perm_bk(v::AbstractVector{T}, maxi::Integer, uplo::Char) where T + p = T[1:maxi;] + uploL = uplo == 'L' + i = uploL ? 1 : maxi + # if uplo == 'U' we construct the permutation backwards + @inbounds while 1 <= i <= length(v) + vi = v[i] + if vi > 0 # the 1x1 blocks + p[i], p[vi] = p[vi], p[i] + i += uploL ? 1 : -1 + else # the 2x2 blocks + if uploL + p[i + 1], p[-vi] = p[-vi], p[i + 1] + i += 2 + else # 'U' + p[i - 1], p[-vi] = p[-vi], p[i - 1] + i -= 2 + end + end + end + return p +end + +""" + getindex(B::BunchKaufman, d::Symbol) + +Extract the factors of the Bunch-Kaufman factorization `B`. The factorization can take the +two forms `L*D*L'` or `U*D*U'` (or `.'` in the complex symmetric case) where `L` is a +`UnitLowerTriangular` matrix, `U` is a `UnitUpperTriangular`, and `D` is a block diagonal +symmetric or Hermitian matrix with 1x1 or 2x2 blocks. The argument `d` can be +- `:D`: the block diagonal matrix +- `:U`: the upper triangular factor (if factorization is `U*D*U'`) +- `:L`: the lower triangular factor (if factorization is `L*D*L'`) +- `:p`: permutation vector +- `:P`: permutation matrix + +# Examples +```jldoctest +julia> A = [1 2 3; 2 1 2; 3 2 1] +3×3 Array{Int64,2}: + 1 2 3 + 2 1 2 + 3 2 1 + +julia> F = bkfact(Symmetric(A, :L)) +Base.LinAlg.BunchKaufman{Float64,Array{Float64,2}} +D factor: +3×3 Tridiagonal{Float64}: + 1.0 3.0 ⋅ + 3.0 1.0 0.0 + ⋅ 0.0 -1.0 +L factor: +3×3 Base.LinAlg.UnitLowerTriangular{Float64,Array{Float64,2}}: + 1.0 0.0 0.0 + 0.0 1.0 0.0 + 0.5 0.5 1.0 +permutation: +3-element Array{Int64,1}: + 1 + 3 + 2 +successful: true + +julia> F[:L]*F[:D]*F[:L].' - A[F[:p], F[:p]] +3×3 Array{Float64,2}: + 0.0 0.0 0.0 + 0.0 0.0 0.0 + 0.0 0.0 0.0 + +julia> F = bkfact(Symmetric(A)); + +julia> F[:U]*F[:D]*F[:U].' - F[:P]*A*F[:P]' +3×3 Array{Float64,2}: + 0.0 0.0 0.0 + 0.0 0.0 0.0 + 0.0 0.0 0.0 +``` +""" +function getindex(B::BunchKaufman{T}, d::Symbol) where {T<:BlasFloat} + n = size(B, 1) + if d == :p + return _ipiv2perm_bk(B.ipiv, n, B.uplo) + elseif d == :P + return eye(T, n)[:,invperm(B[:p])] + elseif d == :L || d == :U || d == :D + if B.rook + # syconvf_rook just added to LAPACK 3.7.0. Uncomment and remove error when we distribute LAPACK 3.7.0 + # LUD, od = LAPACK.syconvf_rook!(B.uplo, 'C', copy(B.LD), B.ipiv) + throw(ArgumentError("reconstruction rook pivoted Bunch-Kaufman factorization not implemented yet")) + else + LUD, od = LAPACK.syconv!(B.uplo, copy(B.LD), B.ipiv) + end + if d == :D + if B.uplo == 'L' + odl = od[1:n - 1] + return Tridiagonal(odl, diag(LUD), B.symmetric ? odl : conj.(odl)) + else # 'U' + odu = od[2:n] + return Tridiagonal(B.symmetric ? odu : conj.(odu), diag(LUD), odu) + end + elseif d == :L + if B.uplo == 'L' + return UnitLowerTriangular(LUD) + else + throw(ArgumentError("factorization is U*D*U.' but you requested L")) + end + else # :U + if B.uplo == 'U' + return UnitUpperTriangular(LUD) + else + throw(ArgumentError("factorization is L*D*L.' but you requested U")) + end + end + else + throw(KeyError(d)) + end +end + issuccess(B::BunchKaufman) = B.info == 0 function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, B::BunchKaufman) println(io, summary(B)) - print(io, "successful: $(issuccess(B))") + println(io, "D factor:") + show(io, mime, B[:D]) + println(io, "\n$(B.uplo) factor:") + show(io, mime, B[Symbol(B.uplo)]) + println(io, "\npermutation:") + show(io, mime, B[:p]) + print(io, "\nsuccessful: $(issuccess(B))") end function inv(B::BunchKaufman{<:BlasReal}) diff --git a/base/linalg/lapack.jl b/base/linalg/lapack.jl index 75e1807ff2219..3a70e5a865d8a 100644 --- a/base/linalg/lapack.jl +++ b/base/linalg/lapack.jl @@ -3990,9 +3990,9 @@ for (syconv, sysv, sytrf, sytri, sytrs, elty) in end # Rook-pivoting variants of symmetric-matrix algorithms -for (sysv, sytrf, sytri, sytrs, elty) in - ((:dsysv_rook_,:dsytrf_rook_,:dsytri_rook_,:dsytrs_rook_,:Float64), - (:ssysv_rook_,:ssytrf_rook_,:ssytri_rook_,:ssytrs_rook_,:Float32)) +for (sysv, sytrf, sytri, sytrs, syconvf, elty) in + ((:dsysv_rook_,:dsytrf_rook_,:dsytri_rook_,:dsytrs_rook_,:dsyconvf_rook_,:Float64), + (:ssysv_rook_,:ssytrf_rook_,:ssytri_rook_,:ssytrs_rook_,:ssyconvf_rook_,:Float32)) @eval begin # SUBROUTINE DSYSV_ROOK(UPLO, N, NRHS, A, LDA, IPIV, B, LDB, WORK, # LWORK, INFO ) @@ -4107,6 +4107,45 @@ for (sysv, sytrf, sytri, sytrs, elty) in chklapackerror(info[]) B end + + # SUBROUTINE DSYCONVF_ROOK( UPLO, WAY, N, A, LDA, IPIV, E, INFO ) + # + # .. Scalar Arguments .. + # CHARACTER UPLO, WAY + # INTEGER INFO, LDA, N + # .. + # .. Array Arguments .. + # INTEGER IPIV( * ) + # DOUBLE PRECISION A( LDA, * ), E( * ) + function syconvf_rook!(uplo::Char, way::Char, A::StridedMatrix{$elty}, + ipiv::StridedVector{BlasInt}, e::StridedVector{$elty}) + # extract + n = checksquare(A) + + # check + chkuplo(uplo) + if way != :C && way != :R + throw(ArgumentError("way must be :C or :R")) + end + if length(ipiv) != n + throw(ArgumentError("length of pivot vector was $(length(ipiv)) but should have been $n")) + end + if length(e) != n + throw(ArgumentError("length of e vector was $(length(ipiv)) but should have been $n")) + end + + # allocate + info = Ref{BlasInt}() + + ccall((@blasfunc($syconvf), liblapack), Void, + (Ptr{UInt8}, Ptr{UInt8}, Ptr{BlasInt}, Ptr{$elty}, + Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}), + &uplo, &way, &n, A, + &lda, ipiv, e, info) + + chklapackerror(info[]) + return A, e + end end end @@ -4548,9 +4587,9 @@ for (sysv, sytrf, sytri, sytrs, elty, relty) in end end -for (sysv, sytrf, sytri, sytrs, elty, relty) in - ((:zsysv_rook_,:zsytrf_rook_,:zsytri_rook_,:zsytrs_rook_,:Complex128, :Float64), - (:csysv_rook_,:csytrf_rook_,:csytri_rook_,:csytrs_rook_,:Complex64, :Float32)) +for (sysv, sytrf, sytri, sytrs, syconvf, elty, relty) in + ((:zsysv_rook_,:zsytrf_rook_,:zsytri_rook_,:zsytrs_rook_,:zsyconvf_rook_,:Complex128, :Float64), + (:csysv_rook_,:csytrf_rook_,:csytri_rook_,:csytrs_rook_,:csyconvf_rook_,:Complex64, :Float32)) @eval begin # SUBROUTINE ZSYSV_ROOK(UPLO, N, NRHS, A, LDA, IPIV, B, LDB, WORK, # $ LWORK, INFO ) @@ -4667,6 +4706,46 @@ for (sysv, sytrf, sytri, sytrs, elty, relty) in chklapackerror(info[]) B end + + # SUBROUTINE ZSYCONVF_ROOK( UPLO, WAY, N, A, LDA, IPIV, E, INFO ) + # + # .. Scalar Arguments .. + # CHARACTER UPLO, WAY + # INTEGER INFO, LDA, N + # .. + # .. Array Arguments .. + # INTEGER IPIV( * ) + # COMPLEX*16 A( LDA, * ), E( * ) + function syconvf_rook!(uplo::Char, way::Char, A::StridedMatrix{$elty}, + ipiv::StridedVector{BlasInt}, e::StridedVector{$elty} = Vector{$elty}(length(ipiv))) + # extract + n = checksquare(A) + lda = stride(A, 2) + + # check + chkuplo(uplo) + if way != 'C' && way != 'R' + throw(ArgumentError("way must be 'C' or 'R'")) + end + if length(ipiv) != n + throw(ArgumentError("length of pivot vector was $(length(ipiv)) but should have been $n")) + end + if length(e) != n + throw(ArgumentError("length of e vector was $(length(ipiv)) but should have been $n")) + end + + # allocate + info = Ref{BlasInt}() + + ccall((@blasfunc($syconvf), liblapack), Void, + (Ptr{UInt8}, Ptr{UInt8}, Ptr{BlasInt}, Ptr{$elty}, + Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}), + &uplo, &way, &n, A, + &max(1, lda), ipiv, e, info) + + chklapackerror(info[]) + return A, e + end end end diff --git a/doc/src/manual/linear-algebra.md b/doc/src/manual/linear-algebra.md index 176d4b77b6de0..a8303a228502a 100644 --- a/doc/src/manual/linear-algebra.md +++ b/doc/src/manual/linear-algebra.md @@ -75,7 +75,23 @@ julia> B = [1.5 2 -4; 2 -1 -3; -4 -3 5] -4.0 -3.0 5.0 julia> factorize(B) -Base.LinAlg.BunchKaufman{Float64,Array{Float64,2}}([-1.64286 0.142857 -0.8; 2.0 -2.8 -0.6; -4.0 -3.0 5.0], [1, 2, 3], 'U', true, false, 0) +Base.LinAlg.BunchKaufman{Float64,Array{Float64,2}} +D factor: +3×3 Tridiagonal{Float64}: + -1.64286 0.0 ⋅ + 0.0 -2.8 0.0 + ⋅ 0.0 5.0 +U factor: +3×3 Base.LinAlg.UnitUpperTriangular{Float64,Array{Float64,2}}: + 1.0 0.142857 -0.8 + 0.0 1.0 -0.6 + 0.0 0.0 1.0 +permutation: +3-element Array{Int64,1}: + 1 + 2 + 3 +successful: true ``` Here, Julia was able to detect that `B` is in fact symmetric, and used a more appropriate factorization. diff --git a/test/linalg/bunchkaufman.jl b/test/linalg/bunchkaufman.jl index 46e77b3886c4b..80a3781a3f4e6 100644 --- a/test/linalg/bunchkaufman.jl +++ b/test/linalg/bunchkaufman.jl @@ -39,6 +39,11 @@ bimg = randn(n,2)/2 @testset for eltyb in (Float32, Float64, Complex64, Complex128, Int) b = eltyb == Int ? rand(1:5, n, 2) : convert(Matrix{eltyb}, eltyb <: Complex ? complex.(breal, bimg) : breal) + + # check that factorize gives a Bunch-Kaufman + @test isa(factorize(asym), LinAlg.BunchKaufman) + @test isa(factorize(aher), LinAlg.BunchKaufman) + @testset for btype in ("Array", "SubArray") if btype == "Array" b = b @@ -49,10 +54,6 @@ bimg = randn(n,2)/2 εb = eps(abs(float(one(eltyb)))) ε = max(εa,εb) - # check that factorize gives a Bunch-Kaufman - @test isa(factorize(asym), LinAlg.BunchKaufman) - @test isa(factorize(aher), LinAlg.BunchKaufman) - @testset "$uplo Bunch-Kaufman factor of indefinite matrix" for uplo in (:L, :U) bc1 = bkfact(Hermitian(aher, uplo)) @test LinAlg.issuccess(bc1) @@ -73,6 +74,15 @@ bimg = randn(n,2)/2 @test_throws ArgumentError bkfact(a) end end + # Test extraction of factors + # syconvf_rook just added to LAPACK 3.7.0. Test when we distribute LAPACK 3.7.0 + @test bc1[uplo]*bc1[:D]*bc1[uplo]' ≈ aher[bc1[:p], bc1[:p]] + @test bc1[uplo]*bc1[:D]*bc1[uplo]' ≈ bc1[:P]*aher*bc1[:P]' + if eltya <: Complex + bc1 = bkfact(Symmetric(asym, uplo)) + @test bc1[uplo]*bc1[:D]*bc1[uplo].' ≈ asym[bc1[:p], bc1[:p]] + @test bc1[uplo]*bc1[:D]*bc1[uplo].' ≈ bc1[:P]*asym*bc1[:P]' + end end @testset "$uplo Bunch-Kaufman factors of a pos-def matrix" for uplo in (:U, :L) @@ -122,9 +132,7 @@ end end end - -# test example due to @timholy in PR 15354 -let +@testset "test example due to @timholy in PR 15354" begin A = rand(6,5); A = complex(A'*A) # to avoid calling the real-lhs-complex-rhs method F = cholfact(A); v6 = rand(Complex128, 6)