Skip to content

Commit ab6a268

Browse files
committed
inference×effects: improve the const-prop' heuristics
This commit improves the heuristics to judge const-prop' profitability with new effect property `:const_prop_profitable_args`. This is supposed to supplement our primary const-prop' heuristic based on inlining cost and is supposed to be a general fix for type stabilities issues discussed at e.g. #45952 and #46430 (and eliminating the need for manual `@constprop :aggressive` clutters in such situations). The new effect property `:const_prop_profitable_args` tracks call arguments that can be considered to shape up generated code if their constant information is available. Currently this commit exploits the following const-prop' profitabilities: - `Val(x)`-profitability: as `Val` generally encodes constant information into the type domain, it is generally profitable to constant prop' `x` if the constructed `Val(x)` is used later (e.g. for dispatch). This basically tries to exploit const-prop' profitability in the following kind of case: ```julia kernel(::Val{1}, args...) = ... kernel(::Val{2}, args...) = ... function profitable1(x::Int, args...) kernel(Val(x), args...) end ``` This allows the compiler to perform const-prop' for case like #45952 even if the primary heuristics based on inlining cost gets confused. - branching-profitability: constant branch condition is generally very profitable as it can shape up generated code as well as narrow down the return type inference by cutting off the dead branch. ```julia function profitable2(raise::Bool, args...) v = op(args...) if v === nothing && raise return nothing end return v end ``` Currently this commit passes all the test cases and also actually improves target type stabilities, but doesn't work very ideally as it seems to be a bit too aggressive (this commit right now strictly increases the chances of const-propagation). I'd like to further tweak this heuristic to keep the latency in general cases.
1 parent a392e39 commit ab6a268

File tree

8 files changed

+197
-31
lines changed

8 files changed

+197
-31
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 87 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
211211
all_effects = Effects(all_effects; nothrow=false)
212212
end
213213

214-
rettype = from_interprocedural!(ipo_lattice(interp), rettype, sv, arginfo, conditionals)
214+
(; rt, effects) = from_interprocedural!(ipo_lattice(interp), rettype, all_effects, sv, arginfo, conditionals)
215215

216216
# Also considering inferring the compilation signature for this method, so
217217
# it is available to the compiler in case it ends up needing it.
@@ -220,32 +220,32 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
220220
method = match.method
221221
sig = match.spec_types
222222
mi = specialize_method(match; preexisting=true)
223-
if mi !== nothing && !const_prop_methodinstance_heuristic(interp, match, mi::MethodInstance, arginfo, sv)
223+
if mi !== nothing && !const_prop_methodinstance_heuristic(interp, mi::MethodInstance, arginfo, Effects(), sv)
224224
csig = get_compileable_sig(method, sig, match.sparams)
225225
if csig !== nothing && csig !== sig
226226
abstract_call_method(interp, method, csig, match.sparams, multiple_matches, StmtInfo(false), sv)
227227
end
228228
end
229229
end
230230

231-
if call_result_unused(si) && !(rettype === Bottom)
231+
if call_result_unused(si) && !(rt === Bottom)
232232
add_remark!(interp, sv, "Call result type was widened because the return value is unused")
233233
# We're mainly only here because the optimizer might want this code,
234234
# but we ourselves locally don't typically care about it locally
235235
# (beyond checking if it always throws).
236236
# So avoid adding an edge, since we don't want to bother attempting
237237
# to improve our result even if it does change (to always throw),
238238
# and avoid keeping track of a more complex result type.
239-
rettype = Any
239+
rt = Any
240240
end
241-
add_call_backedges!(interp, rettype, all_effects, edges, matches, atype, sv)
241+
add_call_backedges!(interp, rt, effects, edges, matches, atype, sv)
242242
if !isempty(sv.pclimitations) # remove self, if present
243243
delete!(sv.pclimitations, sv)
244244
for caller in sv.callers_in_cycle
245245
delete!(sv.pclimitations, caller)
246246
end
247247
end
248-
return CallMeta(rettype, all_effects, info)
248+
return CallMeta(rt, effects, info)
249249
end
250250

251251
struct FailedMethodMatch
@@ -348,15 +348,24 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
348348
end
349349
end
350350

