Skip to content

Commit 799136d

Browse files
authored
compiler: general refactor (#41633)
Separated from compiler-plugin prototyping.
1 parent 877c0a5 commit 799136d

File tree

7 files changed

+147
-118
lines changed

7 files changed

+147
-118
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 110 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -35,73 +35,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
3535
add_remark!(interp, sv, "Skipped call in throw block")
3636
return CallMeta(Any, false)
3737
end
38-
valid_worlds = WorldRange()
39-
# NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type
40-
splitunions = 1 < unionsplitcost(argtypes) <= InferenceParams(interp).MAX_UNION_SPLITTING
41-
mts = Core.MethodTable[]
42-
fullmatch = Bool[]
43-
if splitunions
44-
split_argtypes = switchtupleunion(argtypes)
45-
applicable = Any[]
46-
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
47-
infos = MethodMatchInfo[]
48-
for arg_n in split_argtypes
49-
sig_n = argtypes_to_type(arg_n)
50-
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
51-
if mt === nothing
52-
add_remark!(interp, sv, "Could not identify method table for call")
53-
return CallMeta(Any, false)
54-
end
55-
mt = mt::Core.MethodTable
56-
matches = findall(sig_n, method_table(interp); limit=max_methods)
57-
if matches === missing
58-
add_remark!(interp, sv, "For one of the union split cases, too many methods matched")
59-
return CallMeta(Any, false)
60-
end
61-
push!(infos, MethodMatchInfo(matches))
62-
for m in matches
63-
push!(applicable, m)
64-
push!(applicable_argtypes, arg_n)
65-
end
66-
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
67-
thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
68-
found = false
69-
for (i, mt′) in enumerate(mts)
70-
if mt′ === mt
71-
fullmatch[i] &= thisfullmatch
72-
found = true
73-
break
74-
end
75-
end
76-
if !found
77-
push!(mts, mt)
78-
push!(fullmatch, thisfullmatch)
79-
end
80-
end
81-
info = UnionSplitInfo(infos)
82-
else
83-
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
84-
if mt === nothing
85-
add_remark!(interp, sv, "Could not identify method table for call")
86-
return CallMeta(Any, false)
87-
end
88-
mt = mt::Core.MethodTable
89-
matches = findall(atype, method_table(interp, sv); limit=max_methods)
90-
if matches === missing
91-
# this means too many methods matched
92-
# (assume this will always be true, so we don't compute / update valid age in this case)
93-
add_remark!(interp, sv, "Too many methods matched")
94-
return CallMeta(Any, false)
95-
end
96-
push!(mts, mt)
97-
push!(fullmatch, _any(match->(match::MethodMatch).fully_covers, matches))
98-
info = MethodMatchInfo(matches)
99-
applicable = matches.matches
100-
valid_worlds = matches.valid_worlds
101-
applicable_argtypes = nothing
38+
39+
matches = find_matching_methods(argtypes, atype, method_table(interp, sv), InferenceParams(interp).MAX_UNION_SPLITTING, max_methods)
40+
if isa(matches, FailedMethodMatch)
41+
add_remark!(interp, sv, matches.reason)
42+
return CallMeta(Any, false)
10243
end
44+
45+
(; valid_worlds, applicable, info) = matches
10346
update_valid_age!(sv, valid_worlds)
104-
applicable = applicable::Array{Any,1}
10547
napplicable = length(applicable)
10648
rettype = Bottom
10749
edges = MethodInstance[]
@@ -142,7 +84,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
14284
if edge !== nothing
14385
push!(edges, edge)
14486
end
145-
this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i]
87+
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
14688
const_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
14789
if const_rt !== rt && const_rt rt
14890
rt = const_rt
@@ -164,7 +106,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
164106
end
165107
# try constant propagation with argtypes for this match
166108
# this is in preparation for inlining, or improving the return result
167-
this_argtypes = applicable_argtypes === nothing ? argtypes : applicable_argtypes[i]
109+
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
168110
const_this_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
169111
if const_this_rt !== this_rt && const_this_rt this_rt
170112
this_rt = const_this_rt
@@ -275,7 +217,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
275217
# and avoid keeping track of a more complex result type.
276218
rettype = Any
277219
end
278-
add_call_backedges!(interp, rettype, edges, fullmatch, mts, atype, sv)
220+
add_call_backedges!(interp, rettype, edges, matches, atype, sv)
279221
if !isempty(sv.pclimitations) # remove self, if present
280222
delete!(sv.pclimitations, sv)
281223
for caller in sv.callers_in_cycle
@@ -286,24 +228,110 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
286228
return CallMeta(rettype, info)
287229
end
288230

