Skip to content

Commit 6661563

Browse files
committed
inference×effects: improve the const-prop' heuristics
This commit improves the heuristics to judge const-prop' profitability with new effect property `:const_prop_profitable_args`. This is supposed to supplement our primary const-prop' heuristic based on inlining cost and is supposed to be a general fix for type stabilities issues discussed at e.g. #45952 and #46430 (and eliminating the need for manual `@constprop :aggressive` clutters in such situations). The new effect property `:const_prop_profitable_args` tracks call arguments that can be considered to shape up generated code if their constant information is available. Currently this commit exploits the following const-prop' profitabilities: - `Val(x)`-profitability: as `Val` generally encodes constant information into the type domain, it is generally profitable to constant prop' `x` if the constructed `Val(x)` is used later (e.g. for dispatch). This basically tries to exploit const-prop' profitability in the following kind of case: ```julia kernel(::Val{1}, args...) = ... kernel(::Val{2}, args...) = ... function profitable1(x::Int, args...) kernel(Val(x), args...) end ``` This allows the compiler to perform const-prop' for case like #45952 even if the primary heuristics based on inlining cost gets confused. - branching-profitability: constant branch condition is generally very profitable as it can shape up generated code as well as narrow down the return type inference by cutting off the dead branch. ```julia function profitable2(raise::Bool, args...) v = op(args...) if v === nothing && raise return nothing end return v end ``` Currently this commit passes all the test cases and also actually improves target type stabilities, but doesn't work very ideally as it seems to be a bit too aggressive (this commit right now strictly increases the chances of const-propagation). I'd like to further tweak this heuristic to keep the latency in general cases.
1 parent 8181316 commit 6661563

File tree

7 files changed

+191
-29
lines changed

7 files changed

+191
-29
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 85 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
211211
all_effects = Effects(all_effects; nothrow=false)
212212
end
213213

214-
rettype = from_interprocedural!(ipo_lattice(interp), rettype, sv, arginfo, conditionals)
214+
(; rt, effects) = from_interprocedural!(ipo_lattice(interp), rettype, all_effects, sv, arginfo, conditionals)
215215

216216
# Also considering inferring the compilation signature for this method, so
217217
# it is available to the compiler in case it ends up needing it.
@@ -220,32 +220,32 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
220220
method = match.method
221221
sig = match.spec_types
222222
mi = specialize_method(match; preexisting=true)
223-
if mi !== nothing && !const_prop_methodinstance_heuristic(interp, match, mi, arginfo, sv)
223+
if mi !== nothing && !const_prop_methodinstance_heuristic(interp, mi, arginfo, Effects(), sv)
224224
csig = get_compileable_sig(method, sig, match.sparams)
225225
if csig !== nothing && csig !== sig
226226
abstract_call_method(interp, method, csig, match.sparams, multiple_matches, StmtInfo(false), sv)
227227
end
228228
end
229229
end
230230

231-
if call_result_unused(si) && !(rettype === Bottom)
231+
if call_result_unused(si) && !(rt === Bottom)
232232
add_remark!(interp, sv, "Call result type was widened because the return value is unused")
233233
# We're mainly only here because the optimizer might want this code,
234234
# but we ourselves locally don't typically care about it locally
235235
# (beyond checking if it always throws).
236236
# So avoid adding an edge, since we don't want to bother attempting
237237
# to improve our result even if it does change (to always throw),
238238
# and avoid keeping track of a more complex result type.
239-
rettype = Any
239+
rt = Any
240240
end
241-
add_call_backedges!(interp, rettype, all_effects, edges, matches, atype, sv)
241+
add_call_backedges!(interp, rt, effects, edges, matches, atype, sv)
242242
if !isempty(sv.pclimitations) # remove self, if present
243243
delete!(sv.pclimitations, sv)
244244
for caller in sv.callers_in_cycle
245245
delete!(sv.pclimitations, caller)
246246
end
247247
end
248-
return CallMeta(rettype, all_effects, info)
248+
return CallMeta(rt, effects, info)
249249
end
250250

251251
struct FailedMethodMatch
@@ -348,15 +348,24 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
348348
end
349349
end
350350

