Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 90 additions & 35 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ end
# get `code_cache(::AbstractInterpreter)` from `state::InliningState`
code_cache(state::InliningState) = WorldView(code_cache(state.interp), state.world)

include("compiler/ssair/driver.jl")

mutable struct OptimizationState{Interp<:AbstractInterpreter}
linfo::MethodInstance
src::CodeInfo
Expand All @@ -131,15 +129,13 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter}
sptypes::Vector{VarState}
slottypes::Vector{Any}
inlining::InliningState{Interp}
cfg::Union{Nothing,CFG}
cfg::CFG
insert_coverage::Bool
end
function OptimizationState(sv::InferenceState, interp::AbstractInterpreter,
recompute_cfg::Bool=true)
function OptimizationState(sv::InferenceState, interp::AbstractInterpreter)
inlining = InliningState(sv, interp)
cfg = recompute_cfg ? nothing : sv.cfg
return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, sv.mod,
sv.sptypes, sv.slottypes, inlining, cfg, sv.insert_coverage)
sv.sptypes, sv.slottypes, inlining, sv.cfg, sv.insert_coverage)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, interp::AbstractInterpreter)
# prepare src for running optimization passes if it isn't already
Expand All @@ -162,7 +158,8 @@ function OptimizationState(linfo::MethodInstance, src::CodeInfo, interp::Abstrac
# Allow using the global MI cache, but don't track edges.
# This method is mostly used for unit testing the optimizer
inlining = InliningState(interp)
return OptimizationState(linfo, src, nothing, stmt_info, mod, sptypes, slottypes, inlining, nothing, false)
cfg = compute_basic_blocks(src.code)
return OptimizationState(linfo, src, nothing, stmt_info, mod, sptypes, slottypes, inlining, cfg, false)
end
function OptimizationState(linfo::MethodInstance, interp::AbstractInterpreter)
world = get_world_counter(interp)
Expand All @@ -171,6 +168,9 @@ function OptimizationState(linfo::MethodInstance, interp::AbstractInterpreter)
return OptimizationState(linfo, src, interp)
end


include("compiler/ssair/driver.jl")

function ir_to_codeinf!(opt::OptimizationState)
(; linfo, src) = opt
src = ir_to_codeinf!(src, opt.ir::IRCode)
Expand Down Expand Up @@ -534,7 +534,7 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
idx = 1
oldidx = 1
nstmts = length(code)
ssachangemap = labelchangemap = nothing
ssachangemap = labelchangemap = blockchangemap = nothing
prevloc = zero(eltype(ci.codelocs))
while idx <= length(code)
codeloc = codelocs[idx]
Expand All @@ -555,54 +555,93 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
if oldidx < length(labelchangemap)
labelchangemap[oldidx + 1] += 1
end
if blockchangemap === nothing
blockchangemap = fill(0, length(sv.cfg.blocks))
end
blockchangemap[block_for_inst(sv.cfg, oldidx)] += 1
idx += 1
prevloc = codeloc
end
if code[idx] isa Expr && ssavaluetypes[idx] === Union{}
if ssavaluetypes[idx] === Union{} && !(code[idx] isa Core.Const)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we want to set up a special and explicit marker type instead of Const to indicate a statement is known to be unreachable, e.g.

struct Unreachable
    stmt
    Unreachable(@nospecialize stmt) = new(stmt)
end

# Type inference should have converted any must-throw terminators to an equivalent w/o control-flow edges
@assert !isterminator(code[idx])

block = block_for_inst(sv.cfg, oldidx)
block_end = last(sv.cfg.blocks[block].stmts) + (idx - oldidx)

# Delete all successors to this basic block
for succ in sv.cfg.blocks[block].succs
preds = sv.cfg.blocks[succ].preds
deleteat!(preds, findfirst(x::Int->x==block, preds)::Int)
end
empty!(sv.cfg.blocks[block].succs)

if !(idx < length(code) && isa(code[idx + 1], ReturnNode) && !isdefined((code[idx + 1]::ReturnNode), :val))
# insert unreachable in the same basic block after the current instruction (splitting it)
insert!(code, idx + 1, ReturnNode())
insert!(codelocs, idx + 1, codelocs[idx])
insert!(ssavaluetypes, idx + 1, Union{})
insert!(stmtinfo, idx + 1, NoCallInfo())
insert!(ssaflags, idx + 1, IR_FLAG_NOTHROW)
if ssachangemap === nothing
ssachangemap = fill(0, nstmts)
end
if labelchangemap === nothing
labelchangemap = sv.insert_coverage ? fill(0, nstmts) : ssachangemap
end
if oldidx < length(ssachangemap)
ssachangemap[oldidx + 1] += 1
sv.insert_coverage && (labelchangemap[oldidx + 1] += 1)
# Any statements from here to the end of the block have been wrapped in Core.Const(...)
# by type inference (effectively deleting them). Only task left is to replace the block
# terminator with an explicit `unreachable` marker.
if block_end > idx
code[block_end] = ReturnNode()
codelocs[block_end] = codelocs[idx]
ssavaluetypes[block_end] = Union{}
stmtinfo[block_end] = NoCallInfo()
ssaflags[block_end] = IR_FLAG_NOTHROW

# Verify that type-inference did its job
if JLOptions().debug_level == 2
for i = (idx + 1):(block_end - 1)
@assert (code[i] isa Core.Const) || is_meta_expr(code[i])
end
end

idx += block_end - idx
else
insert!(code, idx + 1, ReturnNode())
insert!(codelocs, idx + 1, codelocs[idx])
insert!(ssavaluetypes, idx + 1, Union{})
insert!(stmtinfo, idx + 1, NoCallInfo())
insert!(ssaflags, idx + 1, IR_FLAG_NOTHROW)
if ssachangemap === nothing
ssachangemap = fill(0, nstmts)
end
if labelchangemap === nothing
labelchangemap = sv.insert_coverage ? fill(0, nstmts) : ssachangemap
end
if oldidx < length(ssachangemap)
ssachangemap[oldidx + 1] += 1
sv.insert_coverage && (labelchangemap[oldidx + 1] += 1)
end
if blockchangemap === nothing
blockchangemap = fill(0, length(sv.cfg.blocks))
end
blockchangemap[block] += 1
idx += 1
end
idx += 1
oldidx = last(sv.cfg.blocks[block].stmts)
end
end
idx += 1
oldidx += 1
end

cfg = sv.cfg
if ssachangemap !== nothing && labelchangemap !== nothing
renumber_ir_elements!(code, ssachangemap, labelchangemap)
cfg = nothing # recompute CFG
end
if blockchangemap !== nothing
renumber_cfg_stmts!(sv.cfg, blockchangemap)
end

for i = 1:length(code)
code[i] = process_meta!(meta, code[i])
end
strip_trailing_junk!(ci, code, stmtinfo)
strip_trailing_junk!(ci, sv.cfg, code, stmtinfo)
types = Any[]
stmts = InstructionStream(code, types, stmtinfo, codelocs, ssaflags)
if cfg === nothing
cfg = compute_basic_blocks(code)
end
# NOTE this `argtypes` contains types of slots yet: it will be modified to contain the
# types of call arguments only once `slot2reg` converts this `IRCode` to the SSA form
# and eliminates slots (see below)
argtypes = sv.slottypes
return IRCode(stmts, cfg, linetable, argtypes, meta, sv.sptypes)
return IRCode(stmts, sv.cfg, linetable, argtypes, meta, sv.sptypes)
end

function process_meta!(meta::Vector{Expr}, @nospecialize stmt)
Expand Down Expand Up @@ -763,8 +802,8 @@ function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeI
return maxcost
end

function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int})
return renumber_ir_elements!(body, ssachangemap, ssachangemap)
function renumber_ir_elements!(body::Vector{Any}, cfg::Union{CFG,Nothing}, ssachangemap::Vector{Int})
return renumber_ir_elements!(body, cfg, ssachangemap, ssachangemap)
end

