diff --git a/src/fillalgebra.jl b/src/fillalgebra.jl index 928c04a7..c17e46cd 100644 --- a/src/fillalgebra.jl +++ b/src/fillalgebra.jl @@ -551,3 +551,6 @@ end function LinearAlgebra.istril(A::AbstractFillMatrix, k::Integer = 0) iszero(A) || k >= size(A,2)-1 end + +triu(A::AbstractZerosMatrix, k::Integer=0) = A +tril(A::AbstractZerosMatrix, k::Integer=0) = A diff --git a/src/oneelement.jl b/src/oneelement.jl index 28bc2c4b..17c8309c 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -296,7 +296,20 @@ end adjoint(A::OneElementMatrix) = OneElement(adjoint(A.val), reverse(A.ind), reverse(A.axes)) transpose(A::OneElementMatrix) = OneElement(transpose(A.val), reverse(A.ind), reverse(A.axes)) +# tril/triu + +function tril(A::OneElementMatrix, k::Integer=0) + nzband = A.ind[2] - A.ind[1] + OneElement(nzband > k ? zero(A.val) : A.val, A.ind, axes(A)) +end + +function triu(A::OneElementMatrix, k::Integer=0) + nzband = A.ind[2] - A.ind[1] + OneElement(nzband < k ? zero(A.val) : A.val, A.ind, axes(A)) +end + # broadcast + function broadcasted(::DefaultArrayStyle{N}, ::typeof(conj), r::OneElement{<:Any,N}) where {N} OneElement(conj(r.val), r.ind, axes(r)) end diff --git a/test/runtests.jl b/test/runtests.jl index e7a5eda6..ddc248e3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2447,6 +2447,17 @@ end @test A / (2 + 3.0im) === OneElement(val / (2 + 3.0im), (2,2), (3,4)) == B / (2 + 3.0im) end + + @testset "tril/triu" begin + for A in (OneElement(3, (4,2), (4,5)), OneElement(3, (2,3), (4,5)), OneElement(3, (3,3), (4,5))) + B = Array(A) + for k in -5:5 + @test tril(A, k) == tril(B, k) + @test triu(A, k) == triu(B, k) + end + end + end + @testset "broadcasting" begin for v in (OneElement(2, 3, 4), OneElement(2im, (1,2), (3,4))) w = Array(v) @@ -2634,3 +2645,11 @@ end end end end + +@testset "triu/tril for Zeros" begin + Z = Zeros(3, 4) + @test triu(Z) === Z + @test tril(Z) === Z + @test triu(Z, 2) === Z + @test tril(Z, 2) === Z +end