Skip to content

Commit 5259e28

Browse files
committed
inference×effects: improve the const-prop' heuristics
This commit improves the heuristics to judge const-prop' profitability with new effect property `:const_prop_profitable_args`. This is supposed to supplement our primary const-prop' heuristic based on inlining cost and is supposed to be a general fix for type stabilities issues discussed at e.g. #45952 and #46430 (and eliminating the need for manual `@constprop :aggressive` clutters in such situations). The new effect property `:const_prop_profitable_args` tracks call arguments that can be considered to shape up generated code if their constant information is available. Currently this commit exploits the following const-prop' profitabilities: - `Val(x)`-profitability: as `Val` generally encodes constant information into the type domain, it is generally profitable to constant prop' `x` if the constructed `Val(x)` is used later (e.g. for dispatch). This basically tries to exploit const-prop' profitability in the following kind of case: ```julia kernel(::Val{1}, args...) = ... kernel(::Val{2}, args...) = ... function profitable1(x::Int, args...) kernel(Val(x), args...) end ``` This allows the compiler to perform const-prop' for case like #45952 even if the primary heuristics based on inlining cost gets confused. - branching-profitability: constant branch condition is generally very profitable as it can shape up generated code as well as narrow down the return type inference by cutting off the dead branch. ```julia function profitable2(raise::Bool, args...) v = op(args...) if v === nothing && raise return nothing end return v end ``` Currently this commit passes all the test cases and also actually improves target type stabilities, but doesn't work very ideally as it seems to be a bit too aggressive (this commit right now strictly increases the chances of const-propagation). I'd like to further tweak this heuristic to keep the latency in general cases.
1 parent 1eee6ef commit 5259e28

File tree

7 files changed

+200
-40
lines changed

7 files changed

+200
-40
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 84 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
210210
all_effects = Effects(all_effects; nothrow=false)
211211
end
212212

213-
rettype = from_interprocedural!(𝕃ₚ, rettype, sv, arginfo, conditionals)
213+
(; rt, effects) = from_interprocedural!(𝕃ₚ, rettype, all_effects, sv, arginfo, conditionals)
214214

215215
# Also considering inferring the compilation signature for this method, so
216216
# it is available to the compiler in case it ends up needing it.
@@ -219,32 +219,32 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
219219
method = match.method
220220
sig = match.spec_types
221221
mi = specialize_method(match; preexisting=true)
222-
if mi !== nothing && !const_prop_methodinstance_heuristic(interp, mi, arginfo, sv)
222+
if mi !== nothing && !const_prop_methodinstance_heuristic(interp, mi, arginfo, Effects(), sv)
223223
csig = get_compileable_sig(method, sig, match.sparams)
224224
if csig !== nothing && csig !== sig
225225
abstract_call_method(interp, method, csig, match.sparams, multiple_matches, StmtInfo(false), sv)
226226
end
227227
end
228228
end
229229

230-
if call_result_unused(si) && !(rettype === Bottom)
230+
if call_result_unused(si) && !(rt === Bottom)
231231
add_remark!(interp, sv, "Call result type was widened because the return value is unused")
232232
# We're mainly only here because the optimizer might want this code,
233233
# but we ourselves locally don't typically care about it locally
234234
# (beyond checking if it always throws).
235235
# So avoid adding an edge, since we don't want to bother attempting
236236
# to improve our result even if it does change (to always throw),
237237
# and avoid keeping track of a more complex result type.
238-
rettype = Any
238+
rt = Any
239239
end
240-
add_call_backedges!(interp, rettype, all_effects, edges, matches, atype, sv)
240+
add_call_backedges!(interp, rt, effects, edges, matches, atype, sv)
241241
if !isempty(sv.pclimitations) # remove self, if present
242242
delete!(sv.pclimitations, sv)
243243
for caller in sv.callers_in_cycle
244244
delete!(sv.pclimitations, caller)
245245
end
246246
end
247-
return CallMeta(rettype, all_effects, info)
247+
return CallMeta(rt, effects, info)
248248
end
249249