351+
struct InterproceduralResult
352+
rt
353+
effects::Effects
354+
InterproceduralResult(@nospecialize(rt), effects::Effects) = new(rt, effects)
355+
end
356+
351357
"""
352-
from_interprocedural!(ipo_lattice::AbstractLattice, rt, sv::InferenceState, arginfo::ArgInfo, maybecondinfo) -> newrt
358+
from_interprocedural!(ipo_lattice::AbstractLattice, rt, effects::Effects,
359+
sv::InferenceState, arginfo::ArgInfo, maybecondinfo) -> InterproceduralResult
353360
354-
Converts inter-procedural return type `rt` into a local lattice element `newrt`,
355-
that is appropriate in the context of current local analysis frame `sv`, especially:
361+
Converts extended lattice element `rt` and `effects::Effects` that represent inferred
362+
return type and method call effects into new lattice ement and `Effects` that are
363+
appropriate in the context of current local analysis frame `sv`, especially:
356364
- unwraps `rt::LimitedAccuracy` and collects its limitations into the current frame `sv`
357365
- converts boolean `rt` to new boolean `newrt` in a way `newrt` can propagate extra conditional
358366
refinement information, e.g. translating `rt::InterConditional` into `newrt::Conditional`
359367
that holds a type constraint information about a variable in `sv`
368+
- recomputes `effects.const_prop_profitable_args` so that they are imposed on call arguments of `sv`
360369
361370
This function _should_ be used wherever we propagate results returned from
362371
`abstract_call_method` or `abstract_call_method_with_const_args`.
@@ -368,7 +377,8 @@ In such cases `maybecondinfo` should be either of:
368377
When we deal with multiple `MethodMatch`es, it's better to precompute `maybecondinfo` by
369378
`tmerge`ing argument signature type of each method call.
370379
"""
371-
function from_interprocedural!(ipo_lattice::AbstractLattice, @nospecialize(rt), sv::InferenceState, arginfo::ArgInfo, @nospecialize(maybecondinfo))
380+
function from_interprocedural!(ipo_lattice::AbstractLattice, @nospecialize(rt), effects::Effects,
381+
sv::InferenceState, arginfo::ArgInfo, @nospecialize(maybecondinfo))
372382
rt = collect_limitations!(rt, sv)
373383
if is_lattice_bool(ipo_lattice, rt)
374384
if maybecondinfo === nothing
@@ -378,7 +388,23 @@ function from_interprocedural!(ipo_lattice::AbstractLattice, @nospecialize(rt),
378388
end
379389
end
380390
@assert !(rt isa InterConditional) "invalid lattice element returned from inter-procedural context"
381-
return rt
391+
if effects.const_prop_profitable_args !== NO_PROFITABLE_ARGS
392+
argsbits = 0x00
393+
fargs = arginfo.fargs
394+
if fargs !== nothing
395+
for i = 1:length(fargs)
396+
if is_const_prop_profitable_arg(effects, i)
397+
arg = fargs[i]
398+
if is_call_argument(arg, sv) && 1 slot_id(arg) 8
399+
argsbits |= 0x01 << (slot_id(arg)-1)
400+
end
401+
end
402+
end
403+
end
404+
const_prop_profitable_args = ConstPropProfitableArgs(argsbits)
405+
effects = Effects(effects; const_prop_profitable_args)
406+
end
407+
return InterproceduralResult(rt, effects)
382408
end
383409

384410
function collect_limitations!(@nospecialize(typ), sv::InferenceState)
@@ -906,8 +932,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
906932
end
907933
end
908934
# try constant prop'
909-
inf_cache = get_inference_cache(interp)
910-
inf_result = cache_lookup(typeinf_lattice(interp), mi, arginfo.argtypes, inf_cache)
935+
inf_result = cache_lookup(typeinf_lattice(interp), mi, arginfo.argtypes, get_inference_cache(interp))
911936
if inf_result === nothing
912937
# if there might be a cycle, check to make sure we don't end up
913938
# calling ourselves here.
@@ -964,7 +989,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter,
964989
return nothing
965990
end
966991
mi = mi::MethodInstance
967-
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, arginfo, sv)
992+
if !force && !const_prop_methodinstance_heuristic(interp, mi, arginfo, result.effects, sv)
968993
add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic")
969994
return nothing
970995
end
@@ -1128,8 +1153,8 @@ end
11281153
# where we would spend a lot of time, but are probably unlikely to get an improved
11291154
# result anyway.
11301155
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter,
1131-
match::MethodMatch, mi::MethodInstance, arginfo::ArgInfo, sv::InferenceState)
1132-
method = match.method
1156+
mi::MethodInstance, arginfo::ArgInfo, effects::Effects, sv::InferenceState)
1157+
method = mi.def::Method
11331158
if method.is_for_opaque_closure
11341159
# Not inlining an opaque closure can be very expensive, so be generous
11351160
# with the const-prop-ability. It is quite possible that we can't infer
@@ -1153,6 +1178,8 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter,
11531178
elseif is_stmt_noinline(flag)
11541179
# this call won't be inlined, thus this constant-prop' will most likely be unfruitful
11551180
return false
1181+
elseif any_const_prop_profitable_args(effects, arginfo.argtypes)
1182+
return true
11561183
else
11571184
code = get(code_cache(interp), mi, nothing)
11581185
if isdefined(code, :inferred)
@@ -1161,7 +1188,6 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter,
11611188
else
11621189
inferred = code.inferred
11631190
end
1164-
# TODO propagate a specific `CallInfo` that conveys information about this call
11651191
if inlining_policy(interp, inferred, NoCallInfo(), IR_FLAG_NULL, mi, arginfo.argtypes) !== nothing
11661192
return true
11671193
end
@@ -1171,6 +1197,21 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter,
11711197
return false # the cache isn't inlineable, so this constant-prop' will most likely be unfruitful
11721198
end
11731199

