Skip to content

Commit 8320fcc

Browse files
authored
skip inferring calls that lead to throw (#35982)
1 parent 3af7ec8 commit 8320fcc

File tree

8 files changed

+91
-9
lines changed

8 files changed

+91
-9
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ end
3636

3737
function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState,
3838
max_methods::Int = InferenceParams(interp).MAX_METHODS)
39+
if sv.currpc in sv.throw_blocks
40+
return Any
41+
end
3942
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
4043
if mt === nothing
4144
add_remark!(interp, sv, "Could not identify method table for call")

base/compiler/inferencestate.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ mutable struct InferenceState
3131
n_handlers::Int
3232
# ssavalue sparsity and restart info
3333
ssavalue_uses::Vector{BitSet}
34+
throw_blocks::BitSet
3435

3536
cycle_backedges::Vector{Tuple{InferenceState, LineNum}} # call-graph backedges connecting from callee to caller
3637
callers_in_cycle::Vector{InferenceState}
@@ -80,6 +81,7 @@ mutable struct InferenceState
8081
s_types[1] = s_argtypes
8182

8283
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
84+
throw_blocks = find_throw_blocks(code)
8385

8486
# exception handlers
8587
cur_hand = nothing
@@ -106,7 +108,7 @@ mutable struct InferenceState
106108
nargs, s_types, s_edges,
107109
Union{}, W, 1, n,
108110
cur_hand, handler_at, n_handlers,
109-
ssavalue_uses,
111+
ssavalue_uses, throw_blocks,
110112
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
111113
Vector{InferenceState}(), # callers_in_cycle
112114
#=parent=#nothing,

base/compiler/optimize.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ plus_saturate(x::Int, y::Int) = max(x, y, x+y)
285285
# known return type
286286
isknowntype(@nospecialize T) = (T === Union{}) || isa(T, Const) || isconcretetype(widenconst(T))
287287

288-
function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any}, params::OptimizationParams)
288+
function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any}, params::OptimizationParams, error_path::Bool = false)
289289
head = ex.head
290290
if is_meta_expr_head(head)
291291
return 0
@@ -320,7 +320,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
320320
return 0
321321
elseif (f === Main.Core.arrayref || f === Main.Core.const_arrayref) && length(ex.args) >= 3
322322
atyp = argextype(ex.args[3], src, sptypes, slottypes)
323-
return isknowntype(atyp) ? 4 : params.inline_nonleaf_penalty
323+
return isknowntype(atyp) ? 4 : error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
324324
end
325325
fidx = find_tfunc(f)
326326
if fidx === nothing
@@ -330,7 +330,11 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
330330
end
331331
return T_FFUNC_COST[fidx]
332332
end
333-
return params.inline_nonleaf_penalty
333+
extyp = line == -1 ? Any : src.ssavaluetypes[line]
334+
if extyp === Union{}
335+
return 0
336+
end
337+
return error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
334338
elseif head === :foreigncall || head === :invoke
335339
# Calls whose "return type" is Union{} do not actually return:
336340
# they are errors. Since these are not part of the typical
@@ -347,7 +351,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
347351
end
348352
a = ex.args[2]
349353
if a isa Expr
350-
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, slottypes, params))
354+
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, slottypes, params, error_path))
351355
end
352356
return cost
353357
elseif head === :copyast
@@ -365,10 +369,11 @@ end
365369
function inline_worthy(body::Array{Any,1}, src::CodeInfo, sptypes::Vector{Any}, slottypes::Vector{Any},
366370
params::OptimizationParams, cost_threshold::Integer=params.inline_cost_threshold)
367371
bodycost::Int = 0
372+
throw_blocks = find_throw_blocks(body)
368373
for line = 1:length(body)
369374
stmt = body[line]
370375
if stmt isa Expr
371-
thiscost = statement_cost(stmt, line, src, sptypes, slottypes, params)::Int
376+
thiscost = statement_cost(stmt, line, src, sptypes, slottypes, params, line in throw_blocks)::Int
372377
elseif stmt isa GotoNode
373378
# loops are generally always expensive
374379
# but assume that forward jumps are already counted for from

base/compiler/ssair/inlining.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -997,7 +997,9 @@ end
997997
function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)
998998
# todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie)
999999
todo = Any[]
1000+
skip = find_throw_blocks(ir.stmts.inst, RefValue(ir))
10001001
for idx in 1:length(ir.stmts)
1002+
idx in skip && continue
10011003
r = process_simple!(ir, idx, sv.params, sv.world)
10021004
r === nothing && continue
10031005