250250
struct FailedMethodMatch
@@ -348,15 +348,24 @@ function find_matching_methods(𝕃::AbstractLattice,
348348
end
349349
end
350350

351+
struct InterproceduralResult
352+
rt
353+
effects::Effects
354+
InterproceduralResult(@nospecialize(rt), effects::Effects) = new(rt, effects)
355+
end
356+
351357
"""
352-
from_interprocedural!(𝕃ₚ::AbstractLattice, rt, sv::InferenceState, arginfo::ArgInfo, maybecondinfo) -> newrt
358+
from_interprocedural!(𝕃ₚ::AbstractLattice, rt, effects::Effects,
359+
sv::InferenceState, arginfo::ArgInfo, maybecondinfo) -> InterproceduralResult
353360
354-
Converts inter-procedural return type `rt` into a local lattice element `newrt`,
355-
that is appropriate in the context of current local analysis frame `sv`, especially:
361+
Converts extended lattice element `rt` and `effects::Effects` that represent inferred
362+
return type and method call effects into new lattice ement and `Effects` that are
363+
appropriate in the context of current local analysis frame `sv`, especially:
356364
- unwraps `rt::LimitedAccuracy` and collects its limitations into the current frame `sv`
357365
- converts boolean `rt` to new boolean `newrt` in a way `newrt` can propagate extra conditional
358366
refinement information, e.g. translating `rt::InterConditional` into `newrt::Conditional`
359367
that holds a type constraint information about a variable in `sv`
368+
- recomputes `effects.const_prop_profitable_args` so that they are imposed on call arguments of `sv`
360369
361370
This function _should_ be used wherever we propagate results returned from
362371
`abstract_call_method` or `abstract_call_method_with_const_args`.
@@ -368,7 +377,8 @@ In such cases `maybecondinfo` should be either of:
368377
When we deal with multiple `MethodMatch`es, it's better to precompute `maybecondinfo` by
369378
`tmerge`ing argument signature type of each method call.
370379
"""
371-
function from_interprocedural!(𝕃ₚ::AbstractLattice, @nospecialize(rt), sv::InferenceState, arginfo::ArgInfo, @nospecialize(maybecondinfo))
380+
function from_interprocedural!(𝕃ₚ::AbstractLattice, @nospecialize(rt), effects::Effects,
381+
sv::InferenceState, arginfo::ArgInfo, @nospecialize(maybecondinfo))
372382
rt = collect_limitations!(rt, sv)
373383
if isa(rt, InterMustAlias)
374384
rt = from_intermustalias(rt, arginfo)
@@ -380,7 +390,23 @@ function from_interprocedural!(𝕃ₚ::AbstractLattice, @nospecialize(rt), sv::
380390
end
381391
end
382392
@assert !(rt isa InterConditional || rt isa InterMustAlias) "invalid lattice element returned from inter-procedural context"
383-
return rt
393+
if effects.const_prop_profitable_args !== NO_PROFITABLE_ARGS
394+
argsbits = 0x00
395+
fargs = arginfo.fargs
396+
if fargs !== nothing
397+
for i = 1:length(fargs)
398+
if is_const_prop_profitable_arg(effects, i)
399+
arg = fargs[i]
400+
if is_call_argument(arg, sv) && 1 slot_id(arg) 8
401+
argsbits |= 0x01 << (slot_id(arg)-1)
402+
end
403+
end
404+
end
405+
end
406+
const_prop_profitable_args = ConstPropProfitableArgs(argsbits)
407+
effects = Effects(effects; const_prop_profitable_args)
408+
end
409+
return InterproceduralResult(rt, effects)
384410
end
385411

386412
function collect_limitations!(@nospecialize(typ), sv::InferenceState)
@@ -993,9 +1019,8 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
9931019
end
9941020
end
9951021
# try constant prop'
996-
inf_cache = get_inference_cache(interp)
9971022
𝕃ᵢ = typeinf_lattice(interp)
998-
inf_result = cache_lookup(𝕃ᵢ, mi, arginfo.argtypes, inf_cache)
1023+
inf_result = cache_lookup(𝕃ᵢ, mi, arginfo.argtypes, get_inference_cache(interp))
9991024
if inf_result === nothing
10001025
# if there might be a cycle, check to make sure we don't end up
10011026
# calling ourselves here.
@@ -1062,7 +1087,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter,
10621087
return nothing
10631088
end
10641089
mi = mi::MethodInstance
1065-
if !force && !const_prop_methodinstance_heuristic(interp, mi, arginfo, sv)
1090+
if !force && !const_prop_methodinstance_heuristic(interp, mi, arginfo, result.effects, sv)
10661091
add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic")
10671092
return nothing
10681093
end
@@ -1214,7 +1239,7 @@ end
12141239
# where we would spend a lot of time, but are probably unlikely to get an improved
12151240
# result anyway.
12161241
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter,
1217-
mi::MethodInstance, arginfo::ArgInfo, sv::InferenceState)
1242+
mi::MethodInstance, arginfo::ArgInfo, effects::Effects, sv::InferenceState)
12181243
method = mi.def::Method
12191244
if method.is_for_opaque_closure
12201245
# Not inlining an opaque closure can be very expensive, so be generous
@@ -1239,6 +1264,8 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter,
12391264
elseif is_stmt_noinline(flag)
12401265
# this call won't be inlined, thus this constant-prop' will most likely be unfruitful
12411266
return false
1267+
elseif any_const_prop_profitable_args(effects, arginfo.argtypes)
1268+
return true
12421269
else
12431270
# Peek at the inferred result for the method to determine if the optimizer
12441271
# was able to cut it down to something simple (inlineable in particular).
@@ -1256,6 +1283,21 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter,
12561283
return false # the cache isn't inlineable, so this constant-prop' will most likely be unfruitful
12571284
end
12581285

