Skip to content

Commit e3b5d1d

Browse files
aviateskvtjnash
andcommitted
inference: propagate variable changes to all exception frames #42081
cherry-picked from #42081 Co-Authored-By: Jameson Nash <[email protected]>
1 parent b735af8 commit e3b5d1d

File tree

3 files changed

+161
-43
lines changed

3 files changed

+161
-43
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,19 +1331,18 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
13311331
n = frame.nstmts
13321332
while frame.pc´´ <= n
13331333
# make progress on the active ip set
1334-
local pc::Int = frame.pc´´ # current program-counter
1334+
local pc::Int = frame.pc´´
13351335
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
13361336
#print(pc,": ",s[pc],"\n")
13371337
local pc´::Int = pc + 1 # next program-counter (after executing instruction)
13381338
if pc == frame.pc´´
1339-
# need to update pc´´ to point at the new lowest instruction in W
1340-
min_pc = _bits_findnext(W.bits, pc + 1)
1341-
frame.pc´´ = min_pc == -1 ? n + 1 : min_pc
1339+
# want to update pc´´ to point at the new lowest instruction in W
1340+
frame.pc´´ = pc´
13421341
end
13431342
delete!(W, pc)
13441343
frame.currpc = pc
1345-
frame.cur_hand = frame.handler_at[pc]
1346-
frame.stmt_edges[pc] === nothing || empty!(frame.stmt_edges[pc])
1344+
edges = frame.stmt_edges[pc]
1345+
edges === nothing || empty!(edges)
13471346
frame.stmt_info[pc] = nothing
13481347
stmt = frame.src.code[pc]
13491348
changes = s[pc]::VarTable
@@ -1377,7 +1376,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
13771376
pc´ = l
13781377
else
13791378
# general case
1380-
frame.handler_at[l] = frame.cur_hand
13811379
changes_else = changes
13821380
if isa(condt, Conditional)
13831381
if condt.elsetype !== Any && condt.elsetype !== changes[slot_id(condt.var)]
@@ -1425,7 +1423,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14251423
end
14261424
elseif hd === :enter
14271425
l = stmt.args[1]::Int
1428-
frame.cur_hand = Pair{Any,Any}(l, frame.cur_hand)
14291426
# propagate type info to exception handler
14301427
old = s[l]
14311428
newstate_catch = stupdate!(old, changes)
@@ -1437,11 +1434,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14371434
s[l] = newstate_catch
14381435
end
14391436
typeassert(s[l], VarTable)
1440-
frame.handler_at[l] = frame.cur_hand
14411437
elseif hd === :leave
1442-
for i = 1:((stmt.args[1])::Int)
1443-
frame.cur_hand = (frame.cur_hand::Pair{Any,Any}).second
1444-
end
14451438
else
14461439
if hd === :(=)
14471440
t = abstract_eval_statement(interp, stmt.args[2], changes, frame)
@@ -1467,16 +1460,22 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14671460
frame.src.ssavaluetypes[pc] = t
14681461
end
14691462
end
1470-
if frame.cur_hand !== nothing && isa(changes, StateUpdate)
1471-
# propagate new type info to exception handler
1472-
# the handling for Expr(:enter) propagates all changes from before the try/catch
1473-
# so this only needs to propagate any changes
1474-
l = frame.cur_hand.first::Int
1475-
if stupdate1!(s[l]::VarTable, changes::StateUpdate) !== false
1476-
if l < frame.pc´´
1477-
frame.pc´´ = l
1463+
if isa(changes, StateUpdate)
1464+
let cur_hand = frame.handler_at[pc], l, enter
1465+
while cur_hand != 0
1466+
enter = frame.src.code[cur_hand]
1467+
l = (enter::Expr).args[1]::Int
1468+
# propagate new type info to exception handler
1469+
# the handling for Expr(:enter) propagates all changes from before the try/catch
1470+
# so this only needs to propagate any changes
1471+
if stupdate1!(s[l]::VarTable, changes::StateUpdate) !== false
1472+
if l < frame.pc´´
1473+
frame.pc´´ = l
1474+
end
1475+
push!(W, l)
1476+
end
1477+
cur_hand = frame.handler_at[cur_hand]
14781478
end
1479-
push!(W, l)
14801479
end
14811480
end
14821481
end
@@ -1489,7 +1488,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14891488
end
14901489

