Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@ as well as identity matrices. This package exports the following types:


The primary purpose of this package is to present a unified way of constructing
matrices. For example, to construct a 5-by-5 `CLArray` of all zeros, one would use
```julia
julia> CLArray(Zeros(5,5))
```
Because `Zeros` is lazy, this can be accomplished on the GPU with no memory transfer.
Similarly, to construct a 5-by-5 `BandedMatrix` of all zeros with bandwidths `(1,2)`, one would use
matrices.
For example, to construct a 5-by-5 `BandedMatrix` of all zeros with bandwidths `(1,2)`, one would use
```julia
julia> BandedMatrix(Zeros(5,5), (1, 2))
```
Expand Down
5 changes: 4 additions & 1 deletion src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape
import Statistics: mean, std, var, cov, cor


export Zeros, Ones, Fill, Eye, Trues, Falses
export Zeros, Ones, Fill, Eye, Trues, Falses, OneElement

import Base: oneto

Expand Down Expand Up @@ -262,6 +262,7 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one))
@inline $Typ{T,N}(A::AbstractArray{V,N}) where{T,V,N} = $Typ{T,N}(size(A))
@inline $Typ{T}(A::AbstractArray) where{T} = $Typ{T}(size(A))
@inline $Typ(A::AbstractArray) = $Typ{eltype(A)}(A)
@inline $Typ(::Type{T}, m...) where T = $Typ{T}(m...)

@inline axes(Z::$Typ) = Z.axes
@inline size(Z::$Typ) = length.(Z.axes)
Expand Down Expand Up @@ -717,4 +718,6 @@ Base.@propagate_inbounds function view(A::AbstractFill{<:Any,N}, I::Vararg{Real,
fillsimilar(A)
end

include("oneelement.jl")

end # module
1 change: 0 additions & 1 deletion src/fillalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ end
*(a::ZerosMatrix, b::AbstractMatrix) = mult_zeros(a, b)
*(a::AbstractMatrix, b::ZerosVector) = mult_zeros(a, b)
*(a::AbstractMatrix, b::ZerosMatrix) = mult_zeros(a, b)
*(a::ZerosVector, b::AbstractVector) = mult_zeros(a, b)
*(a::ZerosMatrix, b::AbstractVector) = mult_zeros(a, b)
*(a::AbstractVector, b::ZerosMatrix) = mult_zeros(a, b)

Expand Down
27 changes: 27 additions & 0 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
OneElement(val, ind, axes) <: AbstractArray
Extremely simple `struct` used for the gradient of scalar `getindex`.
"""
struct OneElement{T,N,I,A} <: AbstractArray{T,N}
val::T
ind::I
axes::A
OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
end

OneElement(val, inds::NTuple{N,Int}, sz::NTuple{N,Integer}) where N = OneElement(val, inds, oneto.(sz))
OneElement(val, inds::Int, sz::Int) = OneElement(val, (inds,), (sz,))
OneElement(inds::Int, sz::Int) = OneElement(1, inds, sz)
OneElement{T}(inds::Int, sz::Int) where T = OneElement(one(T), inds, sz)

Base.size(A::OneElement) = map(length, A.axes)
Base.axes(A::OneElement) = A.axes
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))

Base.replace_in_print_matrix(o::OneElement{<:Any,2}, k::Integer, j::Integer, s::AbstractString) =
o.ind == (k,j) ? s : Base.replace_with_centered_mark(s)

function Base.setindex(A::Zeros{T,N}, v, kj::Vararg{Int,N}) where {T,N}
@boundscheck checkbounds(A, kj...)
OneElement(convert(T, v), kj, axes(A))
end
20 changes: 17 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ include("infinitearrays.jl")

for T in (Int, Float64)
Z = $Typ{T}(5)
@test $Typ(T, 5) ≡ Z
@test eltype(Z) == T
@test Array(Z) == $funcs(T,5)
@test Array{T}(Z) == $funcs(T,5)
Expand All @@ -34,6 +35,7 @@ include("infinitearrays.jl")
@test $Typ(2ones(T,5)) == Z

Z = $Typ{T}(5, 5)
@test $Typ(T, 5, 5) ≡ Z
@test eltype(Z) == T
@test Array(Z) == $funcs(T,5,5)
@test Array{T}(Z) == $funcs(T,5,5)
Expand Down Expand Up @@ -508,9 +510,9 @@ end
@test_throws MethodError [1,2,3]*Zeros(3) # Not defined for [1,2,3]*[0,0,0] either

@testset "Check multiplication by Adjoint vectors works as expected." begin
@test randn(4, 3)' * Zeros(4) === Zeros(3)
@test randn(4)' * Zeros(4) === zero(Float64)
@test [1, 2, 3]' * Zeros{Int}(3) === zero(Int)
@test randn(4, 3)' * Zeros(4) Zeros(3)
@test randn(4)' * Zeros(4) ≡ transpose(randn(4)) * Zeros(4) ≡ zero(Float64)
@test [1, 2, 3]' * Zeros{Int}(3) zero(Int)
@test [SVector(1,2)', SVector(2,3)', SVector(3,4)']' * Zeros{Int}(3) === SVector(0,0)
@test_throws DimensionMismatch randn(4)' * Zeros(3)
@test Zeros(5)' * randn(5,3) ≡ Zeros(5)'*Zeros(5,3) ≡ Zeros(5)'*Ones(5,3) ≡ Zeros(3)'
Expand Down Expand Up @@ -1486,4 +1488,16 @@ end
@test Zeros(5,5) .+ D isa Diagonal
f = (x,y) -> x+1
@test f.(D, Zeros(5,5)) isa Matrix
end

@testset "OneElement" begin
e₁ = OneElement(2, 5)
@test e₁ == [0,1,0,0,0]

e₁ = OneElement{Float64}(2, 5)
@test e₁ == [0,1,0,0,0]

@test Base.setindex(Zeros(5), 2, 2) ≡ OneElement(2.0, 2, 5)
@test Base.setindex(Zeros(5,3), 2, 2, 3) ≡ OneElement(2.0, (2,3), (5,3))
@test_throws BoundsError Base.setindex(Zeros(5), 2, 6)
end