Skip to content

Commit b96dd48

Browse files
committed
Make return type of map inferrable with heterogeneous arrays
Inference is not able to detect the element type automatically, but we can do it manually since we know promote_typejoin is used for widening. This is similar to the approach used for `broadcast` at #30485.
1 parent 6e91085 commit b96dd48

File tree

5 files changed

+73
-51
lines changed

5 files changed

+73
-51
lines changed

base/array.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -775,8 +775,13 @@ function collect(itr::Generator)
775775
return _array_for(et, itr.iter, isz)
776776
end
777777
v1, st = y
778-
arr = _array_for(typeof(v1), itr.iter, isz, shape)
779-
return collect_to_with_first!(arr, v1, itr, st)
778+
dest = _array_for(typeof(v1), itr.iter, isz, shape)
779+
# The typeassert gives inference a helping hand on the element type and dimensionality
780+
# (work-around for #28382)
781+
ElType = promote_typejoin_union(et)
782+
ElType′ = ElType <: Type ? Type : ElType
783+
RT = dest isa AbstractArray ? AbstractArray{<:ElType′, ndims(dest)} : Any
784+
collect_to_with_first!(dest, v1, itr, st)::RT
780785
end
781786
end
782787

base/broadcast.jl

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Module containing the broadcasting implementation.
88
module Broadcast
99

1010
using .Base.Cartesian
11-
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure,
11+
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, promote_typejoin_union, @pure,
1212
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
1313
import .Base: copy, copyto!, axes
1414
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, BroadcastFunction
@@ -713,50 +713,6 @@ eltypes(t::Tuple{Any}) = Tuple{_broadcast_getindex_eltype(t[1])}
713713
eltypes(t::Tuple{Any,Any}) = Tuple{_broadcast_getindex_eltype(t[1]), _broadcast_getindex_eltype(t[2])}
714714
eltypes(t::Tuple) = Tuple{_broadcast_getindex_eltype(t[1]), eltypes(tail(t)).types...}
715715

716-
function promote_typejoin_union(::Type{T}) where T
717-
if T === Union{}
718-
return Union{}
719-
elseif T isa UnionAll
720-
return Any # TODO: compute more precise bounds
721-
elseif T isa Union
722-
return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b))
723-
elseif T <: Tuple
724-
return typejoin_union_tuple(T)
725-
else
726-
return T
727-
end
728-
end
729-
730-
@pure function typejoin_union_tuple(T::Type)
731-
u = Base.unwrap_unionall(T)
732-
u isa Union && return typejoin(
733-
typejoin_union_tuple(Base.rewrap_unionall(u.a, T)),
734-
typejoin_union_tuple(Base.rewrap_unionall(u.b, T)))
735-
p = (u::DataType).parameters
736-
lr = length(p)::Int
737-
if lr == 0
738-
return Tuple{}
739-
end
740-
c = Vector{Any}(undef, lr)
741-
for i = 1:lr
742-
pi = p[i]
743-
U = Core.Compiler.unwrapva(pi)
744-
if U === Union{}
745-
ci = Union{}
746-
elseif U isa Union
747-
ci = typejoin(U.a, U.b)
748-
else
749-
ci = U
750-
end
751-
if i == lr && Core.Compiler.isvarargtype(pi)
752-
c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci}
753-
else
754-
c[i] = ci
755-
end
756-
end
757-
return Base.rewrap_unionall(Tuple{c...}, T)
758-
end
759-
760716
# Inferred eltype of result of broadcast(f, args...)
761717
combine_eltypes(f, args::Tuple) =
762718
promote_typejoin_union(Base._return_type(f, eltypes(args)))

base/promotion.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,50 @@ function promote_typejoin(@nospecialize(a), @nospecialize(b))
161161
end
162162
_promote_typesubtract(@nospecialize(a)) = typesplit(a, Union{Nothing, Missing})
163163