289-
function add_call_backedges!(interp::AbstractInterpreter,
290-
@nospecialize(rettype),
291-
edges::Vector{MethodInstance},
292-
fullmatch::Vector{Bool}, mts::Vector{Core.MethodTable}, @nospecialize(atype),
293-
sv::InferenceState)
294-
if rettype === Any
295-
# for `NativeInterpreter`, we don't add backedges when a new method couldn't refine
296-
# (widen) this type
297-
return
231+
struct FailedMethodMatch
232+
reason::String
233+
end
234+
235+
struct MethodMatches
236+
applicable::Vector{Any}
237+
info::MethodMatchInfo
238+
valid_worlds::WorldRange
239+
mt::Core.MethodTable
240+
fullmatch::Bool
241+
end
242+
243+
struct UnionSplitMethodMatches
244+
applicable::Vector{Any}
245+
applicable_argtypes::Vector{Vector{Any}}
246+
info::UnionSplitInfo
247+
valid_worlds::WorldRange
248+
mts::Vector{Core.MethodTable}
249+
fullmatches::Vector{Bool}
250+
end
251+
252+
function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView,
253+
union_split::Int, max_methods::Int)
254+
# NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type
255+
if 1 < unionsplitcost(argtypes) <= union_split
256+
split_argtypes = switchtupleunion(argtypes)
257+
infos = MethodMatchInfo[]
258+
applicable = Any[]
259+
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
260+
valid_worlds = WorldRange()
261+
mts = Core.MethodTable[]
262+
fullmatches = Bool[]
263+
for i in 1:length(split_argtypes)
264+
arg_n = split_argtypes[i]::Vector{Any}
265+
sig_n = argtypes_to_type(arg_n)
266+
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
267+
mt === nothing && return FailedMethodMatch("Could not identify method table for call")
268+
mt = mt::Core.MethodTable
269+
matches = findall(sig_n, method_table; limit = max_methods)
270+
if matches === missing
271+
return FailedMethodMatch("For one of the union split cases, too many methods matched")
272+
end
273+
push!(infos, MethodMatchInfo(matches))
274+
for m in matches
275+
push!(applicable, m)
276+
push!(applicable_argtypes, arg_n)
277+
end
278+
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
279+
thisfullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
280+
found = false
281+
for (i, mt′) in enumerate(mts)
282+
if mt′ === mt
283+
fullmatches[i] &= thisfullmatch
284+
found = true
285+
break
286+
end
287+
end
288+
if !found
289+
push!(mts, mt)
290+
push!(fullmatches, thisfullmatch)
291+
end
292+
end
293+
return UnionSplitMethodMatches(applicable,
294+
applicable_argtypes,
295+
UnionSplitInfo(infos),
296+
valid_worlds,
297+
mts,
298+
fullmatches)
299+
else
300+
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
301+
if mt === nothing
302+
return FailedMethodMatch("Could not identify method table for call")
303+
end
304+
mt = mt::Core.MethodTable
305+
matches = findall(atype, method_table; limit = max_methods)
306+
if matches === missing
307+
# this means too many methods matched
308+
# (assume this will always be true, so we don't compute / update valid age in this case)
309+
return FailedMethodMatch("Too many methods matched")
310+
end
311+
fullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
312+
return MethodMatches(matches.matches,
313+
MethodMatchInfo(matches),
314+
matches.valid_worlds,
315+
mt,
316+
fullmatch)
298317
end
318+
end
319+
320+
function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype), edges::Vector{MethodInstance},
321+
matches::Union{MethodMatches,UnionSplitMethodMatches}, @nospecialize(atype),
322+
sv::InferenceState)
323+
# for `NativeInterpreter`, we don't add backedges when a new method couldn't refine (widen) this type
324+
rettype === Any && return
299325
for edge in edges
300326
add_backedge!(edge, sv)
301327
end
302-
for (thisfullmatch, mt) in zip(fullmatch, mts)
303-
if !thisfullmatch
304-
# also need an edge to the method table in case something gets
305-
# added that did not intersect with any existing method
306-
add_mt_backedge!(mt, atype, sv)
328+
# also need an edge to the method table in case something gets
329+
# added that did not intersect with any existing method
330+
if isa(matches, MethodMatches)
331+
matches.fullmatch || add_mt_backedge!(matches.mt, atype, sv)
332+
else
333+
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
334+
thisfullmatch || add_mt_backedge!(mt, atype, sv)
307335
end
308336
end
309337
end