function cumsum_ssamap!(ssachangemap::Vector{Int})
Expand Down Expand Up @@ -847,3 +886,19 @@ function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int}, lab
end
end
end

function renumber_cfg_stmts!(cfg::CFG, blockchangemap::Vector{Int})
any_change = cumsum_ssamap!(blockchangemap)
any_change || return

last_end = 0
for i = 1:length(cfg.blocks)
old_range = cfg.blocks[i].stmts
new_range = StmtRange(first(old_range) + ((i > 1) ? blockchangemap[i - 1] : 0),
last(old_range) + blockchangemap[i])
cfg.blocks[i] = BasicBlock(cfg.blocks[i], new_range)
if i <= length(cfg.index)
cfg.index[i] = cfg.index[i] + blockchangemap[i]
end
end
end
8 changes: 7 additions & 1 deletion base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ function rename_uses!(ir::IRCode, ci::CodeInfo, idx::Int, @nospecialize(stmt), r
return fixemup!(stmt::UnoptSlot->true, stmt::UnoptSlot->renames[slot_id(stmt)], ir, ci, idx, stmt)
end

function strip_trailing_junk!(ci::CodeInfo, code::Vector{Any}, info::Vector{CallInfo})
function strip_trailing_junk!(ci::CodeInfo, cfg::CFG, code::Vector{Any}, info::Vector{CallInfo})
# Remove `nothing`s at the end, we don't handle them well
# (we expect the last instruction to be a terminator)
ssavaluetypes = ci.ssavaluetypes::Vector{Any}
Expand All @@ -207,6 +207,12 @@ function strip_trailing_junk!(ci::CodeInfo, code::Vector{Any}, info::Vector{Call
push!(codelocs, 0)
push!(info, NoCallInfo())
push!(ssaflags, IR_FLAG_NOTHROW)

# Update CFG to include appended terminator
old_range = cfg.blocks[end].stmts
new_range = StmtRange(first(old_range), last(old_range) + 1)
cfg.blocks[end] = BasicBlock(cfg.blocks[end], new_range)
(length(cfg.index) == length(cfg.blocks)) && (cfg.index[end] += 1)
end
nothing
end
Expand Down
47 changes: 44 additions & 3 deletions base/compiler/ssair/verify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,31 @@ end
function verify_ir(ir::IRCode, print::Bool=true,
allow_frontend_forms::Bool=false,
𝕃ₒ::AbstractLattice = SimpleInferenceLattice.instance)
# Verify CFG graph. Must be well formed to construct domtree
if !(length(ir.cfg.blocks) - 1 <= length(ir.cfg.index) <= length(ir.cfg.blocks))
@verify_error "CFG index length ($(length(ir.cfg.index))) does not correspond to # of blocks $(length(ir.cfg.blocks))"
error("")
end
if length(ir.stmts.stmt) != length(ir.stmts)
@verify_error "IR stmt length is invalid $(length(ir.stmts.stmt)) / $(length(ir.stmts))"
error("")
end
if length(ir.stmts.type) != length(ir.stmts)
@verify_error "IR type length is invalid $(length(ir.stmts.type)) / $(length(ir.stmts))"
error("")
end
if length(ir.stmts.info) != length(ir.stmts)
@verify_error "IR info length is invalid $(length(ir.stmts.info)) / $(length(ir.stmts))"
error("")
end
if length(ir.stmts.line) != length(ir.stmts)
@verify_error "IR line length is invalid $(length(ir.stmts.line)) / $(length(ir.stmts))"
error("")
end
if length(ir.stmts.flag) != length(ir.stmts)
@verify_error "IR flag length is invalid $(length(ir.stmts.flag)) / $(length(ir.stmts))"
error("")
end
# For now require compact IR
# @assert isempty(ir.new_nodes)
# Verify CFG
Expand Down Expand Up @@ -125,6 +150,18 @@ function verify_ir(ir::IRCode, print::Bool=true,
error("")
end
end
if !(1 <= first(block.stmts) <= length(ir.stmts))
@verify_error "First statement of BB $idx ($(first(block.stmts))) out of bounds for IR (length=$(length(ir.stmts)))"
error("")
end
if !(1 <= last(block.stmts) <= length(ir.stmts))
@verify_error "Last statement of BB $idx ($(last(block.stmts))) out of bounds for IR (length=$(length(ir.stmts)))"
error("")
end
if idx <= length(ir.cfg.index) && last(block.stmts) + 1 != ir.cfg.index[idx]
@verify_error "End of BB $idx ($(last(block.stmts))) is not one less than CFG index ($(ir.cfg.index[idx]))"
error("")
end
end
# Verify statements
domtree = construct_domtree(ir.cfg.blocks)
Expand All @@ -145,7 +182,7 @@ function verify_ir(ir::IRCode, print::Bool=true,
end
elseif isa(terminator, GotoNode)
if length(block.succs) != 1 || block.succs[1] != terminator.label
@verify_error "Block $idx successors ($(block.succs)), does not match GotoNode terminator"
@verify_error "Block $idx successors ($(block.succs)), does not match GotoNode terminator ($(terminator.label))"
error("")
end
elseif isa(terminator, GotoIfNot)
Expand All @@ -167,8 +204,8 @@ function verify_ir(ir::IRCode, print::Bool=true,
if length(block.succs) != 1 || block.succs[1] != idx + 1
# As a special case, we allow extra statements in the BB of an :enter
# statement, until we can do proper CFG manipulations during compaction.
for idx in first(block.stmts):last(block.stmts)
stmt = ir[SSAValue(idx)][:stmt]
for stmt_idx in first(block.stmts):last(block.stmts)
stmt = ir[SSAValue(stmt_idx)][:stmt]
if isexpr(stmt, :enter)
terminator = stmt
@goto enter_check
Expand All @@ -188,6 +225,10 @@ function verify_ir(ir::IRCode, print::Bool=true,
end
end
end
if length(ir.stmts) != last(ir.cfg.blocks[end].stmts)
@verify_error "End of last BB $(last(ir.cfg.blocks[end].stmts)) does not match last IR statement $(length(ir.stmts))"
error("")
end
lastbb = 0
is_phinode_block = false
firstidx = 1
Expand Down
31 changes: 21 additions & 10 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,9 @@ function finish(me::InferenceState, interp::AbstractInterpreter)
# annotate fulltree with type information,
# either because we are the outermost code, or we might use this later
doopt = (me.cached || me.parent !== nothing)
recompute_cfg = type_annotate!(interp, me, doopt)
type_annotate!(interp, me, doopt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can eliminate anything related to any_unreachable within type_annotate! now.

if doopt && may_optimize(interp)
me.result.src = OptimizationState(me, interp, recompute_cfg)
me.result.src = OptimizationState(me, interp)
else
me.result.src = me.src::CodeInfo # stash a convenience copy of the code (e.g. for reflection)
end
Expand Down Expand Up @@ -713,20 +713,31 @@ function type_annotate!(interp::AbstractInterpreter, sv::InferenceState, run_opt
# 3. mark unreached statements for a bulk code deletion (see issue #7836)
# 4. widen slot wrappers (`Conditional` and `MustAlias`) and remove `NOT_FOUND` from `ssavaluetypes`
# NOTE because of this, `was_reached` will no longer be available after this point
# 5. eliminate GotoIfNot if either branch target is unreachable
# 5. eliminate GotoIfNot if either or both branches are statically unreachable
changemap = nothing # initialized if there is any dead region
for i = 1:nstmt
expr = stmts[i]
if was_reached(sv, i)
if run_optimizer
if isa(expr, GotoIfNot) && widenconst(argextype(expr.cond, src, sv.sptypes)) === Bool
if isa(expr, GotoIfNot)
# 5: replace this live GotoIfNot with:
# - GotoNode if the fallthrough target is unreachable
# - no-op if the branch target is unreachable
if !was_reached(sv, i+1)
expr = GotoNode(expr.dest)
elseif !was_reached(sv, expr.dest)
expr = nothing
# - no-op if :nothrow and the branch target is unreachable
# - cond if :nothrow and both targets are unreachable
# - typeassert if must-throw
if widenconst(argextype(expr.cond, src, sv.sptypes)) === Bool
block = block_for_inst(sv.cfg, i)
if !was_reached(sv, i+1)
cfg_delete_edge!(sv.cfg, block, block + 1)
expr = GotoNode(expr.dest)
elseif !was_reached(sv, expr.dest)
cfg_delete_edge!(sv.cfg, block, block_for_inst(sv.cfg, expr.dest))
expr = nothing
end
elseif ssavaluetypes[i] === Bottom
block = block_for_inst(sv.cfg, i)
cfg_delete_edge!(sv.cfg, block, block + 1)
cfg_delete_edge!(sv.cfg, block, block_for_inst(sv.cfg, expr.dest))
expr = Expr(:call, Core.typeassert, expr.cond, Bool)
end
end
end
Expand Down
3 changes: 3 additions & 0 deletions test/compiler/interpreter_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ let m = Meta.@lower 1 + 1
]
nstmts = length(src.code)
src.ssavaluetypes = Any[ Any for _ = 1:nstmts ]
src.ssaflags = fill(UInt8(0x00), nstmts)
src.codelocs = fill(Int32(1), nstmts)
src.inferred = true
Core.Compiler.verify_ir(Core.Compiler.inflate_ir(src))
Expand Down Expand Up @@ -61,6 +62,7 @@ let m = Meta.@lower 1 + 1
]
nstmts = length(src.code)
src.ssavaluetypes = Any[ Any for _ = 1:nstmts ]
src.ssaflags = fill(UInt8(0x00), nstmts)
src.codelocs = fill(Int32(1), nstmts)
src.inferred = true
Core.Compiler.verify_ir(Core.Compiler.inflate_ir(src))
Expand Down Expand Up @@ -98,6 +100,7 @@ let m = Meta.@lower 1 + 1
]
nstmts = length(src.code)
src.ssavaluetypes = Any[ Any for _ = 1:nstmts ]
src.ssaflags = fill(UInt8(0x00), nstmts)
src.codelocs = fill(Int32(1), nstmts)
src.inferred = true
Core.Compiler.verify_ir(Core.Compiler.inflate_ir(src))
Expand Down
1 change: 1 addition & 0 deletions test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,7 @@ function each_stmt_a_bb(stmts, preds, succs)
empty!(ir.stmts.line); append!(ir.stmts.line, [Int32(0) for _ = 1:length(stmts)])
empty!(ir.stmts.info); append!(ir.stmts.info, [NoCallInfo() for _ = 1:length(stmts)])
empty!(ir.cfg.blocks); append!(ir.cfg.blocks, [BasicBlock(StmtRange(i, i), preds[i], succs[i]) for i = 1:length(stmts)])
empty!(ir.cfg.index); append!(ir.cfg.index, [i for i = 2:length(stmts)])
Core.Compiler.verify_ir(ir)
return ir
end
Expand Down
Loading