Skip to content

Commit 5fd053f

Browse files
authored
diag of SparseMatrixCSC should always return SparseVector (#23261)
* diag of SparseMatrixCSC should always return SparseVector * remove SpDiagIterator
1 parent 88a553a commit 5fd053f

File tree

4 files changed

+49
-27
lines changed

4 files changed

+49
-27
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,8 @@ Deprecated or removed
312312
* `Base.cpad` has been removed; use an appropriate combination of `rpad` and `lpad`
313313
instead ([#23187]).
314314

315+
* `Base.SparseArrays.SpDiagIterator` has been removed ([#23261]).
316+
315317
Command-line option changes
316318
---------------------------
317319

base/sparse/linalg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -879,7 +879,7 @@ for f in (:\, :Ac_ldiv_B, :At_ldiv_B)
879879
if m == n
880880
if istril(A)
881881
if istriu(A)
882-
return ($f)(Diagonal(A), B)
882+
return ($f)(Diagonal(Vector(diag(A))), B)
883883
else
884884
return ($f)(LowerTriangular(A), B)
885885
end

base/sparse/sparsematrix.jl

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3380,40 +3380,39 @@ function expandptr(V::Vector{<:Integer})
33803380
res
33813381
end
33823382

3383-
## diag and related using an iterator
33843383

3385-
mutable struct SpDiagIterator{Tv,Ti}
3386-
A::SparseMatrixCSC{Tv,Ti}
3387-
n::Int
3388-
end
3389-
SpDiagIterator(A::SparseMatrixCSC) = SpDiagIterator(A,minimum(size(A)))
3390-
3391-
length(d::SpDiagIterator) = d.n
3392-
start(d::SpDiagIterator) = 1
3393-
done(d::SpDiagIterator, j) = j > d.n
3394-
3395-
function next(d::SpDiagIterator{Tv}, j) where Tv
3396-
A = d.A
3397-
r1 = Int(A.colptr[j])
3398-
r2 = Int(A.colptr[j+1]-1)
3399-
(r1 > r2) && (return (zero(Tv), j+1))
3400-
r1 = searchsortedfirst(A.rowval, j, r1, r2, Forward)
3401-
(((r1 > r2) || (A.rowval[r1] != j)) ? zero(Tv) : A.nzval[r1], j+1)
3384+
function diag(A::SparseMatrixCSC{Tv,Ti}, d::Integer=0) where {Tv,Ti}
3385+
m, n = size(A)
3386+
k = Int(d)
3387+
if !(-m <= k <= n)
3388+
throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($m, $n)"))
3389+
end
3390+
l = k < 0 ? min(m+k,n) : min(n-k,m)
3391+
r, c = k <= 0 ? (-k, 0) : (0, k) # start row/col -1
3392+
ind = Vector{Ti}()
3393+
val = Vector{Tv}()
3394+
for i in 1:l
3395+
r += 1; c += 1
3396+
r1 = Int(A.colptr[c])
3397+
r2 = Int(A.colptr[c+1]-1)
3398+
r1 > r2 && continue
3399+
r1 = searchsortedfirst(A.rowval, r, r1, r2, Forward)
3400+
((r1 > r2) || (A.rowval[r1] != r)) && continue
3401+
push!(ind, i)
3402+
push!(val, A.nzval[r1])
3403+
end
3404+
return SparseVector{Tv,Ti}(l, ind, val)
34023405
end
34033406

34043407
function trace(A::SparseMatrixCSC{Tv}) where Tv
3405-
if size(A,1) != size(A,2)
3406-
throw(DimensionMismatch("expected square matrix"))
3407-
end
3408+
n = checksquare(A)
34083409
s = zero(Tv)
3409-
for d in SpDiagIterator(A)
3410-
s += d
3410+
for i in 1:n
3411+
s += A[i,i]
34113412
end
3412-
s
3413+
return s
34133414
end
34143415

3415-
diag(A::SparseMatrixCSC{Tv}) where {Tv} = Tv[d for d in SpDiagIterator(A)]
3416-
34173416
function diagm(v::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
34183417
if size(v,1) != 1 && size(v,2) != 1
34193418
throw(DimensionMismatch("input should be nx1 or 1xn"))

test/sparse/sparse.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,27 @@ end
13231323
@test diagm(sparse(ones(5,1))) == speye(5)
13241324
end
13251325

1326+
@testset "diag" begin
1327+
for T in (Float64, Complex128)
1328+
S1 = sprand(T, 5, 5, 0.5)
1329+
S2 = sprand(T, 10, 5, 0.5)
1330+
S3 = sprand(T, 5, 10, 0.5)
1331+
for S in (S1, S2, S3)
1332+
A = Matrix(S)
1333+
@test diag(S)::SparseVector{T,Int} == diag(A)
1334+
for k in -size(S,1):size(S,2)
1335+
@test diag(S, k)::SparseVector{T,Int} == diag(A, k)
1336+
end
1337+
@test_throws ArgumentError diag(S, -size(S,1)-1)
1338+
@test_throws ArgumentError diag(S, size(S,2)+1)
1339+
end
1340+
end
1341+
# test that stored zeros are still stored zeros in the diagonal
1342+
S = sparse([1,3],[1,3],[0.0,0.0]); V = diag(S)
1343+
@test V.nzind == [1,3]
1344+
@test V.nzval == [0.0,0.0]
1345+
end
1346+
13261347
@testset "expandptr" begin
13271348
A = speye(5)
13281349
@test Base.SparseArrays.expandptr(A.colptr) == collect(1:5)

0 commit comments

Comments
 (0)