14911490
pc´ > n && break # can't proceed with the fast-path fall-through
1492-
frame.handler_at[pc´] = frame.cur_hand
14931491
newstate = stupdate!(s[pc´], changes)
14941492
if isa(stmt, GotoNode) && frame.pc´´ < pc´
14951493
# if we are processing a goto node anyways,
@@ -1500,7 +1498,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
15001498
s[pc´] = newstate
15011499
end
15021500
push!(W, pc´)
1503-
pc = frame.pc´´
1501+
break
15041502
elseif newstate !== nothing
15051503
s[pc´] = newstate
15061504
pc = pc´
@@ -1510,6 +1508,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
15101508
break
15111509
end
15121510
end
1511+
frame.pc´´ = _bits_findnext(W.bits, frame.pc´´)::Int # next program-counter
15131512
end
15141513
frame.dont_work_on_me = false
15151514
nothing

base/compiler/inferencestate.jl

Lines changed: 94 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ mutable struct InferenceState
2828
pc´´::LineNum
2929
nstmts::Int
3030
# current exception handler info
31-
cur_hand #::Union{Nothing, Pair{LineNum, prev_handler}}
32-
handler_at::Vector{Any}
33-
n_handlers::Int
31+
handler_at::Vector{LineNum}
3432
# ssavalue sparsity and restart info
3533
ssavalue_uses::Vector{BitSet}
3634
throw_blocks::BitSet
@@ -57,8 +55,9 @@ mutable struct InferenceState
5755
function InferenceState(result::InferenceResult, src::CodeInfo,
5856
cached::Bool, interp::AbstractInterpreter)
5957
linfo = result.linfo
58+
def = linfo.def
6059
code = src.code::Array{Any,1}
61-
toplevel = !isa(linfo.def, Method)
60+
toplevel = !isa(def, Method)
6261

6362
sp = sptypes_from_meth_instance(linfo::MethodInstance)
6463

@@ -87,30 +86,21 @@ mutable struct InferenceState
8786
throw_blocks = find_throw_blocks(code)
8887

8988
# exception handlers
90-
cur_hand = nothing
91-
handler_at = Any[ nothing for i=1:n ]
92-
n_handlers = 0
93-
94-
W = BitSet()
95-
push!(W, 1) #initial pc to visit
96-
97-
if !toplevel
98-
meth = linfo.def
99-
inmodule = meth.module
100-
else
101-
inmodule = linfo.def::Module
102-
end
89+
ip = BitSet()
90+
handler_at = compute_trycatch(src.code, ip)
91+
push!(ip, 1)
10392