1286+
# check if constant information is available on any call argument that has been analyzed as
1287+
# const-prop' profitable
1288+
function any_const_prop_profitable_args(effects::Effects, argtypes::Vector{Any})
1289+
if effects.const_prop_profitable_args === NO_PROFITABLE_ARGS
1290+
return false
1291+
end
1292+
for i in 1:length(argtypes)
1293+
ai = widenconditional(argtypes[i])
1294+
if isa(ai, Const) && is_const_prop_profitable_arg(effects, i)
1295+
return true
1296+
end
1297+
end
1298+
return false
1299+
end
1300+
12591301
# This is only for use with `Conditional`.
12601302
# In general, usage of this is wrong.
12611303
ssa_def_slot(@nospecialize(arg), sv::IRCode) = nothing
@@ -1901,11 +1943,10 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
19011943
(; rt, effects, const_result, edge) = const_call_result
19021944
end
19031945
end
1904-
rt = from_interprocedural!(𝕃ₚ, rt, sv, arginfo, sig)
19051946
effects = Effects(effects; nonoverlayed=!overlayed)
1906-
info = InvokeCallInfo(match, const_result)
1947+
(; rt, effects) = from_interprocedural!(𝕃ₚ, rt, effects, sv, arginfo, sig)
19071948
edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge)
1908-
return CallMeta(rt, effects, info)
1949+
return CallMeta(rt, effects, InvokeCallInfo(match, const_result))
19091950
end
19101951

