Skip to content

Commit aebf3bc

Browse files
committed
Porting SDiagonal from Bridge.jl
Originally contributed by D. Getz (https://github.com/getzdan), M. Schauer at https://github.com/mschauer/Bridge.jl under MIT License
1 parent 80b6bac commit aebf3bc

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed

src/SDiagonal.jl

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Originally contributed by D. Getz (https://github.com/getzdan), M. Schauer
2+
# at https://github.com/mschauer/Bridge.jl under MIT License
3+
4+
import Base: getindex,setindex!,==,-,+,*,/,\,transpose,ctranspose,convert, size, abs, real, imag, conj, eye, inv
5+
import Base.LinAlg: ishermitian, issymmetric, isposdef, factorize, diag, trace, det, logdet, expm, logm, sqrtm
6+
7+
@generated function scalem{T, M, N}(a::SMatrix{M,N, T}, b::SVector{N, T})
8+
expr = vec([:(a[$j,$i]*b[$i]) for j=1:M, i=1:N])
9+
:(SMatrix{M,N,T}($(expr...)))
10+
end
11+
@generated function scalem{T, M, N}(a::SVector{M,T}, b::SMatrix{M, N, T})
12+
expr = vec([:(b[$j,$i]*a[$j]) for j=1:M, i=1:N])
13+
:(SMatrix{M,N,T}($(expr...)))
14+
end
15+
16+
struct SDiagonal{N,T}
17+
diag::SVector{N,T}
18+
end
19+
20+
function \{T,M}(D::SDiagonal, b::SVector{M,T} )
21+
D.diag .* b
22+
end
23+
24+
SDiagonal(A::SMatrix) = SDiagonal(diag(A))
25+
26+
27+
convert{N,T}(::Type{SDiagonal{N,T}}, D::SDiagonal{N,T}) = D
28+
convert{N,T}(::Type{SDiagonal{N,T}}, D::SDiagonal) = SDiagonal{N,T}(convert(SVector{N,T}, D.diag))
29+
30+
size(D::SDiagonal) = (length(D.diag),length(D.diag))
31+
32+
function size(D::SDiagonal,d::Integer)
33+
if d<1
34+
throw(ArgumentError("dimension must be ≥ 1, got $d"))
35+
end
36+
return d<=2 ? length(D.diag) : 1
37+
end
38+
39+
function getindex{T}(D::SDiagonal{T}, i::Int, j::Int)
40+
if i == j
41+
D.diag[i]
42+
else
43+
zero(T)
44+
end
45+
end
46+
function setindex!(D::SDiagonal, v, i::Int, j::Int)
47+
if i == j
48+
unsafe_setindex!(D.diag, v, i)
49+
elseif v != 0
50+
throw(ArgumentError("cannot set an off-diagonal index ($i, $j) to a nonzero value ($v)"))
51+
end
52+
D
53+
end
54+
55+
ishermitian{T<:Real}(D::SDiagonal{T}) = true
56+
ishermitian(D::SDiagonal) = all(D.diag .== real(D.diag))
57+
issym(D::SDiagonal) = true
58+
isposdef(D::SDiagonal) = all(D.diag .> 0)
59+
60+
factorize(D::SDiagonal) = D
61+
62+
abs(D::SDiagonal) = SDiagonal(abs(D.diag))
63+
real(D::SDiagonal) = SDiagonal(real(D.diag))
64+
imag(D::SDiagonal) = SDiagonal(imag(D.diag))
65+
66+
==(Da::SDiagonal, Db::SDiagonal) = Da.diag == Db.diag
67+
-(A::SDiagonal) = SDiagonal(-A.diag)
68+
+(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag + Db.diag)
69+
-(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag - Db.diag)
70+
-(A::SDiagonal, B::SMatrix) = eye(typeof(B))*A - B
71+
72+
73+
*{T<:Number}(x::T, D::SDiagonal) = SDiagonal(x * D.diag)
74+
*{T<:Number}(D::SDiagonal, x::T) = SDiagonal(D.diag * x)
75+
/{T<:Number}(D::SDiagonal, x::T) = SDiagonal(D.diag / x)
76+
*(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag .* Db.diag)
77+
*(D::SDiagonal, V::SVector) = D.diag .* V
78+
*(V::SVector, D::SDiagonal) = D.diag .* V
79+
*(A::SMatrix, D::SDiagonal) = scalem(A,D.diag)
80+
*(D::SDiagonal, A::SMatrix) = scalem(D.diag,A)
81+
82+
/(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag ./ Db.diag )
83+
84+
conj(D::SDiagonal) = SDiagonal(conj(D.diag))
85+
transpose(D::SDiagonal) = D
86+
ctranspose(D::SDiagonal) = conj(D)
87+
88+
diag(D::SDiagonal) = D.diag
89+
trace(D::SDiagonal) = sum(D.diag)
90+
det(D::SDiagonal) = prod(D.diag)
91+
logdet{N,T<:Real}(D::SDiagonal{N,T}) = sum(log.(D.diag))
92+
function logdet{N,T<:Complex}(D::SDiagonal{N,T}) #Make sure branch cut is correct
93+
x = sum(log.(D.diag))
94+
-pi<imag(x)<pi ? x : real(x)+(mod2pi(imag(x)+pi)-pi)*im
95+
end
96+
97+
98+
eye{N,T}(::Type{SDiagonal{N,T}}) = SDiagonal(one(SVector{n,Int}))
99+
100+
expm(D::SDiagonal) = SDiagonal(exp.(D.diag))
101+
logm(D::SDiagonal) = SDiagonal(log.(D.diag))
102+
sqrtm(D::SDiagonal) = SDiagonal(sqrt.(D.diag))
103+
104+
\(D::SDiagonal, B::SMatrix) = scalem(1 ./ D.diag, B)
105+
/(B::SMatrix, D::SDiagonal) = scalem(1 ./ D.diag, B)
106+
\(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Db.diag ./ Da.diag)
107+
108+
function inv{N,T}(D::SDiagonal{N,T})
109+
for i = 1:length(D.diag)
110+
if D.diag[i] == zero(T)
111+
throw(SingularException(i))
112+
end
113+
end
114+
SDiagonal(one(T)./D.diag)
115+
end
116+

src/StaticArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ export Scalar, SArray, SVector, SMatrix
1919
export MArray, MVector, MMatrix
2020
export FieldVector
2121
export SizedArray, SizedVector, SizedMatrix
22+
export SDiagonal
2223

2324
export Size, Length
2425

@@ -79,6 +80,7 @@ include("MArray.jl")
7980
include("MVector.jl")
8081
include("MMatrix.jl")
8182
include("SizedArray.jl")
83+
include("SDiagonal.jl")
8284

8385
include("abstractarray.jl")
8486
include("indexing.jl")

0 commit comments

Comments
 (0)