1200+
# check if constant information is available on any call argument that has been analyzed as
1201+
# const-prop' profitable
1202+
function any_const_prop_profitable_args(effects::Effects, argtypes::Vector{Any})
1203+
if effects.const_prop_profitable_args === NO_PROFITABLE_ARGS
1204+
return false
1205+
end
1206+
for i in 1:length(argtypes)
1207+
ai = widenconditional(argtypes[i])
1208+
if isa(ai, Const) && is_const_prop_profitable_arg(effects, i)
1209+
return true
1210+
end
1211+
end
1212+
return false
1213+
end
1214+
11741215
# This is only for use with `Conditional`.
11751216
# In general, usage of this is wrong.
11761217
ssa_def_slot(@nospecialize(arg), sv::IRCode) = nothing
@@ -1711,8 +1752,9 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
17111752
end
17121753
end
17131754
effects = Effects(effects; nonoverlayed=!overlayed)
1755+
(; rt, effects) = from_interprocedural!(ipo_lattice(interp), rt, effects, sv, arginfo, sig)
17141756
edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge)
1715-
return CallMeta(from_interprocedural!(ipo_lattice(interp), rt, sv, arginfo, sig), effects, InvokeCallInfo(match, const_result))
1757+
return CallMeta(rt, effects, InvokeCallInfo(match, const_result))
17161758
end
17171759

17181760
function invoke_rewrite(xs::Vector{Any})
@@ -1824,6 +1866,20 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
18241866
val = _pure_eval_call(f, arginfo)
18251867
return CallMeta(val === nothing ? Type : val, EFFECTS_TOTAL, MethodResultPure())
18261868
end
1869+
elseif la == 2 && istoptype(f, :Val)
1870+
# `Val` generally encodes constant information into the type domain, so there is
1871+
# generally a high profitability for constant propagation if the argument of the
1872+
# `Val` constructor is a call argument
1873+
fargs = arginfo.fargs
1874+
if fargs !== nothing
1875+
arg = arginfo.fargs[2]
1876+
if is_call_argument(arg, sv) && !isempty(sv.ssavalue_uses[sv.currpc])
1877+
if 1 slot_id(arg) 8
1878+
const_prop_profitable_args = ConstPropProfitableArgs(0x01 << (slot_id(arg)-1))
1879+
merge_effects!(interp, sv, Effects(EFFECTS_TOTAL; const_prop_profitable_args))
1880+
end
1881+
end
1882+
end
18271883
end
18281884
atype = argtypes_to_type(argtypes)
18291885
return abstract_call_gf_by_type(interp, f, arginfo, si, atype, sv, max_methods)
@@ -1858,7 +1914,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
18581914
effects = Effects(effects; nothrow=false)
18591915
end
18601916
end
1861-
rt = from_interprocedural!(ipo, rt, sv, arginfo, match.spec_types)
1917+
(; rt, effects) = from_interprocedural!(ipo, rt, effects, sv, arginfo, match.spec_types)
18621918
edge !== nothing && add_backedge!(sv, edge)
18631919
return CallMeta(rt, effects, info)
18641920
end
@@ -2226,7 +2282,7 @@ function abstract_eval_foreigncall(interp::AbstractInterpreter, e::Expr, vtypes:
22262282
override.terminates_globally ? true : effects.terminates,
22272283
override.notaskstate ? true : effects.notaskstate,
22282284
override.inaccessiblememonly ? ALWAYS_TRUE : effects.inaccessiblememonly,
2229-
effects.nonoverlayed)
2285+
effects.nonoverlayed, effects.const_prop_profitable_args)
22302286
end
22312287
return RTEffects(t, effects)
22322288
end
@@ -2513,6 +2569,15 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
25132569
@goto branch
25142570
elseif isa(stmt, GotoIfNot)
25152571
condx = stmt.cond
2572+
if is_call_argument(condx, frame)
2573+
# if this condition object is a call argument, there will be a high
2574+
# profitability for constant-propagating it, since it can shape up
2575+
# the generated code by cutting off the dead branch entirely
2576+
if 1 slot_id(condx) 8
2577+
const_prop_profitable_args = ConstPropProfitableArgs(0x01 << (slot_id(condx)-1))
2578+
merge_effects!(interp, frame, Effects(EFFECTS_TOTAL; const_prop_profitable_args))
2579+
end
2580+
end
25162581
condt = abstract_eval_value(interp, condx, currstate, frame)
25172582
if condt === Bottom
25182583
ssavaluetypes[currpc] = Bottom