19111952
function invoke_rewrite(xs::Vector{Any})
@@ -2015,6 +2056,20 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
20152056
return CallMeta(typename_static(argtypes[2]), EFFECTS_TOTAL, MethodResultPure())
20162057
elseif f === Core._hasmethod
20172058
return _hasmethod_tfunc(interp, argtypes, sv)
2059+
elseif la == 2 && istoptype(f, :Val)
2060+
# `Val` generally encodes constant information into the type domain, so there is
2061+
# generally a high profitability for constant propagation if the argument of the
2062+
# `Val` constructor is a call argument
2063+
fargs = arginfo.fargs
2064+
if fargs !== nothing
2065+
arg = arginfo.fargs[2]
2066+
if is_call_argument(arg, sv) && !isempty(sv.ssavalue_uses[sv.currpc])
2067+
if 1 slot_id(arg) 8
2068+
const_prop_profitable_args = ConstPropProfitableArgs(0x01 << (slot_id(arg)-1))
2069+
merge_effects!(interp, sv, Effects(EFFECTS_TOTAL; const_prop_profitable_args))
2070+
end
2071+
end
2072+
end
20182073
end
20192074
atype = argtypes_to_type(argtypes)
20202075
return abstract_call_gf_by_type(interp, f, arginfo, si, atype, sv, max_methods)
@@ -2048,7 +2103,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
20482103
effects = Effects(effects; nothrow=false)
20492104
end
20502105
end
2051-
rt = from_interprocedural!(𝕃ₚ, rt, sv, arginfo, match.spec_types)
2106+
(; rt, effects) = from_interprocedural!(𝕃ₚ, rt, effects, sv, arginfo, match.spec_types)
20522107
info = OpaqueClosureCallInfo(match, const_result)
20532108
edge !== nothing && add_backedge!(sv, edge)
20542109
return CallMeta(rt, effects, info)
@@ -2481,8 +2536,7 @@ function abstract_eval_foreigncall(interp::AbstractInterpreter, e::Expr, vtypes:
24812536
override.terminates_globally ? true : effects.terminates,
24822537
override.notaskstate ? true : effects.notaskstate,
24832538
override.inaccessiblememonly ? ALWAYS_TRUE : effects.inaccessiblememonly,
2484-
effects.nonoverlayed,
2485-
effects.noinbounds)
2539+
effects.nonoverlayed, effects.noinbounds, effects.const_prop_profitable_args)
24862540
end
24872541
return RTEffects(t, effects)
24882542
end
@@ -2865,6 +2919,15 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
28652919
@goto branch
28662920
elseif isa(stmt, GotoIfNot)
28672921
condx = stmt.cond
2922+
if is_call_argument(condx, frame)
2923+
# if this condition object is a call argument, there will be a high
2924+
# profitability for constant-propagating it, since it can shape up
2925+
# the generated code by cutting off the dead branch entirely
2926+
if 1 slot_id(condx) 8
2927+
const_prop_profitable_args = ConstPropProfitableArgs(0x01 << (slot_id(condx)-1))
2928+
merge_effects!(interp, frame, Effects(EFFECTS_TOTAL; const_prop_profitable_args))
2929+
end
2930+
end
28682931
condt = abstract_eval_value(interp, condx, currstate, frame)
28692932
if condt === Bottom
28702933
ssavaluetypes[currpc] = Bottom

