diff --git a/src/SDiagonal.jl b/src/SDiagonal.jl new file mode 100644 index 00000000..4ff83cef --- /dev/null +++ b/src/SDiagonal.jl @@ -0,0 +1,106 @@ +# Originally contributed by D. Getz (https://github.com/getzdan), M. Schauer +# at https://github.com/mschauer/Bridge.jl under MIT License + +import Base: ==, -, +, *, /, \, abs, real, imag, conj + +@generated function scalem(a::StaticMatrix{M,N}, b::StaticVector{N}) where {M, N} + expr = vec([:(a[$j,$i]*b[$i]) for j=1:M, i=1:N]) + :(@_inline_meta; let val1 = ($(expr[1])); similar_type(SMatrix{M,N},typeof(val1))(val1, $(expr[2:end]...)); end) +end +@generated function scalem(a::StaticVector{M}, b::StaticMatrix{M, N}) where {M, N} + expr = vec([:(b[$j,$i]*a[$j]) for j=1:M, i=1:N]) + :(@_inline_meta; let val1 = ($(expr[1])); similar_type(SMatrix{M,N},typeof(val1))(val1, $(expr[2:end]...)); end) +end + +struct SDiagonal{N,T} <: StaticMatrix{N,N,T} + diag::SVector{N,T} + SDiagonal{N,T}(diag::SVector{N,T}) where {N,T} = new(diag) +end +diagtype(::Type{SDiagonal{N,T}}) where {N, T} = SVector{N,T} +diagtype(::Type{SDiagonal{N}}) where {N} = SVector{N} +diagtype(::Type{SDiagonal}) = SVector + +# this is to deal with convert.jl +@inline (::Type{SD})(a::AbstractVector) where {SD <: SDiagonal} = SDiagonal(convert(diagtype(SD), a)) +@inline (::Type{SD})(a::Tuple) where {SD <: SDiagonal} = SDiagonal(convert(diagtype(SD), a)) +@inline (::Type{SDiagonal})(a::SVector{N,T}) where {N,T} = SDiagonal{N,T}(a) + +@generated function SDiagonal(a::StaticMatrix{N,N,T}) where {N,T} + expr = [:(a[$i,$i]) for i=1:N] + :(SDiagonal{N,T}($(expr...))) +end + +convert(::Type{SDiagonal{N,T}}, D::SDiagonal{N,T}) where {N,T} = D +convert(::Type{SDiagonal{N,T}}, D::SDiagonal{N}) where {N,T} = SDiagonal{N,T}(convert(SVector{N,T}, D.diag)) + +function getindex(D::SDiagonal{N,T}, i::Int, j::Int) where {N,T} + @boundscheck checkbounds(D, i, j) + @inbounds return ifelse(i == j, D.diag[i], zero(T)) +end + +# avoid linear indexing? +@propagate_inbounds function getindex(D::SDiagonal{N,T}, k::Int) where {N,T} + i, j = ind2sub(size(D), k) + D[i,j] +end + +ishermitian(D::SDiagonal{N, T}) where {N,T<:Real} = true +ishermitian(D::SDiagonal) = all(D.diag .== real(D.diag)) +issymmetric(D::SDiagonal) = true +isposdef(D::SDiagonal) = all(D.diag .> 0) + +factorize(D::SDiagonal) = D + +==(Da::SDiagonal, Db::SDiagonal) = Da.diag == Db.diag +-(A::SDiagonal) = SDiagonal(-A.diag) ++(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag + Db.diag) +-(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag - Db.diag) +-(A::SDiagonal, B::SMatrix) = eye(typeof(B))*A - B + +*(x::T, D::SDiagonal) where {T<:Number} = SDiagonal(x * D.diag) +*(D::SDiagonal, x::T) where {T<:Number} = SDiagonal(D.diag * x) +/(D::SDiagonal, x::T) where {T<:Number} = SDiagonal(D.diag / x) +*(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag .* Db.diag) +*(D::SDiagonal, V::AbstractVector) = D.diag .* V +*(D::SDiagonal, V::StaticVector) = D.diag .* V +*(A::StaticMatrix, D::SDiagonal) = scalem(A,D.diag) +*(D::SDiagonal, A::StaticMatrix) = scalem(D.diag,A) +\(D::SDiagonal, b::AbstractVector) = D.diag .\ b +\(D::SDiagonal, b::StaticVector) = D.diag .\ b # catch ambiguity + +conj(D::SDiagonal) = SDiagonal(conj(D.diag)) +transpose(D::SDiagonal) = D +ctranspose(D::SDiagonal) = conj(D) + +diag(D::SDiagonal) = D.diag +trace(D::SDiagonal) = sum(D.diag) +det(D::SDiagonal) = prod(D.diag) +logdet{N,T<:Real}(D::SDiagonal{N,T}) = sum(log.(D.diag)) +function logdet(D::SDiagonal{N,T}) where {N,T<:Complex} #Make sure branch cut is correct + x = sum(log.(D.diag)) + -pi(@inbounds iszero(D.diag[i]) && throw(Base.LinAlg.SingularException(i))) + end +end + +function inv(D::SDiagonal) + check_singular(D) + SDiagonal(inv.(D.diag)) +end + diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 4b8a6bb5..3547bc71 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -8,8 +8,8 @@ import Base: getindex, setindex!, size, similar, vec, show, length, convert, promote_op, promote_rule, map, map!, reduce, reducedim, mapreducedim, mapreduce, broadcast, broadcast!, conj, transpose, ctranspose, hcat, vcat, ones, zeros, eye, one, cross, vecdot, reshape, fill, - fill!, det, inv, eig, eigvals, expm, sqrtm, trace, vecnorm, norm, dot, diagm, diag, - lu, svd, svdvals, svdfact, + fill!, det, logdet, inv, eig, eigvals, expm, logm, sqrtm, trace, diag, vecnorm, norm, dot, diagm, diag, + lu, svd, svdvals, svdfact, factorize, ishermitian, issymmetric, isposdef, sum, diff, prod, count, any, all, minimum, maximum, extrema, mean, copy, rand, randn, randexp, rand!, randn!, randexp!, normalize, normalize!, read, read!, write @@ -19,6 +19,7 @@ export Scalar, SArray, SVector, SMatrix export MArray, MVector, MMatrix export FieldVector export SizedArray, SizedVector, SizedMatrix +export SDiagonal export Size, Length @@ -79,6 +80,7 @@ include("MArray.jl") include("MVector.jl") include("MMatrix.jl") include("SizedArray.jl") +include("SDiagonal.jl") include("abstractarray.jl") include("indexing.jl") diff --git a/test/SDiagonal.jl b/test/SDiagonal.jl new file mode 100644 index 00000000..030898d1 --- /dev/null +++ b/test/SDiagonal.jl @@ -0,0 +1,104 @@ +@testset "SDiagonal" begin + @testset "Constructors" begin + @test SDiagonal{1,Int64}((1,)).diag === SVector{1,Int64}((1,)) + @test SDiagonal{1,Float64}((1,)).diag === SVector{1,Float64}((1,)) + + @test SDiagonal{4,Float64}((1, 1.0, 1, 1)).diag.data === (1.0, 1.0, 1.0, 1.0) + @test SDiagonal{4}((1, 1.0, 1, 1)).diag.data === (1.0, 1.0, 1.0, 1.0) + @test SDiagonal((1, 1.0, 1, 1)).diag.data === (1.0, 1.0, 1.0, 1.0) + + # Bad input + @test_throws Exception SMatrix{1,Int}() + @test_throws Exception SMatrix{2,Int}((1,)) + + # From SMatrix + @test SDiagonal(SMatrix{2,2,Int}((1,2,3,4))).diag.data === (1,4) + + @test SDiagonal{1,Int}(SDiagonal{1,Float64}((1,))).diag[1] === 1 + + end + + @testset "Methods" begin + + @test StaticArrays.scalem(@SMatrix([1 1 1;1 1 1; 1 1 1]), @SVector [1,2,3]) === @SArray [1 2 3; 1 2 3; 1 2 3] + @test StaticArrays.scalem(@SVector([1,2,3]),@SMatrix [1 1 1;1 1 1; 1 1 1])' === @SArray [1 2 3; 1 2 3; 1 2 3] + + m = SDiagonal(@SVector [11, 12, 13, 14]) + + + + @test diag(m) === m.diag + + + m2 = diagm([11, 12, 13, 14]) + + @test logdet(m) == logdet(m2) + @test logdet(im*m) ≈ logdet(im*m2) + @test det(m) == det(m2) + @test trace(m) == trace(m2) + @test logm(m) == logm(m2) + @test expm(m) == expm(m2) + @test sqrtm(m) == sqrtm(m2) + + + @test isimmutable(m) == true + + @test m[1,1] === 11 + @test m[2,2] === 12 + @test m[3,3] === 13 + @test m[4,4] === 14 + + for i in 1:4 + for j in 1:4 + i == j || @test m[i,j] === 0 + end + end + + @test_throws Exception m[5,5] + + @test_throws Exception m[1,5] + + + @test size(m) === (4, 4) + @test size(typeof(m)) === (4, 4) + @test size(SDiagonal{4}) === (4, 4) + + @test size(m, 1) === 4 + @test size(m, 2) === 4 + @test size(typeof(m), 1) === 4 + @test size(typeof(m), 2) === 4 + + @test length(m) === 4*4 + + @test_throws Exception m[1] = 1 + + b = @SVector [2,-1,2,1] + b2 = Vector(b) + + + @test m*b == @SVector [22,-12,26,14] + @test (b'*m)' == @SVector [22,-12,26,14] + + @test m\b == m2\b + + @test b'/m == b'/m2 + @test_throws Exception b/m + @test m*m == m2*m + + @test ishermitian(m) == ishermitian(m2) + @test ishermitian(m/2) + + @test isposdef(m) == isposdef(m2) + @test issymmetric(m) == issymmetric(m2) + + @test (2*m/2)' == m + @test 2m == m + m + @test m*0 == m - m + + @test m*inv(m) == m/m == m\m == eye(SDiagonal{4,Float64}) + + + + + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 7e597d1e..a4414f45 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,8 @@ include("FieldVector.jl") include("Scalar.jl") include("SUnitRange.jl") include("SizedArray.jl") +include("SDiagonal.jl") + include("custom_types.jl") include("core.jl")