From 0a3220b4c35417ae9bc5547afbf3d7cd6e4535d8 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Thu, 10 Dec 2020 14:27:13 +0000 Subject: [PATCH 1/3] view returns a Fill (#84) --- Project.toml | 2 +- src/FillArrays.jl | 35 ++++++++++++++++++++++++----------- test/runtests.jl | 7 +++++++ 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index cbb0c6dc..232dfbe4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FillArrays" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.10.1" +version = "0.11" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/FillArrays.jl b/src/FillArrays.jl index e12314e3..893e5ba7 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -6,7 +6,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, +, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!, any, all, axes, isone, iterate, unique, allunique, permutedims, inv, copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero, - show + show, view import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular, AdjointAbsVec @@ -152,7 +152,9 @@ convert(::Type{Fill}, arr::AbstractArray{T}) where T = Fill{T}(unique_value(arr) convert(::Type{Fill{T}}, arr::AbstractArray) where T = Fill{T}(unique_value(arr), axes(arr)) convert(::Type{Fill{T,N}}, arr::AbstractArray{<:Any,N}) where {T,N} = Fill{T,N}(unique_value(arr), axes(arr)) convert(::Type{Fill{T,N,Axes}}, arr::AbstractArray{<:Any,N}) where {T,N,Axes} = Fill{T,N,Axes}(unique_value(arr), axes(arr)) -convert(::Type{T}, F::T) where T<:Fill = F # ambiguity fix +# ambiguity fix +convert(::Type{Fill}, arr::Fill{T}) where T = Fill{T}(unique_value(arr), axes(arr)) +convert(::Type{T}, F::T) where T<:Fill = F @@ -211,14 +213,14 @@ reshape(parent::AbstractFill, dims::Integer...) = reshape(parent, dims) reshape(parent::AbstractFill, dims::Union{Int,Colon}...) = reshape(parent, dims) reshape(parent::AbstractFill, dims::Union{Integer,Colon}...) = reshape(parent, dims) -reshape(parent::AbstractFill, dims::Tuple{Vararg{Union{Integer,Colon}}}) = +reshape(parent::AbstractFill, dims::Tuple{Vararg{Union{Integer,Colon}}}) = fill_reshape(parent, Base._reshape_uncolon(parent, dims)...) -reshape(parent::AbstractFill, dims::Tuple{Vararg{Union{Int,Colon}}}) = +reshape(parent::AbstractFill, dims::Tuple{Vararg{Union{Int,Colon}}}) = fill_reshape(parent, Base._reshape_uncolon(parent, dims)...) -reshape(parent::AbstractFill, shp::Tuple{Union{Integer,Base.OneTo}, Vararg{Union{Integer,Base.OneTo}}}) = - reshape(parent, Base.to_shape(shp)) -reshape(parent::AbstractFill, dims::Dims) = Base._reshape(parent, dims) -reshape(parent::AbstractFill, dims::Tuple{Integer, Vararg{Integer}}) = Base._reshape(parent, dims) +reshape(parent::AbstractFill, shp::Tuple{Union{Integer,Base.OneTo}, Vararg{Union{Integer,Base.OneTo}}}) = + reshape(parent, Base.to_shape(shp)) +reshape(parent::AbstractFill, dims::Dims) = Base._reshape(parent, dims) +reshape(parent::AbstractFill, dims::Tuple{Integer, Vararg{Integer}}) = Base._reshape(parent, dims) Base._reshape(parent::AbstractFill, dims::Dims) = fill_reshape(parent, dims...) Base._reshape(parent::AbstractFill, dims::Tuple{Integer,Vararg{Integer}}) = fill_reshape(parent, dims...) # Resolves ambiguity error with `_reshape(v::AbstractArray{T, 1}, dims::Tuple{Int})` @@ -344,7 +346,7 @@ for f in (:triu, :triu!, :tril, :tril!) end -Base.replace_in_print_matrix(A::RectDiagonal, i::Integer, j::Integer, s::AbstractString) = +Base.replace_in_print_matrix(A::RectDiagonal, i::Integer, j::Integer, s::AbstractString) = i == j ? s : Base.replace_with_centered_mark(s) @@ -378,7 +380,7 @@ end Eye(n::Integer, m::Integer) = RectDiagonal(Ones(min(n,m)), n, m) Eye{T}(n::Integer, m::Integer) where T = RectDiagonal{T}(Ones{T}(min(n,m)), n, m) -function Eye{T}((a,b)::NTuple{2,AbstractUnitRange{Int}}) where T +function Eye{T}((a,b)::NTuple{2,AbstractUnitRange{Int}}) where T ab = length(a) ≤ length(b) ? a : b RectDiagonal{T}(Ones{T}((ab,)), (a,b)) end @@ -600,7 +602,7 @@ if VERSION ≥ v"1.5" Base.array_summary(io::IO, a::Fill{T}, inds::Tuple{Vararg{Base.OneTo}}) where T = print(io, Base.dims2string(length.(inds)), " Fill{$T}") Base.array_summary(io::IO, a::Eye{T}, inds::Tuple{Vararg{Base.OneTo}}) where T = - print(io, Base.dims2string(length.(inds)), " Eye{$T}") + print(io, Base.dims2string(length.(inds)), " Eye{$T}") end Base.show(io::IO, ::MIME"text/plain", x::Union{Eye,AbstractFill}) = show(io, x) @@ -612,4 +614,15 @@ Base.show(io::IO, ::MIME"text/plain", x::Union{Eye,AbstractFill}) = show(io, x) getindex_value(a::LinearAlgebra.AdjOrTrans) = getindex_value(parent(a)) getindex_value(a::SubArray) = getindex_value(parent(a)) + +## +# view +## + +Base.@propagate_inbounds view(A::AbstractFill, kr::AbstractVector{Bool}) = getindex(F, kr) +Base.@propagate_inbounds view(A::AbstractFill{<:Any,N}, I::Vararg{Union{Real, AbstractArray}, N}) where N = + getindex(A, I...) +Base.@propagate_inbounds view(A::AbstractFill{<:Any,N}, I::Vararg{Real, N}) where N = + Base.invoke(view, Tuple{AbstractArray,Vararg{Any,N}}, A, I...) + end # module diff --git a/test/runtests.jl b/test/runtests.jl index 74f01d69..4baecd69 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1140,9 +1140,16 @@ end end @testset "FillArray interface" begin + @testset "SubArray" begin + a = Fill(2.0,5) + v = SubArray(a,(1:2,)) + @test FillArrays.getindex_value(v) == FillArrays.unique_value(v) == 2.0 + @test convert(Fill, v) ≡ Fill(2.0,2) + end @testset "views" begin a = Fill(2.0,5) v = view(a,1:2) + @test v isa Fill @test FillArrays.getindex_value(v) == FillArrays.unique_value(v) == 2.0 @test convert(Fill, v) ≡ Fill(2.0,2) end From 048622aea33a0969594253ccffce587c6dd2a10b Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Thu, 10 Dec 2020 21:18:19 +0000 Subject: [PATCH 2/3] add tests --- src/FillArrays.jl | 3 ++- test/runtests.jl | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/FillArrays.jl b/src/FillArrays.jl index 893e5ba7..d3b1a04c 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -619,7 +619,8 @@ getindex_value(a::SubArray) = getindex_value(parent(a)) # view ## -Base.@propagate_inbounds view(A::AbstractFill, kr::AbstractVector{Bool}) = getindex(F, kr) +Base.@propagate_inbounds view(A::AbstractFill{<:Any,N}, kr::AbstractArray{Bool,N}) where N = getindex(A, kr) +Base.@propagate_inbounds view(A::AbstractFill{<:Any,1}, kr::AbstractVector{Bool}) = getindex(A, kr) Base.@propagate_inbounds view(A::AbstractFill{<:Any,N}, I::Vararg{Union{Real, AbstractArray}, N}) where N = getindex(A, I...) Base.@propagate_inbounds view(A::AbstractFill{<:Any,N}, I::Vararg{Real, N}) where N = diff --git a/test/runtests.jl b/test/runtests.jl index 4baecd69..e6900b11 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1146,12 +1146,21 @@ end @test FillArrays.getindex_value(v) == FillArrays.unique_value(v) == 2.0 @test convert(Fill, v) ≡ Fill(2.0,2) end + @testset "views" begin a = Fill(2.0,5) v = view(a,1:2) @test v isa Fill @test FillArrays.getindex_value(v) == FillArrays.unique_value(v) == 2.0 @test convert(Fill, v) ≡ Fill(2.0,2) + @test view(a,1) isa SubArray + end + + @testset "view with bool" begin + a = Fill(2.0,5) + @test a[[true,false,false,true,false]] ≡ view(a,[true,false,false,true,false]) + a = Fill(2.0,2,2) + @test a[[true false; false true]] ≡ view(a, [true false; false true]) end @testset "adjtrans" begin From 64adc1094c897e6eed2d1973d19c9ca59ec54998 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Fri, 18 Dec 2020 08:11:25 +0000 Subject: [PATCH 3/3] size -> axes in broadcasted --- src/fillbroadcast.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 908230e4..142422a9 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -162,7 +162,7 @@ function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::A return broadcasted(*, a, _broadcast_getindex_value(b)) end -broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} = Fill(op(getindex_value(r),x), size(r)) -broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {T,N} = Fill(op(x, getindex_value(r)), size(r)) -broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = Fill(op(getindex_value(r),x[]), size(r)) -broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = Fill(op(x[], getindex_value(r)), size(r)) \ No newline at end of file +broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} = Fill(op(getindex_value(r),x), axes(r)) +broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {T,N} = Fill(op(x, getindex_value(r)), axes(r)) +broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = Fill(op(getindex_value(r),x[]), axes(r)) +broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = Fill(op(x[], getindex_value(r)), axes(r)) \ No newline at end of file