base/compiler/effects.jl

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
struct ConstPropProfitableArgs
2+
argsbits::UInt8
3+
end
4+
15
"""
26
effects::Effects
37
@@ -64,6 +68,7 @@ struct Effects
6468
inaccessiblememonly::UInt8
6569
nonoverlayed::Bool
6670
noinbounds::Bool
71+
const_prop_profitable_args::ConstPropProfitableArgs
6772
function Effects(
6873
consistent::UInt8,
6974
effect_free::UInt8,
@@ -72,7 +77,8 @@ struct Effects
7277
notaskstate::Bool,
7378
inaccessiblememonly::UInt8,
7479
nonoverlayed::Bool,
75-
noinbounds::Bool)
80+
noinbounds::Bool,
81+
const_prop_profitable_args::ConstPropProfitableArgs = NO_PROFITABLE_ARGS)
7682
return new(
7783
consistent,
7884
effect_free,
@@ -81,7 +87,8 @@ struct Effects
8187
notaskstate,
8288
inaccessiblememonly,
8389
nonoverlayed,
84-
noinbounds)
90+
noinbounds,
91+
const_prop_profitable_args)
8592
end
8693
end
8794

@@ -98,10 +105,13 @@ const EFFECT_FREE_IF_INACCESSIBLEMEMONLY = 0x01 << 1
98105
# :inaccessiblememonly bits
99106
const INACCESSIBLEMEM_OR_ARGMEMONLY = 0x01 << 1
100107

101-
const EFFECTS_TOTAL = Effects(ALWAYS_TRUE, ALWAYS_TRUE, true, true, true, ALWAYS_TRUE, true, true)
102-
const EFFECTS_THROWS = Effects(ALWAYS_TRUE, ALWAYS_TRUE, false, true, true, ALWAYS_TRUE, true, true)
103-
const EFFECTS_UNKNOWN = Effects(ALWAYS_FALSE, ALWAYS_FALSE, false, false, false, ALWAYS_FALSE, true, false) # unknown mostly, but it's not overlayed at least (e.g. it's not a call)
104-
const _EFFECTS_UNKNOWN = Effects(ALWAYS_FALSE, ALWAYS_FALSE, false, false, false, ALWAYS_FALSE, false, false) # unknown really
108+
# :const_prop_profitable_args bits
109+
const NO_PROFITABLE_ARGS = ConstPropProfitableArgs(0x00)
110+
111+
const EFFECTS_TOTAL = Effects(ALWAYS_TRUE, ALWAYS_TRUE, true, true, true, ALWAYS_TRUE, true, true, NO_PROFITABLE_ARGS)
112+
const EFFECTS_THROWS = Effects(ALWAYS_TRUE, ALWAYS_TRUE, false, true, true, ALWAYS_TRUE, true, true, NO_PROFITABLE_ARGS)
113+
const EFFECTS_UNKNOWN = Effects(ALWAYS_FALSE, ALWAYS_FALSE, false, false, false, ALWAYS_FALSE, true, false, NO_PROFITABLE_ARGS) # unknown mostly, but it's not overlayed at least (e.g. it's not a call)
114+
const _EFFECTS_UNKNOWN = Effects(ALWAYS_FALSE, ALWAYS_FALSE, false, false, false, ALWAYS_FALSE, false, false, NO_PROFITABLE_ARGS) # unknown really
105115

106116
function Effects(e::Effects = _EFFECTS_UNKNOWN;
107117
consistent::UInt8 = e.consistent,
@@ -111,7 +121,8 @@ function Effects(e::Effects = _EFFECTS_UNKNOWN;
111121
notaskstate::Bool = e.notaskstate,
112122
inaccessiblememonly::UInt8 = e.inaccessiblememonly,
113123
nonoverlayed::Bool = e.nonoverlayed,
114-
noinbounds::Bool = e.noinbounds)
124+
noinbounds::Bool = e.noinbounds,
125+
const_prop_profitable_args::ConstPropProfitableArgs = e.const_prop_profitable_args)
115126
return Effects(
116127
consistent,
117128
effect_free,
@@ -120,7 +131,8 @@ function Effects(e::Effects = _EFFECTS_UNKNOWN;
120131
notaskstate,
121132
inaccessiblememonly,
122133
nonoverlayed,
123-
noinbounds)
134+
noinbounds,
135+
const_prop_profitable_args)
124136
end
125137

126138
function merge_effects(old::Effects, new::Effects)
@@ -132,7 +144,8 @@ function merge_effects(old::Effects, new::Effects)
132144
merge_effectbits(old.notaskstate, new.notaskstate),
133145
merge_effectbits(old.inaccessiblememonly, new.inaccessiblememonly),
134146
merge_effectbits(old.nonoverlayed, new.nonoverlayed),
135-
merge_effectbits(old.noinbounds, new.noinbounds))
147+
merge_effectbits(old.noinbounds, new.noinbounds),
148+
merge_effectbits(old.const_prop_profitable_args, new.const_prop_profitable_args))
136149
end
137150

138151
function merge_effectbits(old::UInt8, new::UInt8)
@@ -142,6 +155,7 @@ function merge_effectbits(old::UInt8, new::UInt8)
142155
return old | new
143156
end
144157
merge_effectbits(old::Bool, new::Bool) = old & new
158+
merge_effectbits(old::ConstPropProfitableArgs, new::ConstPropProfitableArgs) = ConstPropProfitableArgs(old.argsbits | new.argsbits)
145159

146160
is_consistent(effects::Effects) = effects.consistent === ALWAYS_TRUE
147161
is_effect_free(effects::Effects) = effects.effect_free === ALWAYS_TRUE
@@ -177,15 +191,18 @@ is_effect_free_if_inaccessiblememonly(effects::Effects) = !iszero(effects.effect
177191

178192
is_inaccessiblemem_or_argmemonly(effects::Effects) = effects.inaccessiblememonly === INACCESSIBLEMEM_OR_ARGMEMONLY
179193

194+
is_const_prop_profitable_arg(effects::Effects, arg::Int) = !iszero(effects.const_prop_profitable_args.argsbits & (0x01 << (arg-1)))
195+
180196
function encode_effects(e::Effects)
181-
return ((e.consistent % UInt32) << 0) |
182-
((e.effect_free % UInt32) << 3) |
183-
((e.nothrow % UInt32) << 5) |
184-
((e.terminates % UInt32) << 6) |
185-
((e.notaskstate % UInt32) << 7) |
186-
((e.inaccessiblememonly % UInt32) << 8) |
187-
((e.nonoverlayed % UInt32) << 10)|
188-
((e.noinbounds % UInt32) << 11)
197+
return ((e.consistent % UInt32) << 0) |
198+
((e.effect_free % UInt32) << 3) |
199+
((e.nothrow % UInt32) << 5) |
200+
((e.terminates % UInt32) << 6) |
201+
((e.notaskstate % UInt32) << 7) |
202+
((e.inaccessiblememonly % UInt32) << 8) |
203+
((e.nonoverlayed % UInt32) << 10) |
204+
((e.noinbounds % UInt32) << 11) |
205+
((e.const_prop_profitable_args.argsbits % UInt32) << 12)
189206
end
190207

191208
function decode_effects(e::UInt32)
@@ -197,7 +214,8 @@ function decode_effects(e::UInt32)
197214
_Bool((e >> 7) & 0x01),
198215
UInt8((e >> 8) & 0x03),
199216
_Bool((e >> 10) & 0x01),
200-
_Bool((e >> 11) & 0x01))
217+
_Bool((e >> 11) & 0x01),
218+
ConstPropProfitableArgs(UInt8((e >> 12) & 0x7f)))
201219
end
202220

203221
struct EffectsOverride

base/compiler/inferencestate.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,3 +620,5 @@ function narguments(sv::InferenceState, include_va::Bool=true)
620620
end
621621
return nargs
622622
end
623+
is_call_argument(@nospecialize(x), sv::InferenceState) =
624+
isa(x, SlotNumber) && slot_id(x) narguments(sv)

base/compiler/utilities.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,23 @@ _topmod(m::Module) = ccall(:jl_base_relative_to, Any, (Any,), m)::Module
5050

5151
function istopfunction(@nospecialize(f), name::Symbol)
5252
tn = typeof(f).name
53-
if tn.mt.name === name
53+
mn = tn.mt.name
54+
if mn === name
5455
top = _topmod(tn.module)
5556
return isdefined(top, name) && isconst(top, name) && f === getglobal(top, name)
5657
end
5758
return false
5859
end
5960

61+
function istoptype(@nospecialize(T), name::Symbol)
62+
t = unwrap_unionall(T)
63+
if isa(t, DataType) && t.name.name === name
64+
top = _topmod(t.name.module)
65+
return isdefined(top, name) && isconst(top, name) && T === getglobal(top, name)
66+
end
67+
return false
68+
end
69+
6070
#######
6171
# AST #
6272
#######

0 commit comments

Comments
 (0)