Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -775,8 +775,13 @@ function collect(itr::Generator)
return _array_for(et, itr.iter, isz)
end
v1, st = y
arr = _array_for(typeof(v1), itr.iter, isz, shape)
return collect_to_with_first!(arr, v1, itr, st)
dest = _array_for(typeof(v1), itr.iter, isz, shape)
# The typeassert gives inference a helping hand on the element type and dimensionality
# (work-around for #28382)
ElType = promote_typejoin_union(et)
ElType′ = ElType <: Type ? Type : ElType
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this should be part of the et computation at the top of the function (so that the total function here is more type-stable), no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I did the same for broadcast (see here) so that would make sense, even if that changes the return type in corner cases. I've just pushed a commit to do this.

RT = dest isa AbstractArray ? AbstractArray{<:ElType′, ndims(dest)} : Any
collect_to_with_first!(dest, v1, itr, st)::RT
end
end

Expand Down
46 changes: 1 addition & 45 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Module containing the broadcasting implementation.
module Broadcast

using .Base.Cartesian
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure,
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, promote_typejoin_union, @pure,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: copy, copyto!, axes
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, BroadcastFunction
Expand Down Expand Up @@ -713,50 +713,6 @@ eltypes(t::Tuple{Any}) = Tuple{_broadcast_getindex_eltype(t[1])}
eltypes(t::Tuple{Any,Any}) = Tuple{_broadcast_getindex_eltype(t[1]), _broadcast_getindex_eltype(t[2])}
eltypes(t::Tuple) = Tuple{_broadcast_getindex_eltype(t[1]), eltypes(tail(t)).types...}

function promote_typejoin_union(::Type{T}) where T
if T === Union{}
return Union{}
elseif T isa UnionAll
return Any # TODO: compute more precise bounds
elseif T isa Union
return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b))
elseif T <: Tuple
return typejoin_union_tuple(T)
else
return T
end
end

@pure function typejoin_union_tuple(T::Type)
u = Base.unwrap_unionall(T)
u isa Union && return typejoin(
typejoin_union_tuple(Base.rewrap_unionall(u.a, T)),
typejoin_union_tuple(Base.rewrap_unionall(u.b, T)))
p = (u::DataType).parameters
lr = length(p)::Int
if lr == 0
return Tuple{}
end
c = Vector{Any}(undef, lr)
for i = 1:lr
pi = p[i]
U = Core.Compiler.unwrapva(pi)
if U === Union{}
ci = Union{}
elseif U isa Union
ci = typejoin(U.a, U.b)
else
ci = U
end
if i == lr && Core.Compiler.isvarargtype(pi)
c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci}
else
c[i] = ci
end
end
return Base.rewrap_unionall(Tuple{c...}, T)
end

# Inferred eltype of result of broadcast(f, args...)
combine_eltypes(f, args::Tuple) =
promote_typejoin_union(Base._return_type(f, eltypes(args)))
Expand Down
44 changes: 44 additions & 0 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,50 @@ function promote_typejoin(@nospecialize(a), @nospecialize(b))
end
_promote_typesubtract(@nospecialize(a)) = typesplit(a, Union{Nothing, Missing})

function promote_typejoin_union(::Type{T}) where T
if T === Union{}
return Union{}
elseif T isa UnionAll
return Any # TODO: compute more precise bounds
elseif T isa Union
return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b))
elseif T <: Tuple
return typejoin_union_tuple(T)
else
return T
end
end

function typejoin_union_tuple(T::Type)
@_pure_meta
u = Base.unwrap_unionall(T)
u isa Union && return typejoin(
typejoin_union_tuple(Base.rewrap_unionall(u.a, T)),
typejoin_union_tuple(Base.rewrap_unionall(u.b, T)))
p = (u::DataType).parameters
lr = length(p)::Int
if lr == 0
return Tuple{}
end
c = Vector{Any}(undef, lr)
for i = 1:lr
pi = p[i]
U = Core.Compiler.unwrapva(pi)
if U === Union{}
ci = Union{}
elseif U isa Union
ci = typejoin(U.a, U.b)
else
ci = U
end
if i == lr && Core.Compiler.isvarargtype(pi)
c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci}
else
c[i] = ci
end
end
return Base.rewrap_unionall(Tuple{c...}, T)
end

# Returns length, isfixed
function full_va_len(p)
Expand Down
4 changes: 0 additions & 4 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -991,10 +991,6 @@ end
@test Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
@test isequal([1, 2] + [3.0, missing], [4.0, missing])
@test Core.Compiler.return_type(+, Tuple{Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
@test Core.Compiler.return_type(+, Tuple{Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
Expand Down
21 changes: 21 additions & 0 deletions test/generic_map_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,27 @@ function generic_map_tests(mapf, inplace_mapf=nothing)
@test A == map(x->x*x*x, Float64[1:10...])
@test A === B
end

# Issue #28382: inferrability of map with Union eltype
@test isequal(map(+, [1, 2], [3.0, missing]), [4.0, missing])
@test Core.Compiler.return_type(map, Tuple{typeof(+), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
@test isequal(map(tuple, [1, 2], [3.0, missing]), [(1, 3.0), (2, missing)])
@test Core.Compiler.return_type(map, Tuple{typeof(tuple), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Vector{<:Tuple{Int, Any}}
# Check that corner cases do not throw an error
@test isequal(map(x -> x === 1 ? nothing : x, [1, 2, missing]),
[nothing, 2, missing])
@test isequal(map(x -> x === 1 ? nothing : x, Any[1, 2, 3.0, missing]),
[nothing, 2, 3, missing])
@test map((x,y)->(x==1 ? 1.0 : x, y), [1, 2, 3], ["a", "b", "c"]) ==
[(1.0, "a"), (2, "b"), (3, "c")]
@test map(typeof, [iszero, isdigit]) == [typeof(iszero), typeof(isdigit)]
@test map(typeof, [iszero, iszero]) == [typeof(iszero), typeof(iszero)]
@test isequal(map(identity, Vector{<:Union{Int, Missing}}[[1, 2],[missing, 1]]),
[[1, 2],[missing, 1]])
end

function testmap_equivalence(mapf, f, c...)
Expand Down