93+
mod = isa(def, Method) ? def.module : def
10494
valid_worlds = WorldRange(src.min_world,
10595
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
96+
10697
frame = new(
10798
InferenceParams(interp), result, linfo,
108-
sp, slottypes, inmodule, 0,
99+
sp, slottypes, mod, 0,
109100
IdSet{InferenceState}(), IdSet{InferenceState}(),
110101
src, get_world_counter(interp), valid_worlds,
111102
nargs, s_types, s_edges, stmt_info,
112-
Union{}, W, 1, n,
113-
cur_hand, handler_at, n_handlers,
103+
Union{}, ip, 1, n, handler_at,
114104
ssavalue_uses, throw_blocks,
115105
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
116106
Vector{InferenceState}(), # callers_in_cycle
@@ -124,6 +114,90 @@ mutable struct InferenceState
124114
end
125115
end
126116

117+
function compute_trycatch(code::Vector{Any}, ip::BitSet)
118+
# The goal initially is to record the frame like this for the state at exit:
119+
# 1: (enter 3) # == 0
120+
# 3: (expr) # == 1
121+
# 3: (leave 1) # == 1
122+
# 4: (expr) # == 0
123+
# then we can find all trys by walking backwards from :enter statements,
124+
# and all catches by looking at the statement after the :enter
125+
n = length(code)
126+
empty!(ip)
127+
ip.offset = 0 # for _bits_findnext
128+
push!(ip, n + 1)
129+
handler_at = fill(0, n)
130+
131+
# start from all :enter statements and record the location of the try
132+
for pc = 1:n
133+
stmt = code[pc]
134+
if isexpr(stmt, :enter)
135+
l = stmt.args[1]::Int
136+
handler_at[pc + 1] = pc
137+
push!(ip, pc + 1)
138+
handler_at[l] = pc
139+
push!(ip, l)
140+
end
141+
end
142+
143+
# now forward those marks to all :leave statements
144+
pc´´ = 0
145+
while true
146+
# make progress on the active ip set
147+
pc = _bits_findnext(ip.bits, pc´´)::Int
148+
pc > n && break
149+
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
150+
pc´ = pc + 1 # next program-counter (after executing instruction)
151+
if pc == pc´´
152+
pc´´ = pc´
153+
end
154+
delete!(ip, pc)
155+
cur_hand = handler_at[pc]
156+
@assert cur_hand != 0 "unbalanced try/catch"
157+
stmt = code[pc]
158+
if isa(stmt, GotoNode)
159+
pc´ = stmt.label
160+
elseif isa(stmt, GotoIfNot)
161+
l = stmt.dest::Int
162+
if handler_at[l] != cur_hand
163+
@assert handler_at[l] == 0 "unbalanced try/catch"
164+
handler_at[l] = cur_hand
165+
if l < pc´´
166+
pc´´ = l
167+
end
168+
push!(ip, l)
169+
end
170+
elseif isa(stmt, ReturnNode)
171+
@assert !isdefined(stmt, :val) "unbalanced try/catch"
172+
break
173+
elseif isa(stmt, Expr)
174+
head = stmt.head
175+
if head === :enter
176+
cur_hand = pc
177+
elseif head === :leave
178+
l = stmt.args[1]::Int
179+
for i = 1:l
180+
cur_hand = handler_at[cur_hand]
181+
end
182+
cur_hand == 0 && break
183+
end
184+
end
185+
186+
pc´ > n && break # can't proceed with the fast-path fall-through
187+
if handler_at[pc´] != cur_hand
188+
@assert handler_at[pc´] == 0 "unbalanced try/catch"
189+
handler_at[pc´] = cur_hand
190+
elseif !in(pc´, ip)
191+
break # already visited
192+
end
193+
pc = pc´
194+
end
195+
end
196+
197+
@assert first(ip) == n + 1
198+
return handler_at
199+
end
200+
127201
method_table(interp::AbstractInterpreter, sv::InferenceState) = sv.method_table
128202

129203
function InferenceState(result::InferenceResult, cached::Bool, interp::AbstractInterpreter)

test/compiler/inference.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3002,3 +3002,48 @@ Base.return_types((Union{Int,Nothing},)) do x
30023002
end
30033003
x
30043004
end == [Int]
3005+
3006+
# issue #42022
3007+
let x = Tuple{Int,Any}[
3008+
#= 1=# (0, Expr(:(=), Core.SlotNumber(3), 1))
3009+
#= 2=# (0, Expr(:enter, 18))
3010+
#= 3=# (2, Expr(:(=), Core.SlotNumber(3), 2.0))
3011+
#= 4=# (2, Expr(:enter, 12))
3012+
#= 5=# (4, Expr(:(=), Core.SlotNumber(3), '3'))
3013+
#= 6=# (4, Core.GotoIfNot(Core.SlotNumber(2), 9))
3014+
#= 7=# (4, Expr(:leave, 2))
3015+
#= 8=# (0, Core.ReturnNode(1))
3016+
#= 9=# (4, Expr(:call, GlobalRef(Main, :throw)))
3017+
#=10=# (4, Expr(:leave, 1))
3018+
#=11=# (2, Core.GotoNode(16))
3019+
#=12=# (4, Expr(:leave, 1))
3020+
#=13=# (2, Expr(:(=), Core.SlotNumber(4), Expr(:the_exception)))
3021+
#=14=# (2, Expr(:call, GlobalRef(Main, :rethrow)))
3022+
#=15=# (2, Expr(:pop_exception, Core.SSAValue(4)))
3023+
#=16=# (2, Expr(:leave, 1))
3024+
#=17=# (0, Core.GotoNode(22))
3025+
#=18=# (2, Expr(:leave, 1))
3026+
#=19=# (0, Expr(:(=), Core.SlotNumber(5), Expr(:the_exception)))
3027+
#=20=# (0, nothing)
3028+
#=21=# (0, Expr(:pop_exception, Core.SSAValue(2)))
3029+
#=22=# (0, Core.ReturnNode(Core.SlotNumber(3)))
3030+
]
3031+
handler_at = Core.Compiler.compute_trycatch(last.(x), Core.Compiler.BitSet())
3032+
@test handler_at == first.(x)
3033+
end
3034+
3035+
@test only(Base.return_types((Bool,)) do y
3036+
x = 1
3037+
try
3038+
x = 2.0
3039+
try
3040+
x = '3'
3041+
y ? (return 1) : throw()
3042+
catch ex1
3043+
rethrow()
3044+
end
3045+
catch ex2
3046+
nothing
3047+
end
3048+
return x
3049+
end) === Union{Int, Float64, Char}

0 commit comments

Comments
 (0)