base/compiler/effects.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
struct ConstPropProfitableArgs
2+
argsbits::UInt8
3+
end
4+
15
"""
26
effects::Effects
37
@@ -63,6 +67,7 @@ struct Effects
6367
notaskstate::Bool
6468
inaccessiblememonly::UInt8
6569
nonoverlayed::Bool
70+
const_prop_profitable_args::ConstPropProfitableArgs
6671
noinbounds::Bool
6772
function Effects(
6873
consistent::UInt8,
@@ -72,6 +77,7 @@ struct Effects
7277
notaskstate::Bool,
7378
inaccessiblememonly::UInt8,
7479
nonoverlayed::Bool,
80+
const_prop_profitable_args::ConstPropProfitableArgs = NO_PROFITABLE_ARGS,
7581
noinbounds::Bool = true)
7682
return new(
7783
consistent,
@@ -81,6 +87,7 @@ struct Effects
8187
notaskstate,
8288
inaccessiblememonly,
8389
nonoverlayed,
90+
const_prop_profitable_args,
8491
noinbounds)
8592
end
8693
end
@@ -98,6 +105,9 @@ const EFFECT_FREE_IF_INACCESSIBLEMEMONLY = 0x01 << 1
98105
# :inaccessiblememonly bits
99106
const INACCESSIBLEMEM_OR_ARGMEMONLY = 0x01 << 1
100107

108+
# :const_prop_profitable_args bits
109+
const NO_PROFITABLE_ARGS = ConstPropProfitableArgs(0x00)
110+
101111
const EFFECTS_TOTAL = Effects(ALWAYS_TRUE, ALWAYS_TRUE, true, true, true, ALWAYS_TRUE, true)
102112
const EFFECTS_THROWS = Effects(ALWAYS_TRUE, ALWAYS_TRUE, false, true, true, ALWAYS_TRUE, true)
103113
const EFFECTS_UNKNOWN = Effects(ALWAYS_FALSE, ALWAYS_FALSE, false, false, false, ALWAYS_FALSE, true) # unknown mostly, but it's not overlayed at least (e.g. it's not a call)
@@ -111,6 +121,7 @@ function Effects(e::Effects = EFFECTS_UNKNOWN′;
111121
notaskstate::Bool = e.notaskstate,
112122
inaccessiblememonly::UInt8 = e.inaccessiblememonly,
113123
nonoverlayed::Bool = e.nonoverlayed,
124+
const_prop_profitable_args::ConstPropProfitableArgs = e.const_prop_profitable_args,
114125
noinbounds::Bool = e.noinbounds)
115126
return Effects(
116127
consistent,
@@ -120,6 +131,7 @@ function Effects(e::Effects = EFFECTS_UNKNOWN′;
120131
notaskstate,
121132
inaccessiblememonly,
122133
nonoverlayed,
134+
const_prop_profitable_args,
123135
noinbounds)
124136
end
125137

@@ -132,6 +144,7 @@ function merge_effects(old::Effects, new::Effects)
132144
merge_effectbits(old.notaskstate, new.notaskstate),
133145
merge_effectbits(old.inaccessiblememonly, new.inaccessiblememonly),
134146
merge_effectbits(old.nonoverlayed, new.nonoverlayed),
147+
merge_effectbits(old.const_prop_profitable_args, new.const_prop_profitable_args),
135148
merge_effectbits(old.noinbounds, new.noinbounds))
136149
end
137150

@@ -142,6 +155,7 @@ function merge_effectbits(old::UInt8, new::UInt8)
142155
return old | new
143156
end
144157
merge_effectbits(old::Bool, new::Bool) = old & new
158+
merge_effectbits(old::ConstPropProfitableArgs, new::ConstPropProfitableArgs) = ConstPropProfitableArgs(old.argsbits | new.argsbits)
145159

146160
is_consistent(effects::Effects) = effects.consistent === ALWAYS_TRUE
147161
is_effect_free(effects::Effects) = effects.effect_free === ALWAYS_TRUE
@@ -177,14 +191,17 @@ is_effect_free_if_inaccessiblememonly(effects::Effects) = !iszero(effects.effect
177191

178192
is_inaccessiblemem_or_argmemonly(effects::Effects) = effects.inaccessiblememonly === INACCESSIBLEMEM_OR_ARGMEMONLY
179193

194+
is_const_prop_profitable_arg(effects::Effects, arg::Int) = !iszero(effects.const_prop_profitable_args.argsbits & (0x01 << (arg-1)))
195+
180196
function encode_effects(e::Effects)
181-
return ((e.consistent % UInt32) << 0) |
182-
((e.effect_free % UInt32) << 3) |
183-
((e.nothrow % UInt32) << 5) |
184-
((e.terminates % UInt32) << 6) |
185-
((e.notaskstate % UInt32) << 7) |
186-
((e.inaccessiblememonly % UInt32) << 8) |
187-
((e.nonoverlayed % UInt32) << 10)
197+
return ((e.consistent % UInt32) << 0) |
198+
((e.effect_free % UInt32) << 3) |
199+
((e.nothrow % UInt32) << 5) |
200+
((e.terminates % UInt32) << 6) |
201+
((e.notaskstate % UInt32) << 7) |
202+
((e.inaccessiblememonly % UInt32) << 8) |
203+
((e.nonoverlayed % UInt32) << 10) |
204+
((e.const_prop_profitable_args.argsbits % UInt32) << 11)
188205
end
189206

190207
function decode_effects(e::UInt32)
@@ -195,7 +212,8 @@ function decode_effects(e::UInt32)
195212
_Bool((e >> 6) & 0x01),
196213
_Bool((e >> 7) & 0x01),
197214
UInt8((e >> 8) & 0x03),
198-
_Bool((e >> 10) & 0x01))
215+
_Bool((e >> 10) & 0x01),
216+
ConstPropProfitableArgs(UInt8((e >> 11) & 0x7f)))
199217
end
200218

201219
struct EffectsOverride

base/compiler/inferencestate.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,3 +550,5 @@ function narguments(sv::InferenceState)
550550
nargs = length(sv.result.argtypes) - isva
551551
return nargs
552552
end
553+
is_call_argument(@nospecialize(x), sv::InferenceState) =
554+
isa(x, SlotNumber) && slot_id(x) narguments(sv)

base/compiler/utilities.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,23 @@ _topmod(m::Module) = ccall(:jl_base_relative_to, Any, (Any,), m)::Module
5050

5151
function istopfunction(@nospecialize(f), name::Symbol)
5252
tn = typeof(f).name
53-
if tn.mt.name === name
53+
mn = tn.mt.name
54+
if mn === name
5455
top = _topmod(tn.module)
5556
return isdefined(top, name) && isconst(top, name) && f === getglobal(top, name)
5657
end
5758
return false
5859
end
5960

61+
function istoptype(@nospecialize(T), name::Symbol)
62+
t = unwrap_unionall(T)
63+
if isa(t, DataType) && t.name.name === name
64+
top = _topmod(t.name.module)
65+
return isdefined(top, name) && isconst(top, name) && T === getglobal(top, name)
66+
end
67+
return false
68+
end
69+
6070
#######
6171
# AST #
6272
#######

test/broadcast.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,3 +1128,10 @@ end
11281128
import Base.Broadcast: BroadcastStyle, DefaultArrayStyle
11291129
@test Base.infer_effects(BroadcastStyle, (DefaultArrayStyle{1},DefaultArrayStyle{2},)) |>
11301130
Core.Compiler.is_foldable
1131+
1132+
function f44330(x; isreal=true)
1133+
y = similar(x)
1134+
y .= x
1135+
isreal ? real(y) : y
1136+
end
1137+
@inferred f44330(randn(ComplexF64, 1))

0 commit comments

Comments
 (0)