Skip to content

Commit f1c9ce7

Browse files
committed
fix #37610, allow constant prop on signatures with unions
1 parent 2e01da6 commit f1c9ce7

File tree

5 files changed

+47
-8
lines changed

5 files changed

+47
-8
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
2929
fullmatch = Bool[]
3030
if splitunions
3131
splitsigs = switchtupleunion(atype)
32+
split_argtypes = switchtupleunion(argtypes)
3233
applicable = Any[]
34+
# arrays like `argtypes`, including constants, for each match
35+
applicable_argtypes = Vector{Any}[]
3336
infos = MethodMatchInfo[]
34-
for sig_n in splitsigs
37+
for j in 1:length(splitsigs)
38+
sig_n = splitsigs[j]
3539
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
3640
if mt === nothing
3741
add_remark!(interp, sv, "Could not identify method table for call")
@@ -45,6 +49,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
4549
end
4650
push!(infos, MethodMatchInfo(matches))
4751
append!(applicable, matches)
52+
for _ in 1:length(matches)
53+
push!(applicable_argtypes, split_argtypes[j])
54+
end
4855
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
4956
thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
5057
found = false
@@ -80,6 +87,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
8087
info = MethodMatchInfo(matches)
8188
applicable = matches.matches
8289
valid_worlds = matches.valid_worlds
90+
applicable_argtypes = nothing
8391
end
8492
update_valid_age!(sv, valid_worlds)
8593
applicable = applicable::Array{Any,1}
@@ -136,6 +144,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
136144
if this_rt !== Bottom
137145
if nonbot === 0
138146
nonbot = i
147+
elseif nonbot === -1
148+
elseif method === (applicable[nonbot]::MethodMatch).method
149+
# another entry from the same method, due to union splitting
139150
else
140151
nonbot = -1
141152
end
@@ -147,12 +158,23 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
147158
# try constant propagation if only 1 method is inferred to non-Bottom
148159
# this is in preparation for inlining, or improving the return result
149160
is_unused = call_result_unused(sv)
150-
if nonbot > 0 && seen == napplicable && (!edgecycle || !is_unused) && isa(rettype, Type) && InferenceParams(interp).ipo_constant_propagation
161+
if nonbot > 0 && seen == napplicable && (!edgecycle || !is_unused) &&
162+
(isa(rettype, Type) || isa(rettype, PartialStruct)) && InferenceParams(interp).ipo_constant_propagation
151163
# if there's a possibility we could constant-propagate a better result
152164
# (hopefully without doing too much work), try to do that now
153165
# TODO: it feels like this could be better integrated into abstract_call_method / typeinf_edge
154-
const_rettype = abstract_call_method_with_const_args(interp, rettype, f, argtypes, applicable[nonbot]::MethodMatch, sv, edgecycle)
155-
if const_rettype rettype
166+
const_rettype = Bottom
167+
for i in 1:napplicable
168+
mm = applicable[i]::MethodMatch
169+
if i === nonbot || mm.method === (applicable[nonbot]::MethodMatch).method
170+
this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i]
171+
one_rettype = abstract_call_method_with_const_args(interp, rettype, f, this_argtypes, mm, sv, edgecycle)
172+
const_rettype = tmerge(const_rettype, one_rettype)
173+
const_rettype rettype || (const_rettype = Any)
174+
const_rettype === Any && break
175+
end
176+
end
177+
if const_rettype !== Any
156178
# use the better result, if it's a refinement of rettype
157179
rettype = const_rettype
158180
end

base/compiler/typeutils.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,16 @@ function switchtupleunion(@nospecialize(ty))
133133
return _switchtupleunion(Any[tparams...], length(tparams), [], ty)
134134
end
135135

136+
switchtupleunion(t::Vector{Any}) = _switchtupleunion(t, length(t), [], nothing)
137+
136138
function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt))
137139
if i == 0
138-
tpl = rewrap_unionall(Tuple{t...}, origt)
139-
push!(tunion, tpl)
140+
if origt === nothing
141+
push!(tunion, t)
142+
else
143+
tpl = rewrap_unionall(Tuple{t...}, origt)
144+
push!(tunion, tpl)
145+
end
140146
else
141147
ti = t[i]
142148
if isa(ti, Union)

base/namedtuple.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ firstindex(t::NamedTuple) = 1
114114
lastindex(t::NamedTuple) = nfields(t)
115115
getindex(t::NamedTuple, i::Int) = getfield(t, i)
116116
getindex(t::NamedTuple, i::Symbol) = getfield(t, i)
117-
indexed_iterate(t::NamedTuple, i::Int, state=1) = (getfield(t, i), i+1)
117+
indexed_iterate(t::NamedTuple, i::Int, state=1) = (@_inline_meta; (getfield(t, i), i+1))
118118
isempty(::NamedTuple{()}) = true
119119
isempty(::NamedTuple) = false
120120
empty(::NamedTuple) = NamedTuple()

base/pair.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Pair, =>
4747

4848
eltype(p::Type{Pair{A, B}}) where {A, B} = Union{A, B}
4949
iterate(p::Pair, i=1) = i > 2 ? nothing : (getfield(p, i), i + 1)
50-
indexed_iterate(p::Pair, i::Int, state=1) = (getfield(p, i), i + 1)
50+
indexed_iterate(p::Pair, i::Int, state=1) = (@_inline_meta; (getfield(p, i), i + 1))
5151

5252
hash(p::Pair, h::UInt) = hash(p.second, hash(p.first, h))
5353

test/compiler/inference.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2813,3 +2813,14 @@ f_apply_cglobal(args...) = cglobal(args...)
28132813
@test Core.Compiler.return_type(f_apply_cglobal, Tuple{Any, Vararg{Type{Int}}}) == Ptr
28142814
@test Core.Compiler.return_type(f_apply_cglobal, Tuple{Any, Type{Int}, Vararg{Type{Int}}}) == Ptr{Int}
28152815
@test Core.Compiler.return_type(f_apply_cglobal, Tuple{Any, Type{Int}, Type{Int}, Vararg{Type{Int}}}) == Union{}
2816+
2817+
# issue #37610
2818+
function f37610(a, i)
2819+
y = iterate(a, i)
2820+
if y !== nothing
2821+
(k, v), st = y
2822+
return k, v
2823+
end
2824+
return y
2825+
end
2826+
@test Base.return_types(f37610, (typeof(("foo" => "bar", "baz" => nothing)), Int)) == Any[Union{Nothing, Tuple{String, Union{Nothing, String}}}]

0 commit comments

Comments
 (0)