Skip to content

Commit 1274740

Browse files
committed
Add diagview to obtain a view along a diagonal
1 parent 88201cf commit 1274740

File tree

10 files changed

+68
-18
lines changed

10 files changed

+68
-18
lines changed

NEWS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ Standard library changes
151151
* The matrix multiplication `A * B` calls `matprod_dest(A, B, T::Type)` to generate the destination.
152152
This function is now public ([#55537]).
153153
* The function `haszero(T::Type)` is used to check if a type `T` has a unique zero element defined as `zero(T)`.
154-
This is now public.
154+
This is now public ([#56223]).
155+
* A new function `diagview` is added that returns a view into a specific band of an `AbstractMatrix` ([#56175]).
155156

156157
#### Logging
157158

stdlib/LinearAlgebra/src/LinearAlgebra.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ export
8787
diag,
8888
diagind,
8989
diagm,
90+
diagview,
9091
dot,
9192
eigen!,
9293
eigen,

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ function Matrix{T}(A::Bidiagonal) where T
191191
B = Matrix{T}(undef, size(A))
192192
if haszero(T) # optimized path for types with zero(T) defined
193193
size(B,1) > 1 && fill!(B, zero(T))
194-
copyto!(view(B, diagind(B)), A.dv)
195-
copyto!(view(B, diagind(B, _offdiagind(A.uplo))), A.ev)
194+
copyto!(diagview(B), A.dv)
195+
copyto!(diagview(B, _offdiagind(A.uplo)), A.ev)
196196
else
197197
copyto!(B, A)
198198
end
@@ -570,7 +570,7 @@ end
570570
# to avoid allocations in _mul! below (#24324, #24578)
571571
_diag(A::Tridiagonal, k) = k == -1 ? A.dl : k == 0 ? A.d : A.du
572572
_diag(A::SymTridiagonal{<:Number}, k) = k == 0 ? A.dv : A.ev
573-
_diag(A::SymTridiagonal, k) = k == 0 ? view(A, diagind(A, IndexStyle(A))) : view(A, diagind(A, 1, IndexStyle(A)))
573+
_diag(A::SymTridiagonal, k) = diagview(A,k)
574574
function _diag(A::Bidiagonal, k)
575575
if k == 0
576576
return A.dv

stdlib/LinearAlgebra/src/dense.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,35 @@ julia> diag(A,1)
290290
"""
291291
diag(A::AbstractMatrix, k::Integer=0) = A[diagind(A, k, IndexStyle(A))]
292292

293+
"""
294+
diagview(M, k::Integer=0)
295+
296+
Return a view into the `k`th diagonal of the matrix `M`.
297+
298+
See also [`diag`](@ref), [`diagind`](@ref).
299+
300+
# Examples
301+
```jldoctest
302+
julia> A = [1 2 3; 4 5 6; 7 8 9]
303+
3×3 Matrix{Int64}:
304+
1 2 3
305+
4 5 6
306+
7 8 9
307+
308+
julia> diagview(A)
309+
3-element view(::Vector{Int64}, 1:4:9) with eltype Int64:
310+
1
311+
5
312+
9
313+
314+
julia> diagview(A, 1)
315+
2-element view(::Vector{Int64}, 4:4:8) with eltype Int64:
316+
2
317+
6
318+
```
319+
"""
320+
diagview(A::AbstractMatrix, k::Integer=0) = @view A[diagind(A, k, IndexStyle(A))]
321+
293322
"""
294323
diagm(kv::Pair{<:Integer,<:AbstractVector}...)
295324
diagm(m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...)

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ function Matrix{T}(D::Diagonal) where {T}
120120
B = Matrix{T}(undef, size(D))
121121
if haszero(T) # optimized path for types with zero(T) defined
122122
size(B,1) > 1 && fill!(B, zero(T))
123-
copyto!(view(B, diagind(B)), D.diag)
123+
copyto!(diagview(B), D.diag)
124124
else
125125
copyto!(B, D)
126126
end
@@ -1041,7 +1041,7 @@ dot(x::AbstractVector, D::Diagonal, y::AbstractVector) = _mapreduce_prod(dot, x,
10411041
dot(A::Diagonal, B::Diagonal) = dot(A.diag, B.diag)
10421042
function dot(D::Diagonal, B::AbstractMatrix)
10431043
size(D) == size(B) || throw(DimensionMismatch(lazy"Matrix sizes $(size(D)) and $(size(B)) differ"))
1044-
return dot(D.diag, view(B, diagind(B, IndexStyle(B))))
1044+
return dot(D.diag, diagview(B))
10451045
end
10461046

10471047
dot(A::AbstractMatrix, B::Diagonal) = conj(dot(B, A))

stdlib/LinearAlgebra/src/special.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function Tridiagonal(A::Bidiagonal)
2222
end
2323

2424
_diagview(S::SymTridiagonal{<:Number}) = S.dv
25-
_diagview(S::SymTridiagonal) = view(S, diagind(S, IndexStyle(S)))
25+
_diagview(S::SymTridiagonal) = diagview(S)
2626

2727
# conversions from SymTridiagonal to other special matrix types
2828
Diagonal(A::SymTridiagonal) = Diagonal(_diagview(A))
@@ -370,20 +370,20 @@ function copyto!(dest::BandedMatrix, src::BandedMatrix)
370370
end
371371
function _copyto_banded!(T::Tridiagonal, D::Diagonal)
372372
T.d .= D.diag
373-
T.dl .= view(D, diagind(D, -1, IndexStyle(D)))
374-
T.du .= view(D, diagind(D, 1, IndexStyle(D)))
373+
T.dl .= diagview(D, -1)
374+
T.du .= diagview(D, 1)
375375
return T
376376
end
377377
function _copyto_banded!(SymT::SymTridiagonal, D::Diagonal)
378378
issymmetric(D) || throw(ArgumentError("cannot copy a non-symmetric Diagonal matrix to a SymTridiagonal"))
379379
SymT.dv .= D.diag
380380
_ev = _evview(SymT)
381-
_ev .= view(D, diagind(D, 1, IndexStyle(D)))
381+
_ev .= diagview(D, 1)
382382
return SymT
383383
end
384384
function _copyto_banded!(B::Bidiagonal, D::Diagonal)
385385
B.dv .= D.diag
386-
B.ev .= view(D, diagind(D, B.uplo == 'U' ? 1 : -1, IndexStyle(D)))
386+
B.ev .= diagview(D, _offdiagind(B.uplo))
387387
return B
388388
end
389389
function _copyto_banded!(D::Diagonal, B::Bidiagonal)
@@ -411,10 +411,10 @@ function _copyto_banded!(T::Tridiagonal, B::Bidiagonal)
411411
T.d .= B.dv
412412
if B.uplo == 'U'
413413
T.du .= B.ev
414-
T.dl .= view(B, diagind(B, -1, IndexStyle(B)))
414+
T.dl .= diagview(B,-1)
415415
else
416416
T.dl .= B.ev
417-
T.du .= view(B, diagind(B, 1, IndexStyle(B)))
417+
T.du .= diagview(B, 1)
418418
end
419419
return T
420420
end

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2041,7 +2041,7 @@ function _find_params_log_quasitriu!(A)
20412041

20422042
# Find s0, the smallest s such that the ρ(triu(A)^(1/2^s) - I) ≤ theta[tmax], where ρ(X)
20432043
# is the spectral radius of X
2044-
d = complex.(@view(A[diagind(A)]))
2044+
d = complex.(diagview(A))
20452045
dm1 = d .- 1
20462046
s = 0
20472047
while norm(dm1, Inf) > theta[tmax] && s < maxsqrt

stdlib/LinearAlgebra/src/tridiag.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -612,9 +612,9 @@ function Matrix{T}(M::Tridiagonal) where {T}
612612
A = Matrix{T}(undef, size(M))
613613
if haszero(T) # optimized path for types with zero(T) defined
614614
size(A,1) > 2 && fill!(A, zero(T))
615-
copyto!(view(A, diagind(A)), M.d)
616-
copyto!(view(A, diagind(A,1)), M.du)
617-
copyto!(view(A, diagind(A,-1)), M.dl)
615+
copyto!(diagview(A), M.d)
616+
copyto!(diagview(A,1), M.du)
617+
copyto!(diagview(A,-1), M.dl)
618618
else
619619
copyto!(A, M)
620620
end
@@ -1092,7 +1092,7 @@ function show(io::IO, T::Tridiagonal)
10921092
end
10931093
function show(io::IO, S::SymTridiagonal)
10941094
print(io, "SymTridiagonal(")
1095-
show(io, eltype(S) <: Number ? S.dv : view(S, diagind(S, IndexStyle(S))))
1095+
show(io, _diagview(S))
10961096
print(io, ", ")
10971097
show(io, S.ev)
10981098
print(io, ")")

stdlib/LinearAlgebra/test/dense.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,15 @@ end
10241024
@test diag(zeros(0,1),2) == []
10251025
end
10261026

1027+
@testset "diagview" begin
1028+
for sz in ((3,3), (3,5), (5,3))
1029+
A = rand(sz...)
1030+
for k in -5:5
1031+
@test diagview(A,k) == diag(A,k)
1032+
end
1033+
end
1034+
end
1035+
10271036
@testset "issue #39857" begin
10281037
@test lyap(1.0+2.0im, 3.0+4.0im) == -1.5 - 2.0im
10291038
end

stdlib/LinearAlgebra/test/tridiag.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,4 +1065,14 @@ end
10651065
end
10661066
end
10671067

1068+
@testset "diagview" begin
1069+
A = Tridiagonal(rand(3), rand(4), rand(3))
1070+
for k in -5:5
1071+
@test diagview(A,k) == diag(A,k)
1072+
end
1073+
v = diagview(A,1)
1074+
v .= 0
1075+
@test all(iszero, diag(A,1))
1076+
end
1077+
10681078
end # module TestTridiagonal

0 commit comments

Comments
 (0)