164+
function promote_typejoin_union(::Type{T}) where T
165+
if T === Union{}
166+
return Union{}
167+
elseif T isa UnionAll
168+
return Any # TODO: compute more precise bounds
169+
elseif T isa Union
170+
return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b))
171+
elseif T <: Tuple
172+
return typejoin_union_tuple(T)
173+
else
174+
return T
175+
end
176+
end
177+
178+
function typejoin_union_tuple(T::Type)
179+
@_pure_meta
180+
u = Base.unwrap_unionall(T)
181+
u isa Union && return typejoin(
182+
typejoin_union_tuple(Base.rewrap_unionall(u.a, T)),
183+
typejoin_union_tuple(Base.rewrap_unionall(u.b, T)))
184+
p = (u::DataType).parameters
185+
lr = length(p)::Int
186+
if lr == 0
187+
return Tuple{}
188+
end
189+
c = Vector{Any}(undef, lr)
190+
for i = 1:lr
191+
pi = p[i]
192+
U = Core.Compiler.unwrapva(pi)
193+
if U === Union{}
194+
ci = Union{}
195+
elseif U isa Union
196+
ci = typejoin(U.a, U.b)
197+
else
198+
ci = U
199+
end
200+
if i == lr && Core.Compiler.isvarargtype(pi)
201+
c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci}
202+
else
203+
c[i] = ci
204+
end
205+
end
206+
return Base.rewrap_unionall(Tuple{c...}, T)
207+
end
164208

165209
# Returns length, isfixed
166210
function full_va_len(p)

test/broadcast.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -991,10 +991,6 @@ end
991991
@test Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int},
992992
Vector{Union{Float64, Missing}}}) ==
993993
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
994-
@test isequal([1, 2] + [3.0, missing], [4.0, missing])
995-
@test Core.Compiler.return_type(+, Tuple{Vector{Int},
996-
Vector{Union{Float64, Missing}}}) ==
997-
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
998994
@test Core.Compiler.return_type(+, Tuple{Vector{Int},
999995
Vector{Union{Float64, Missing}}}) ==
1000996
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}

test/generic_map_tests.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,27 @@ function generic_map_tests(mapf, inplace_mapf=nothing)
5353
@test A == map(x->x*x*x, Float64[1:10...])
5454
@test A === B
5555
end
56+
57+
# Issue #28382: inferrability of map with Union eltype
58+
@test isequal(map(+, [1, 2], [3.0, missing]), [4.0, missing])
59+
@test Core.Compiler.return_type(map, Tuple{typeof(+), Vector{Int},
60+
Vector{Union{Float64, Missing}}}) ==
61+
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
62+
@test isequal(map(tuple, [1, 2], [3.0, missing]), [(1, 3.0), (2, missing)])
63+
@test Core.Compiler.return_type(map, Tuple{typeof(tuple), Vector{Int},
64+
Vector{Union{Float64, Missing}}}) ==
65+
Vector{<:Tuple{Int, Any}}
66+
# Check that corner cases do not throw an error
67+
@test isequal(map(x -> x === 1 ? nothing : x, [1, 2, missing]),
68+
[nothing, 2, missing])
69+
@test isequal(map(x -> x === 1 ? nothing : x, Any[1, 2, 3.0, missing]),
70+
[nothing, 2, 3, missing])
71+
@test map((x,y)->(x==1 ? 1.0 : x, y), [1, 2, 3], ["a", "b", "c"]) ==
72+
[(1.0, "a"), (2, "b"), (3, "c")]
73+
@test map(typeof, [iszero, isdigit]) == [typeof(iszero), typeof(isdigit)]
74+
@test map(typeof, [iszero, iszero]) == [typeof(iszero), typeof(iszero)]
75+
@test isequal(map(identity, Vector{<:Union{Int, Missing}}[[1, 2],[missing, 1]]),
76+
[[1, 2],[missing, 1]])
5677
end
5778

5879
function testmap_equivalence(mapf, f, c...)

0 commit comments

Comments
 (0)