351+
struct InterproceduralResult
352+
rt
353+
effects::Effects
354+
InterproceduralResult(@nospecialize(rt), effects::Effects) = new(rt, effects)
355+
end
356+
351357
"""
352-
from_interprocedural!(ipo_lattice::AbstractLattice, rt, sv::InferenceState, arginfo::ArgInfo, maybecondinfo) -> newrt
358+
from_interprocedural!(ipo_lattice::AbstractLattice, rt, effects::Effects,
359+
sv::InferenceState, arginfo::ArgInfo, maybecondinfo) -> InterproceduralResult
353360
354-
Converts inter-procedural return type `rt` into a local lattice element `newrt`,
355-
that is appropriate in the context of current local analysis frame `sv`, especially:
361+
Converts extended lattice element `rt` and `effects::Effects` that represent inferred
362+
return type and method call effects into new lattice ement and `Effects` that are
363+
appropriate in the context of current local analysis frame `sv`, especially:
356364
- unwraps `rt::LimitedAccuracy` and collects its limitations into the current frame `sv`
357365
- converts boolean `rt` to new boolean `newrt` in a way `newrt` can propagate extra conditional
358366
refinement information, e.g. translating `rt::InterConditional` into `newrt::Conditional`
359367
that holds a type constraint information about a variable in `sv`
368+
- recomputes `effects.const_prop_profitable_args` so that they are imposed on call arguments of `sv`
360369
361370
This function _should_ be used wherever we propagate results returned from
362371
`abstract_call_method` or `abstract_call_method_with_const_args`.
@@ -368,7 +377,8 @@ In such cases `maybecondinfo` should be either of:
368377
When we deal with multiple `MethodMatch`es, it's better to precompute `maybecondinfo` by
369378
`tmerge`ing argument signature type of each method call.
370379
"""
371-
function from_interprocedural!(ipo_lattice::AbstractLattice, @nospecialize(rt), sv::InferenceState, arginfo::ArgInfo, @nospecialize(maybecondinfo))
380+
function from_interprocedural!(ipo_lattice::AbstractLattice, @nospecialize(rt), effects::Effects,
381+
sv::InferenceState, arginfo::ArgInfo, @nospecialize(maybecondinfo))
372382
rt = collect_limitations!(rt, sv)
373383
if is_lattice_bool(ipo_lattice, rt)
374384
if maybecondinfo === nothing
@@ -378,7 +388,23 @@ function from_interprocedural!(ipo_lattice::AbstractLattice, @nospecialize(rt),
378388
end
379389
end
380390
@assert !(rt isa InterConditional) "invalid lattice element returned from inter-procedural context"
381-
return rt
391+
if effects.const_prop_profitable_args !== NO_PROFITABLE_ARGS
392+
argsbits = 0x00
393+
fargs = arginfo.fargs
394+
if fargs !== nothing
395+
for i = 1:length(fargs)
396+
if is_const_prop_profitable_arg(effects, i)
397+
arg = fargs[i]
398+
if is_call_argument(arg, sv) && 1 slot_id(arg) 8
399+
argsbits |= 0x01 << (slot_id(arg)-1)
400+
end
401+
end
402+
end
403+
end
404+
const_prop_profitable_args = ConstPropProfitableArgs(argsbits)
405+
effects = Effects(effects; const_prop_profitable_args)
406+
end
407+
return InterproceduralResult(rt, effects)
382408
end
383409

384410
function collect_limitations!(@nospecialize(typ), sv::InferenceState)
@@ -906,8 +932,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
906932
end
907933
end
908934
# try constant prop'
909-
inf_cache = get_inference_cache(interp)
910-
inf_result = cache_lookup(typeinf_lattice(interp), mi, arginfo.argtypes, inf_cache)
935+
inf_result = cache_lookup(typeinf_lattice(interp), mi, arginfo.argtypes, get_inference_cache(interp))
911936
if inf_result === nothing
912937
# if there might be a cycle, check to make sure we don't end up
913938
# calling ourselves here.
@@ -964,7 +989,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
964989
return nothing
965990
end
966991
mi = mi::MethodInstance
967-
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, arginfo, sv)
992+
if !force && !const_prop_methodinstance_heuristic(interp, mi, arginfo, result.effects, sv)
968993
add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic")
969994
return nothing
970995
end
@@ -1127,10 +1152,9 @@ end
11271152
# This is a heuristic to avoid trying to const prop through complicated functions
11281153
# where we would spend a lot of time, but are probably unlikely to get an improved
11291154
# result anyway.
1130-
function const_prop_methodinstance_heuristic(
1131-
interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance,
1132-
(; argtypes)::ArgInfo, sv::InferenceState)
1133-
method = match.method
1155+
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter,
1156+
mi::MethodInstance, arginfo::ArgInfo, effects::Effects, sv::InferenceState)
1157+
method = mi.def::Method
11341158
if method.is_for_opaque_closure
11351159
# Not inlining an opaque closure can be very expensive, so be generous
11361160
# with the const-prop-ability. It is quite possible that we can't infer
@@ -1154,6 +1178,8 @@ function const_prop_methodinstance_heuristic(
11541178
elseif is_stmt_noinline(flag)
11551179
# this call won't be inlined, thus this constant-prop' will most likely be unfruitful
11561180
return false
1181+
elseif any_const_prop_profitable_args(effects, arginfo.argtypes)
1182+
return true
11571183
else
11581184
code = get(code_cache(interp), mi, nothing)
11591185
if isdefined(code, :inferred)
@@ -1162,7 +1188,7 @@ function const_prop_methodinstance_heuristic(
11621188
else
11631189
inferred = code.inferred
11641190
end
1165-
if inlining_policy(interp, inferred, IR_FLAG_NULL, mi, argtypes) !== nothing
1191+
if inlining_policy(interp, inferred, IR_FLAG_NULL, mi, arginfo.argtypes) !== nothing
11661192
return true
11671193
end
11681194
end
@@ -1171,6 +1197,21 @@ function const_prop_methodinstance_heuristic(
11711197
return false # the cache isn't inlineable, so this constant-prop' will most likely be unfruitful
11721198
end
11731199

1200+
# check if constant information is available on any call argument that has been analyzed as
1201+
# const-prop' profitable
1202+
function any_const_prop_profitable_args(effects::Effects, argtypes::Vector{Any})
1203+
if effects.const_prop_profitable_args === NO_PROFITABLE_ARGS
1204+
return false
1205+
end
1206+
for i in 1:length(argtypes)
1207+
ai = widenconditional(argtypes[i])
1208+
if isa(ai, Const) && is_const_prop_profitable_arg(effects, i)
1209+
return true
1210+
end
1211+
end
1212+
return false
1213+
end
1214+
11741215
# This is only for use with `Conditional`.
11751216
# In general, usage of this is wrong.
11761217
ssa_def_slot(@nospecialize(arg), sv::IRCode) = nothing
@@ -1711,8 +1752,9 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
17111752
end
17121753
end
17131754
effects = Effects(effects; nonoverlayed=!overlayed)
1755+
(; rt, effects) = from_interprocedural!(ipo_lattice(interp), rt, effects, sv, arginfo, sig)
17141756
edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge)
1715-
return CallMeta(from_interprocedural!(ipo_lattice(interp), rt, sv, arginfo, sig), effects, InvokeCallInfo(match, const_result))
1757+
return CallMeta(rt, effects, InvokeCallInfo(match, const_result))
17161758
end
17171759

17181760
function invoke_rewrite(xs::Vector{Any})
@@ -1835,6 +1877,20 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
18351877
val = _pure_eval_call(f, arginfo)
18361878
return CallMeta(val === nothing ? Type : val, EFFECTS_TOTAL, MethodResultPure())
18371879
end
1880+
elseif la == 2 && istoptype(f, :Val)
1881+
# `Val` generally encodes constant information into the type domain, so there is
1882+
# generally a high profitability for constant propagation if the argument of the
1883+
# `Val` constructor is a call argument
1884+
fargs = arginfo.fargs
1885+
if fargs !== nothing
1886+
arg = arginfo.fargs[2]
1887+
if is_call_argument(arg, sv) && !isempty(sv.ssavalue_uses[sv.currpc])
1888+
if 1 slot_id(arg) 8
1889+
const_prop_profitable_args = ConstPropProfitableArgs(0x01 << (slot_id(arg)-1))
1890+
merge_effects!(interp, sv, Effects(EFFECTS_TOTAL; const_prop_profitable_args))
1891+
end
1892+
end
1893+
end
18381894
end
18391895
atype = argtypes_to_type(argtypes)
18401896
return abstract_call_gf_by_type(interp, f, arginfo, si, atype, sv, max_methods)
@@ -1869,7 +1925,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
18691925
effects = Effects(effects; nothrow=false)
18701926
end
18711927
end
1872-
rt = from_interprocedural!(ipo, rt, sv, arginfo, match.spec_types)
1928+
(; rt, effects) = from_interprocedural!(ipo, rt, effects, sv, arginfo, match.spec_types)
18731929
edge !== nothing && add_backedge!(sv, edge)
18741930
return CallMeta(rt, effects, info)
18751931
end
@@ -2233,7 +2289,7 @@ function abstract_eval_foreigncall(interp::AbstractInterpreter, e::Expr, vtypes:
22332289
override.terminates_globally ? true : effects.terminates,
22342290
override.notaskstate ? true : effects.notaskstate,
22352291
override.inaccessiblememonly ? ALWAYS_TRUE : effects.inaccessiblememonly,
2236-
effects.nonoverlayed)
2292+
effects.nonoverlayed, effects.const_prop_profitable_args)
22372293
end
22382294
return RTEffects(t, effects)
22392295
end
@@ -2520,6 +2576,15 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
25202576
@goto branch
25212577
elseif isa(stmt, GotoIfNot)
25222578
condx = stmt.cond
2579+
if is_call_argument(condx, frame)
2580+
# if this condition object is a call argument, there will be a high
2581+
# profitability for constant-propagating it, since it can shape up
2582+
# the generated code by cutting off the dead branch entirely
2583+
if 1 slot_id(condx) 8
2584+
const_prop_profitable_args = ConstPropProfitableArgs(0x01 << (slot_id(condx)-1))
2585+
merge_effects!(interp, frame, Effects(EFFECTS_TOTAL; const_prop_profitable_args))
2586+
end
2587+
end
25232588
condt = abstract_eval_value(interp, condx, currstate, frame)
25242589
if condt === Bottom
25252590
ssavaluetypes[currpc] = Bottom

base/compiler/effects.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
struct ConstPropProfitableArgs
2+
argsbits::UInt8
3+
end
4+
15
"""
26
effects::Effects
37
@@ -63,6 +67,7 @@ struct Effects
6367
notaskstate::Bool
6468
inaccessiblememonly::UInt8
6569
nonoverlayed::Bool
70+
const_prop_profitable_args::ConstPropProfitableArgs
6671
noinbounds::Bool
6772
function Effects(
6873
consistent::UInt8,
@@ -72,6 +77,7 @@ struct Effects
7277
notaskstate::Bool,
7378
inaccessiblememonly::UInt8,
7479
nonoverlayed::Bool,
80+
const_prop_profitable_args::ConstPropProfitableArgs = NO_PROFITABLE_ARGS,
7581
noinbounds::Bool = true)
7682
return new(
7783
consistent,
@@ -81,6 +87,7 @@ struct Effects
8187
notaskstate,
8288
inaccessiblememonly,
8389
nonoverlayed,
90+
const_prop_profitable_args,
8491
noinbounds)
8592
end
8693
end
@@ -98,6 +105,9 @@ const EFFECT_FREE_IF_INACCESSIBLEMEMONLY = 0x01 << 1
98105
# :inaccessiblememonly bits
99106
const INACCESSIBLEMEM_OR_ARGMEMONLY = 0x01 << 1
100107

108+
# :const_prop_profitable_args bits
109+
const NO_PROFITABLE_ARGS = ConstPropProfitableArgs(0x00)
110+
101111
const EFFECTS_TOTAL = Effects(ALWAYS_TRUE, ALWAYS_TRUE, true, true, true, ALWAYS_TRUE, true)
102112
const EFFECTS_THROWS = Effects(ALWAYS_TRUE, ALWAYS_TRUE, false, true, true, ALWAYS_TRUE, true)
103113
const EFFECTS_UNKNOWN = Effects(ALWAYS_FALSE, ALWAYS_FALSE, false, false, false, ALWAYS_FALSE, true) # unknown mostly, but it's not overlayed at least (e.g. it's not a call)
@@ -111,6 +121,7 @@ function Effects(e::Effects = EFFECTS_UNKNOWN′;
111121
notaskstate::Bool = e.notaskstate,
112122
inaccessiblememonly::UInt8 = e.inaccessiblememonly,
113123
nonoverlayed::Bool = e.nonoverlayed,
124+
const_prop_profitable_args::ConstPropProfitableArgs = e.const_prop_profitable_args,
114125
noinbounds::Bool = e.noinbounds)
115126
return Effects(
116127
consistent,
@@ -120,6 +131,7 @@ function Effects(e::Effects = EFFECTS_UNKNOWN′;
120131
notaskstate,
121132
inaccessiblememonly,
122133
nonoverlayed,
134+
const_prop_profitable_args,
123135
noinbounds)
124136
end
125137

@@ -132,6 +144,7 @@ function merge_effects(old::Effects, new::Effects)
132144
merge_effectbits(old.notaskstate, new.notaskstate),
133145
merge_effectbits(old.inaccessiblememonly, new.inaccessiblememonly),
134146
merge_effectbits(old.nonoverlayed, new.nonoverlayed),
147+
merge_effectbits(old.const_prop_profitable_args, new.const_prop_profitable_args),
135148
merge_effectbits(old.noinbounds, new.noinbounds))
136149
end
137150

@@ -142,6 +155,7 @@ function merge_effectbits(old::UInt8, new::UInt8)
142155
return old | new
143156
end
144157
merge_effectbits(old::Bool, new::Bool) = old & new
158+
merge_effectbits(old::ConstPropProfitableArgs, new::ConstPropProfitableArgs) = ConstPropProfitableArgs(old.argsbits | new.argsbits)
145159

146160
is_consistent(effects::Effects) = effects.consistent === ALWAYS_TRUE
147161
is_effect_free(effects::Effects) = effects.effect_free === ALWAYS_TRUE
@@ -177,14 +191,17 @@ is_effect_free_if_inaccessiblememonly(effects::Effects) = !iszero(effects.effect
177191

178192
is_inaccessiblemem_or_argmemonly(effects::Effects) = effects.inaccessiblememonly === INACCESSIBLEMEM_OR_ARGMEMONLY
179193

194+
is_const_prop_profitable_arg(effects::Effects, arg::Int) = !iszero(effects.const_prop_profitable_args.argsbits & (0x01 << (arg-1)))
195+
180196
function encode_effects(e::Effects)
181-
return ((e.consistent % UInt32) << 0) |
182-
((e.effect_free % UInt32) << 3) |
183-
((e.nothrow % UInt32) << 5) |
184-
((e.terminates % UInt32) << 6) |
185-
((e.notaskstate % UInt32) << 7) |
186-
((e.inaccessiblememonly % UInt32) << 8) |
187-
((e.nonoverlayed % UInt32) << 10)
197+
return ((e.consistent % UInt32) << 0) |
198+
((e.effect_free % UInt32) << 3) |
199+
((e.nothrow % UInt32) << 5) |
200+
((e.terminates % UInt32) << 6) |
201+
((e.notaskstate % UInt32) << 7) |
202+
((e.inaccessiblememonly % UInt32) << 8) |
203+
((e.nonoverlayed % UInt32) << 10) |
204+
((e.const_prop_profitable_args.argsbits % UInt32) << 11)
188205
end
189206

190207
function decode_effects(e::UInt32)
@@ -195,7 +212,8 @@ function decode_effects(e::UInt32)
195212
_Bool((e >> 6) & 0x01),
196213
_Bool((e >> 7) & 0x01),
197214
UInt8((e >> 8) & 0x03),
198-
_Bool((e >> 10) & 0x01))
215+
_Bool((e >> 10) & 0x01),
216+
ConstPropProfitableArgs(UInt8((e >> 11) & 0x7f)))
199217
end
200218

201219
struct EffectsOverride

base/compiler/inferencestate.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,3 +550,5 @@ function narguments(sv::InferenceState)
550550
nargs = length(sv.result.argtypes) - isva
551551
return nargs
552552
end
553+
is_call_argument(@nospecialize(x), sv::InferenceState) =
554+
isa(x, SlotNumber) && slot_id(x) narguments(sv)

base/compiler/utilities.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,23 @@ _topmod(m::Module) = ccall(:jl_base_relative_to, Any, (Any,), m)::Module
5050

5151
function istopfunction(@nospecialize(f), name::Symbol)
5252
tn = typeof(f).name
53-
if tn.mt.name === name
53+
mn = tn.mt.name
54+
if mn === name
5455
top = _topmod(tn.module)
5556
return isdefined(top, name) && isconst(top, name) && f === getglobal(top, name)
5657
end
5758
return false
5859
end
5960

61+
function istoptype(@nospecialize(T), name::Symbol)
62+
t = unwrap_unionall(T)
63+
if isa(t, DataType) && t.name.name === name
64+
top = _topmod(t.name.module)
65+
return isdefined(top, name) && isconst(top, name) && T === getglobal(top, name)
66+
end
67+
return false
68+
end
69+
6070
#######
6171
# AST #
6272
#######

test/abstractarray.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1815,3 +1815,7 @@ end
18151815
a, b = zeros(2, 2, 2), zeros(2, 2)
18161816
@test_broken IRUtils.fully_eliminated(_has_offset_axes, Base.typesof(a, a, b, b))
18171817
end
1818+
1819+
# type stable [x;;] (https://github.com/JuliaLang/julia/issues/45952)
1820+
f45952(x) = [x;;]
1821+
@inferred f45952(1.0)

0 commit comments

Comments
 (0)