base/compiler/types.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ struct OptimizationParams
4545
inline_cost_threshold::Int # number of CPU cycles beyond which it's not worth inlining
4646
inline_nonleaf_penalty::Int # penalty for dynamic dispatch
4747
inline_tupleret_bonus::Int # extra willingness for non-isbits tuple return types
48+
inline_error_path_cost::Int # cost of (un-optimized) calls in blocks that throw
4849

4950
# Duplicating for now because optimizer inlining requires it.
5051
# Keno assures me this will be removed in the near future
@@ -57,6 +58,7 @@ struct OptimizationParams
5758
inline_cost_threshold::Int = 100,
5859
inline_nonleaf_penalty::Int = 1000,
5960
inline_tupleret_bonus::Int = 400,
61+
inline_error_path_cost::Int = 20,
6062
max_methods::Int = 3,
6163
tuple_splat::Int = 32,
6264
union_splitting::Int = 4,
@@ -66,6 +68,7 @@ struct OptimizationParams
6668
inline_cost_threshold,
6769
inline_nonleaf_penalty,
6870
inline_tupleret_bonus,
71+
inline_error_path_cost,
6972
max_methods,
7073
tuple_splat,
7174
union_splitting,

base/compiler/utilities.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,72 @@ function find_ssavalue_uses(e::Expr, uses::Vector{BitSet}, line::Int)
219219
end
220220
end
221221

222+
function is_throw_call(e::Expr)
223+
if e.head === :call
224+
f = e.args[1]
225+
if isa(f, GlobalRef)
226+
ff = abstract_eval_global(f.mod, f.name)
227+
if isa(ff, Const) && ff.val === Core.throw
228+
return true
229+
end
230+
end
231+
end
232+
return false
233+
end
234+
235+
function find_throw_blocks(code::Vector{Any}, ir = RefValue{IRCode}())
236+
stmts = BitSet()
237+
n = length(code)
238+
try_depth = 0
239+
for i in n:-1:1
240+
s = code[i]
241+
if isa(s, Expr)
242+
if s.head === :enter
243+
try_depth -= 1
244+
elseif s.head === :leave
245+
try_depth += (s.args[1]::Int)
246+
elseif s.head === :gotoifnot
247+
tgt = s.args[2]::Int
248+
if i+1 in stmts && tgt in stmts
249+
push!(stmts, i)
250+
end
251+
elseif s.head === :return
252+
elseif is_throw_call(s) || s.head === :unreachable
253+
if try_depth == 0
254+
push!(stmts, i)
255+
end
256+
elseif i+1 in stmts
257+
push!(stmts, i)
258+
end
259+
elseif isa(s, ReturnNode)
260+
if try_depth == 0 && !isdefined(s, :val)
261+
push!(stmts, i)
262+
end
263+
elseif isa(s, GotoNode)
264+
tgt = s.label
265+
if isassigned(ir)
266+
tgt = first(ir[].cfg.blocks[tgt].stmts)
267+
end
268+
if tgt in stmts
269+
push!(stmts, i)
270+
end
271+
elseif isa(s, GotoIfNot)
272+
if i+1 in stmts
273+
tgt = s.dest::Int
274+
if isassigned(ir)
275+
tgt = first(ir[].cfg.blocks[tgt].stmts)
276+
end
277+
if tgt in stmts
278+
push!(stmts, i)
279+
end
280+
end
281+
elseif i+1 in stmts
282+
push!(stmts, i)
283+
end
284+
end
285+
return stmts
286+
end
287+
222288
# using a function to ensure we can infer this
223289
@inline slot_id(s) = isa(s, SlotNumber) ? (s::SlotNumber).id :
224290
isa(s, Argument) ? (s::Argument).n : (s::TypedSlot).id

base/ntuple.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ julia> ntuple(i -> 2*i, 4)
1414
(2, 4, 6, 8)
1515
```
1616
"""
17-
function ntuple(f::F, n::Integer) where F
17+
@inline function ntuple(f::F, n::Integer) where F
18+
# marked inline since this benefits from constant propagation of `n`
1819
t = n == 0 ? () :
1920
n == 1 ? (f(1),) :
2021
n == 2 ? (f(1), f(2)) :

test/stacktraces.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ end
8484

8585
module inlined_test
8686
using Test
87-
@inline g(x) = (y = throw("a"); y) # the inliner does not insert the proper markers when inlining a single expression
88-
@inline h(x) = (y = g(x); y) # this test could be extended to check for that if we switch to linear representation
87+
@inline g(x) = (x == 3 && throw("a"); x)
88+
@inline h(x) = (x == 3 && g(x); x)
8989
f(x) = (y = h(x); y)
9090
trace = (try; f(3); catch; stacktrace(catch_backtrace()); end)[1:3]
9191
can_inline = Bool(Base.JLOptions().can_inline)

0 commit comments

Comments
 (0)