base/compiler/inferenceresult.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ end
1313
# for the provided `linfo` and `given_argtypes`. The purpose of this function is
1414
# to return a valid value for `cache_lookup(linfo, argtypes, cache).argtypes`,
1515
# so that we can construct cache-correct `InferenceResult`s in the first place.
16-
function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, va_override)
16+
function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, va_override::Bool)
1717
@assert isa(linfo.def, Method) # ensure the next line works
1818
nargs::Int = linfo.def.nargs
1919
@assert length(given_argtypes) >= (nargs - 1)

base/compiler/optimize.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,11 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
189189
return true
190190
end
191191

192-
# Convert IRCode back to CodeInfo and compute inlining cost and sideeffects
192+
# compute inlining cost and sideeffects
193193
function finish(interp::AbstractInterpreter, opt::OptimizationState, params::OptimizationParams, ir::IRCode, @nospecialize(result))
194-
(; def) = linfo = opt.linfo
195-
nargs = Int(opt.nargs) - 1
194+
(; src, nargs, linfo) = opt
195+
(; def, specTypes) = linfo
196+
nargs = Int(nargs) - 1
196197

197198
force_noinline = _any(@nospecialize(x) -> isexpr(x, :meta) && x.args[1] === :noinline, ir.meta)
198199

@@ -214,7 +215,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
214215
end
215216
end
216217
if proven_pure
217-
for fl in opt.src.slotflags
218+
for fl in src.slotflags
218219
if (fl & SLOT_USEDUNDEF) != 0
219220
proven_pure = false
220221
break
@@ -223,7 +224,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
223224
end
224225
end
225226
if proven_pure
226-
opt.src.pure = true
227+
src.pure = true
227228
end
228229

229230
if proven_pure
@@ -236,7 +237,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
236237
if !(isa(result, Const) && !is_inlineable_constant(result.val))
237238
opt.const_api = true
238239
end
239-
force_noinline || (opt.src.inlineable = true)
240+
force_noinline || (src.inlineable = true)
240241
end
241242
end
242243

@@ -245,7 +246,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
245246
# determine and cache inlineability
246247
union_penalties = false
247248
if !force_noinline
248-
sig = unwrap_unionall(linfo.specTypes)
249+
sig = unwrap_unionall(specTypes)
249250
if isa(sig, DataType) && sig.name === Tuple.name
250251
for P in sig.parameters
251252
P = unwrap_unionall(P)
@@ -257,25 +258,25 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
257258
else
258259
force_noinline = true
259260
end
260-
if !opt.src.inlineable && result === Union{}
261+
if !src.inlineable && result === Union{}
261262
force_noinline = true
262263
end
263264
end
264265
if force_noinline
265-
opt.src.inlineable = false
266+
src.inlineable = false
266267
elseif isa(def, Method)
267-
if opt.src.inlineable && isdispatchtuple(linfo.specTypes)
268+
if src.inlineable && isdispatchtuple(specTypes)
268269
# obey @inline declaration if a dispatch barrier would not help
269270
else
270271
bonus = 0
271272
if result Tuple && !isconcretetype(widenconst(result))
272273
bonus = params.inline_tupleret_bonus
273274
end
274-
if opt.src.inlineable
275+
if src.inlineable
275276
# For functions declared @inline, increase the cost threshold 20x
276277
bonus += params.inline_cost_threshold*19
277278
end
278-
opt.src.inlineable = isinlineable(def, opt, params, union_penalties, bonus)
279+
src.inlineable = isinlineable(def, opt, params, union_penalties, bonus)
279280
end
280281
end
281282

0 commit comments

Comments
 (0)