Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions src/SDiagonal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Originally contributed by D. Getz (https://github.com/getzdan), M. Schauer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, if you're porting someone else's code, and you want to attribute them, it's also quite possible to make them the author in git using something like git commit --author="Some Body <[email protected]>"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I do this kind of thing, I'd be inclined to add the original chunk of code (in probably non-working form) under the other author's name as a single commit (@mentioning them to make sure they're happy with that). Then add any necessary changes under your own name in further commits.

The nice thing about doing it this way is you avoid baking authorship into comments which people feel they can't remove (like the one above), but you also get to do proper attribution which I feel is quite important.

Just some thoughts, I'm happy this was merged already.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you port entire files (like in this case) it is also possible to preserve the full git history http://gbayer.com/development/moving-files-from-one-git-repository-to-another-preserving-history/

# at https://github.com/mschauer/Bridge.jl under MIT License

import Base: getindex,setindex!,==,-,+,*,/,\,transpose,ctranspose,convert, size, abs, real, imag, conj, eye, inv
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge these imports with the global ones?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

import Base.LinAlg: ishermitian, issymmetric, isposdef, factorize, diag, trace, det, logdet, expm, logm, sqrtm

@generated function scalem{T, M, N}(a::SMatrix{M,N, T}, b::SVector{N, T})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use StaticMatrix and StaticVector not SMatrix and SVector in the signature, please. Same with the method below.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we don't need to assume the eltype is the same.

expr = vec([:(a[$j,$i]*b[$i]) for j=1:M, i=1:N])
:(SMatrix{M,N,T}($(expr...)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the similar_type interface here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

end
@generated function scalem{T, M, N}(a::SVector{M,T}, b::SMatrix{M, N, T})
expr = vec([:(b[$j,$i]*a[$j]) for j=1:M, i=1:N])
:(SMatrix{M,N,T}($(expr...)))
end

struct SDiagonal{N,T} <: StaticMatrix{N, N, T}
diag::SVector{N,T}
SDiagonal{N,T}(diag::SVector{N,T}) where {N,T} = new(diag)
end
diagtype{N,T}(::Type{SDiagonal{N,T}}) = SVector{N,T}
diagtype{N}(::Type{SDiagonal{N}}) = SVector{N}
diagtype(::Type{SDiagonal}) = SVector

# this is to deal with convert.jl
@inline (::Type{SD})(a::AbstractVector) where {SD <: SDiagonal} = SDiagonal(diagtype(SD)(a))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this should be SDiagonal(convert(diagtype(SD), a))

@inline (::Type{SD})(a::Tuple) where {SD <: SDiagonal} = SDiagonal(diagtype(SD)(a))
@inline (::Type{SDiagonal}){N,T}(a::SVector{N,T}) = SDiagonal{N,T}(a)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a::StaticVector{N,T}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this work? this calls the inner constructor


@generated function SDiagonal{N,T}(a::SMatrix{N,N,T})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a convert method to me

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In diagonal.jl this is a constructor.

expr = [:(a[$i,$i]) for i=1:N]
:(SDiagonal{N,T}($(expr...)))
end


convert{N,T}(::Type{SDiagonal{N,T}}, D::SDiagonal{N,T}) = D
convert{N,T}(::Type{SDiagonal{N,T}}, D::SDiagonal) = SDiagonal{N,T}(convert(SVector{N,T}, D.diag))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

D::SDiagonal{N} maybe?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


size{N}(D::SDiagonal{N}) = (N,N)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? There should be a fallback already.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


function size{N}(D::SDiagonal{N},d::Int)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a global StaticArray fallback for this also.

if d<1
throw(ArgumentError("dimension must be ≥ 1, got $d"))
end
return d<=2 ? N : 1
end

Base.@propagate_inbounds function getindex{N,T}(D::SDiagonal{N,T}, i::Int, j::Int)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@propagate_inbounds is needed if you want to defer bounds checking to an inner function. This should be simply @inline since you are doing your own @boundscheck and @inbounds here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@boundscheck checkbounds(D, i, j)
if i == j
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a perfect place to use ifelse instead of branching (will produce faster code)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also looks nicer

@inbounds return D.diag[i]
else
zero(T)
end
end

# avoid linear indexing?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice. Unfortunately most the static arrays internal code assumes a linear indexing style - we should change this, but it won't be easy...

Base.@propagate_inbounds function getindex{N,T}(D::SDiagonal{N,T}, k::Int)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good use of @propagate_inbounds.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, it has already been imported so you don't need the Base.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

i, j = ind2sub(size(D), k)
D[i,j]
end

ishermitian{T<:Real}(D::SDiagonal{T}) = true
ishermitian(D::SDiagonal) = all(D.diag .== real(D.diag))
issymmetric(D::SDiagonal) = true
isposdef(D::SDiagonal) = all(D.diag .> 0)

factorize(D::SDiagonal) = D

==(Da::SDiagonal, Db::SDiagonal) = Da.diag == Db.diag
-(A::SDiagonal) = SDiagonal(-A.diag)
+(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag + Db.diag)
-(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag - Db.diag)
-(A::SDiagonal, B::SMatrix) = eye(typeof(B))*A - B

*{T<:Number}(x::T, D::SDiagonal) = SDiagonal(x * D.diag)
*{T<:Number}(D::SDiagonal, x::T) = SDiagonal(D.diag * x)
/{T<:Number}(D::SDiagonal, x::T) = SDiagonal(D.diag / x)
*(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag .* Db.diag)
*(D::SDiagonal, V::SVector) = D.diag .* V
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

V::AbstractVector

*(V::SVector, D::SDiagonal) = D.diag .* V
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should V be a RowVector?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or rather: this is already a fallback https://github.com/JuliaLang/julia/blob/master/base/linalg/rowvector.jl#L181 which looks like it is no overhead

*(A::SMatrix, D::SDiagonal) = scalem(A,D.diag)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A::AbstractMatrix (and below)

*(D::SDiagonal, A::SMatrix) = scalem(D.diag,A)
\(D::SDiagonal, b::SVector) = D.diag .\ b
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this work for matrices as well?

also /(vector, diagonal)?


conj(D::SDiagonal) = SDiagonal(conj(D.diag))
transpose(D::SDiagonal) = D
ctranspose(D::SDiagonal) = conj(D)

diag(D::SDiagonal) = D.diag
trace(D::SDiagonal) = sum(D.diag)
det(D::SDiagonal) = prod(D.diag)
logdet{N,T<:Real}(D::SDiagonal{N,T}) = sum(log.(D.diag))
function logdet{N,T<:Complex}(D::SDiagonal{N,T}) #Make sure branch cut is correct
x = sum(log.(D.diag))
-pi<imag(x)<pi ? x : real(x)+(mod2pi(imag(x)+pi)-pi)*im
end

eye{N,T}(::Type{SDiagonal{N,T}}) = SDiagonal(ones(SVector{N,T}))

expm(D::SDiagonal) = SDiagonal(exp.(D.diag))
logm(D::SDiagonal) = SDiagonal(log.(D.diag))
sqrtm(D::SDiagonal) = SDiagonal(sqrt.(D.diag))

\(D::SDiagonal, B::SMatrix) = scalem(1 ./ D.diag, B)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

B::AbstractMatrix (and below)

/(B::SMatrix, D::SDiagonal) = scalem(1 ./ D.diag, B)
\(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Db.diag ./ Da.diag)
/(Da::SDiagonal, Db::SDiagonal) = SDiagonal(Da.diag ./ Db.diag )

function inv{N,T}(D::SDiagonal{N,T})
for i = 1:N
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly @inbounds ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we generally unroll such loops elsewhere...

It's also slightly annoying that the code to check if we should throw a singular exception will be slower than taking the inverse itself... but I guess Base.Diagonal does this too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I guess the divisions will still be the most expensive operation here.

if D.diag[i] == zero(T)
throw(SingularException(i))
end
end
SDiagonal(inv.(D.diag))
end

2 changes: 2 additions & 0 deletions src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export Scalar, SArray, SVector, SMatrix
export MArray, MVector, MMatrix
export FieldVector
export SizedArray, SizedVector, SizedMatrix
export SDiagonal

export Size, Length

Expand Down Expand Up @@ -79,6 +80,7 @@ include("MArray.jl")
include("MVector.jl")
include("MMatrix.jl")
include("SizedArray.jl")
include("SDiagonal.jl")

include("abstractarray.jl")
include("indexing.jl")
Expand Down
72 changes: 72 additions & 0 deletions test/SDiagonal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
@testset "SDiagonal" begin
@testset "Constructors" begin
@test SDiagonal{1,Int64}((1,)).diag === SVector{1,Int64}((1,))
@test SDiagonal{1,Float64}((1,)).diag === SVector{1,Float64}((1,))

@test SDiagonal{4,Float64}((1, 1.0, 1, 1)).diag.data === (1.0, 1.0, 1.0, 1.0)

# Bad input
@test_throws Exception SMatrix{1,Int}()
@test_throws Exception SMatrix{2,Int}((1,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@test_throws Exception is generally a very bad idea. It will pass if the code has a typo.


# From SMatrix
@test SDiagonal(SMatrix{2,2,Int}((1,2,3,4))).diag.data === (1,4)

end

@testset "Methods" begin

m = SDiagonal(@SVector [11, 12, 13, 14])
m2 = diagm([11, 12, 13, 14])

b = @SVector [2,-1,2,1]
b2 = Vector(b)

@test isimmutable(m) == true

@test m[1,1] === 11
@test m[2,2] === 12
@test m[3,3] === 13
@test m[4,4] === 14

for i in 1:4
for j in 1:4
i == j || @test m[i,j] === 0
end
end

@test_throws Exception m[5,5]

@test_throws Exception m[1,5]


@test size(m) === (4, 4)
@test size(typeof(m)) === (4, 4)
@test size(SDiagonal{4}) === (4, 4)

@test size(m, 1) === 4
@test size(m, 2) === 4
@test size(typeof(m), 1) === 4
@test size(typeof(m), 2) === 4

@test length(m) === 4*4

@test_throws Exception m[1] = 1

@test m*b == @SVector [22,-12,26,14]
@test m\b == m2\b
@test m*m == m2*m

@test ishermitian(m) == ishermitian(m2)
@test isposdef(m) == isposdef(m2)
@test issymmetric(m) == issymmetric(m2)

@test m' == m
@test 2m == m + m
@test 0m == m - m

@test m\m == eye(SDiagonal{4,Float64})


end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ include("FieldVector.jl")
include("Scalar.jl")
include("SUnitRange.jl")
include("SizedArray.jl")
include("SDiagonal.jl")

include("custom_types.jl")

include("core.jl")
Expand Down