diff --git a/src/FillArrays.jl b/src/FillArrays.jl index d2d678ad..f121831f 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -6,7 +6,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert, +, -, *, /, \, diff, sum, cumsum, maximum, minimum, sort, sort!, any, all, axes, isone, iterate, unique, allunique, permutedims, inv, copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero, - show, view, in + show, view, in, mapreduce import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!, dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular, AdjointAbsVec diff --git a/src/fillbroadcast.jl b/src/fillbroadcast.jl index 142422a9..44586f03 100644 --- a/src/fillbroadcast.jl +++ b/src/fillbroadcast.jl @@ -2,6 +2,84 @@ map(f::Function, r::AbstractFill) = Fill(f(getindex_value(r)), axes(r)) +function map(f::Function, vs::AbstractFill{<:Any,1}...) + stop = mapreduce(length, min, vs) + val = f(map(getindex_value, vs)...) + Fill(val, stop) +end + +function map(f::Function, rs::AbstractFill...) + if _maplinear(rs...) + map(f, map(vec, rs)...) + else + val = f(map(getindex_value, rs)...) + Fill(val, axes(first(rs))) + end +end + +function _maplinear(rs...) # tries to match Base's behaviour, could perhaps hook in more deeply + if any(ndims(r)==1 for r in rs) + return true + else + r1 = axes(first(rs)) + for r in rs + axes(r) == r1 || throw(DimensionMismatch( + "dimensions must match: a has dims $r1, b has dims $(axes(r))")) + end + return false + end +end + +### mapreduce + +if VERSION >= v"1.4" + # _InitialValue was introduced after 1.0, before 1.4, not sure exact version. + # Without these methods, some reductions will give an Array not a Fill. + + function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, ::Colon) + fval = f(getindex_value(A)) + out = fval + for _ in 2:length(A) + out = op(out, fval) + end + out + end + + function Base._mapreduce_dim(f, op, ::Base._InitialValue, A::AbstractFill, dims) + fval = f(getindex_value(A)) + red = *(ntuple(d -> d in dims ? size(A,d) : 1, ndims(A))...) + out = fval + for _ in 2:red + out = op(out, fval) + end + Fill(out, ntuple(d -> d in dims ? Base.OneTo(1) : axes(A,d), ndims(A))) + end + +end +if VERSION >= v"1.2" # Vararg mapreduce was added in Julia 1.2 + + function mapreduce(f, op, A::AbstractFill, B::AbstractFill; kw...) + val(_...) = f(getindex_value(A), getindex_value(B)) + reduce(op, map(val, A, B); kw...) + end + + # These are particularly useful because mapreduce(*, +, A, B; dims) is slow in Base, + # but can be re-written as some mapreduce(g, +, C; dims) which is fast. + + function mapreduce(f, op, A::AbstractFill, B::AbstractArray, Cs::AbstractArray...; kw...) + g(b, cs...) = f(getindex_value(A), b, cs...) + mapreduce(g, op, B, Cs...; kw...) + end + function mapreduce(f, op, A::AbstractArray, B::AbstractFill, Cs::AbstractArray...; kw...) + h(a, cs...) = f(a, getindex_value(B), cs...) + mapreduce(h, op, A, Cs...; kw...) + end + function mapreduce(f, op, A::AbstractFill, B::AbstractFill, Cs::AbstractArray...; kw...) + gh(cs...) = f(getindex_value(A), getindex_value(B), cs...) + mapreduce(gh, op, Cs...; kw...) + end + +end ### Unary broadcasting @@ -165,4 +243,4 @@ end broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} = Fill(op(getindex_value(r),x), axes(r)) broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {T,N} = Fill(op(x, getindex_value(r)), axes(r)) broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = Fill(op(getindex_value(r),x[]), axes(r)) -broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = Fill(op(x[], getindex_value(r)), axes(r)) \ No newline at end of file +broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = Fill(op(x[], getindex_value(r)), axes(r)) diff --git a/test/runtests.jl b/test/runtests.jl index e3c69d24..3634cfd4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -536,6 +536,7 @@ end @testset "Cumsum and diff" begin @test sum(Fill(3,10)) ≡ 30 + @test reduce(+, Fill(3,10)) ≡ 30 @test sum(x -> x + 1, Fill(3,10)) ≡ 40 @test cumsum(Fill(3,10)) ≡ 3:3:30 @@ -758,15 +759,63 @@ end end @testset "map" begin - x = Ones(5) - @test map(exp,x) === Fill(exp(1.0),5) - @test map(isone,x) === Fill(true,5) + x1 = Ones(5) + @test map(exp,x1) === Fill(exp(1.0),5) + @test map(isone,x1) === Fill(true,5) - x = Zeros(5) - @test map(exp,x) === exp.(x) + x0 = Zeros(5) + @test map(exp,x0) === exp.(x0) - x = Fill(2,5,3) - @test map(exp,x) === Fill(exp(2),5,3) + x2 = Fill(2,5,3) + @test map(exp,x2) === Fill(exp(2),5,3) + + @test map(+, x1, x2) === Fill(3.0, 5) + @test map(+, x2, x2) === x2 .+ x2 + @test_throws DimensionMismatch map(+, x2', x2) +end + +@testset "mapreduce" begin + x = rand(3, 4) + y = fill(1.0, 3, 4) + Y = Fill(1.0, 3, 4) + O = Ones(3, 4) + + @test mapreduce(exp, +, Y) == mapreduce(exp, +, y) + @test mapreduce(exp, +, Y; dims=2) == mapreduce(exp, +, y; dims=2) + @test mapreduce(identity, +, Y) == sum(y) == sum(Y) + @test mapreduce(identity, +, Y, dims=1) == sum(y, dims=1) == sum(Y, dims=1) + + if VERSION >= v"1.4" + @test mapreduce(exp, +, Y; dims=(1,), init=5.0) == mapreduce(exp, +, y; dims=(1,), init=5.0) + end + + if VERSION >= v"1.2" # Vararg mapreduce was added in Julia 1.2 + + # Two arrays + @test mapreduce(*, +, x, Y) == mapreduce(*, +, x, y) + @test mapreduce(*, +, Y, x) == mapreduce(*, +, y, x) + @test mapreduce(*, +, x, O) == mapreduce(*, +, x, y) + @test mapreduce(*, +, Y, O) == mapreduce(*, +, y, y) + + f2(x,y) = 1 + x/y + op2(x,y) = x^2 + 3y + @test mapreduce(f2, op2, x, Y) == mapreduce(f2, op2, x, y) + + if VERSION >= v"1.4" + @test mapreduce(f2, op2, x, Y, dims=1, init=5.0) == mapreduce(f2, op2, x, y, dims=1, init=5.0) + @test mapreduce(f2, op2, Y, x, dims=1, init=5.0) == mapreduce(f2, op2, y, x, dims=1, init=5.0) + @test mapreduce(f2, op2, x, O, dims=1, init=5.0) == mapreduce(f2, op2, x, y, dims=1, init=5.0) + @test mapreduce(f2, op2, Y, O, dims=1, init=5.0) == mapreduce(f2, op2, y, y, dims=1, init=5.0) + end + + # More than two + @test mapreduce(+, +, x, Y, x) == mapreduce(+, +, x, y, x) + @test mapreduce(+, +, Y, x, x) == mapreduce(+, +, y, x, x) + @test mapreduce(+, +, x, O, Y) == mapreduce(+, +, x, y, y) + @test mapreduce(+, +, Y, O, Y) == mapreduce(+, +, y, y, y) + @test mapreduce(+, +, Y, O, Y, x) == mapreduce(+, +, y, y, y, x) + + end end @testset "Offset indexing" begin