Skip to content

Commit f4e0646

Browse files
committed
wip: report infinite iteration
WIP because I found this is better to be generalized to a control flow analysis that checks for "never reaching exit points". For As an example, `sum(a for a in NeverTeminate(::Int))` will come down to `Base._fold_impl`, which directly uses `iterate(::NeverTeminate, [state])`, against which the current implementation based on the iteration protocol can't report an error. > Adapated from https://github.com/JuliaLang/julia/blob/24d9eab45632bdb3120c9e664503745eb58aa2d6/base/reduce.jl#L53-L65 ```julia function _foldl_impl(op::OP, init, itr) where {OP} # Unroll the while loop once; if init is known, the call to op may # be evaluated at compile time y = iterate(itr) y === nothing && return init v = op(init, y[1]) while true y = iterate(itr, y[2]) y === nothing && break v = op(v, y[1]) end return v end ``` > In a lowered representation ```julia CodeInfo( 1 ─ Core.NewvarNode(:(v)) │ y = Base.iterate(itr) │ %3 = y === Base.nothing └── goto #3 if not %3 2 ─ return init 3 ─ %6 = Base.getindex(y, 1) └── v = (op)(init, %6) 4 ┄ goto #8 if not true 5 ─ %9 = Base.getindex(y, 2) │ y = Base.iterate(itr, %9) │ %11 = y === Base.nothing └── goto #7 if not %11 6 ─ goto #8 7 ─ %14 = v │ %15 = Base.getindex(y, 1) │ v = (op)(%14, %15) └── goto #4 8 ┄ return v ) ``` We can report infinite iteration by checking both `goto #3 if not %3` and `goto #7 if not %11` never happens and thus this function never returns. --- - closes #135 - closes #179
1 parent fdff929 commit f4e0646

File tree

4 files changed

+214
-0
lines changed

4 files changed

+214
-0
lines changed

src/JET.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ import .CC:
4747
abstract_eval_special_value,
4848
abstract_eval_value,
4949
abstract_eval_statement,
50+
typeinf_local,
51+
abstract_iteration,
5052
# typeinfer.jl
5153
typeinf,
5254
_typeinf,
@@ -108,6 +110,7 @@ import .CC:
108110
BasicBlock,
109111
slot_id,
110112
widenconst,
113+
widenconditional,
111114
,
112115
is_throw_call,
113116
tmerge,
@@ -402,6 +405,13 @@ macro jetconfigurable(funcdef)
402405
end
403406
const _JET_CONFIGURATIONS = Dict{Symbol,Symbol}()
404407

408+
macro get!(x, key, default)
409+
return quote
410+
x, key = $(esc(x)), $(esc(key))
411+
haskey(x, key) ? x[key] : (x[key] = $(esc(default)))
412+
end
413+
end
414+
405415
# utils
406416
# -----
407417

src/abstractinterpretation.jl

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,120 @@ function CC.abstract_eval_statement(interp::JETInterpreter, @nospecialize(e), vt
428428
return @invoke abstract_eval_statement(interp::AbstractInterpreter, e, vtypes::VarTable, sv::InferenceState)
429429
end
430430

431+
# TODO generalize the following check to "this function will never return"
432+
433+
# in this overload we will work on pre-optimization state
434+
function CC.typeinf_local(interp::JETInterpreter, frame::InferenceState)
435+
@invoke typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
436+
maybe_report_infinite_iterations!(interp, frame)
437+
end
438+
439+
function maybe_report_infinite_iterations!(interp::JETInterpreter, frame::InferenceState)
440+
src = frame.src
441+
infos = maybe_find_iter_infos(src)
442+
if !isnothing(infos)
443+
ssavaluetypes = src.ssavaluetypes::Vector{Any}
444+
@inbounds for (dest, call2cond) in infos
445+
local may_terminate = false
446+
for (_, cond) in call2cond
447+
condt = ssavaluetypes[cond]
448+
t = widenconditional(condt)
449+
if !(isa(t, Const) && t.val === true)
450+
may_terminate = true
451+
break
452+
end
453+
end
454+
may_terminate && continue
455+
456+
@assert ssavaluetypes[dest] === NOT_FOUND # abstract interpretaion should never reach the destination
457+
t = nothing
458+
pc = nothing
459+
for (call, _) in call2cond
460+
ea = ((src.code[call]::Expr).args[2]::Expr).args # x_ = iterate(t′, [state])
461+
t′ = @invoke abstract_eval_value(interp::AbstractInterpreter, ea[2]::Any, frame.stmt_types[call]::VarTable, frame::InferenceState)
462+
@assert isnothing(t) || t === t′
463+
t = t′
464+
465+
if isnothing(pc)
466+
pc = call
467+
end
468+
end
469+
report!(interp, InfiniteIterationErrorReport(interp, frame, t, pc))
470+
end
471+
end
472+
end
473+
474+
function maybe_find_iter_infos(src::CodeInfo)
475+
bbs = compute_basic_blocks(src.code) # XXX basic block construction can be time-consuming
476+
477+
stmts = nothing
478+
479+
for bb in bbs.blocks
480+
maybeinfo = maybe_find_iteration_info_for_block(src, bb)
481+
if !isnothing(maybeinfo)
482+
cond, (call, dest) = maybeinfo
483+
if isnothing(stmts)
484+
stmts = Dict{Int,Vector{Tuple{Int,Int}}}()
485+
end
486+
push!(@get!(stmts, dest, Tuple{Int,Int}[]), (call, cond))
487+
end
488+
end
489+
490+
return stmts
491+
end
492+
493+
# if this basic block comes from the iteration protocol, return the tuple of
494+
# (stmt # of iteration termination check, tuple of (stmt # of `iterate` call, stmt # of the destination))
495+
function maybe_find_iteration_info_for_block(src::CodeInfo, bb::BasicBlock)
496+
# check for there is a sequence pattern of `iterate` call -> `nothing` check -> `Base.not_int` -> `goto #target if not`
497+
stmts = bb.stmts
498+
length(stmts) 4 || return nothing
499+
500+
@inbounds begin
501+
region = src.code[stmts][end-3:end]
502+
503+
terminator = region[end]
504+
isa(terminator, GotoIfNot) || return nothing
505+
506+
preds = region[end-3:end-1]
507+
is_iterate_stmt(preds[1]) || return nothing
508+
is_nothing_check_stmt(preds[2]) || return nothing
509+
is_notint_stmt(preds[3]) || return nothing
510+
511+
return stmts[end-1], (stmts[end-3], terminator.dest)
512+
end
513+
end
514+
515+
function is_iterate_stmt(@nospecialize(x))
516+
@isexpr(x, :(=)) || return false
517+
lhs = x.args[2]
518+
return @isexpr(lhs, :call) && is_global_ref(lhs.args[1], Base, :iterate)
519+
end
520+
521+
function is_nothing_check_stmt(@nospecialize(x))
522+
@isexpr(x, :call) || return false
523+
length(x.args) 3 || return false
524+
is_global_ref(x.args[1], Core, :(===)) || return false
525+
return x.args[3] === nothing
526+
end
527+
528+
function is_notint_stmt(@nospecialize(x))
529+
@isexpr(x, :call) || return false
530+
return is_global_ref(x.args[1], Base, :not_int)
531+
end
532+
533+
is_global_ref(@nospecialize(x), mod::Module, name::Symbol) = isa(x, GlobalRef) && x.mod === mod && x.name === name
534+
535+
# TODO more detailed error report using the context of abstract iteration
536+
function CC.abstract_iteration(interp::JETInterpreter, @nospecialize(itft), @nospecialize(itertype), sv::InferenceState)
537+
ret = @invoke abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(itertype), sv::InferenceState)
538+
rts, info = ret
539+
if length(rts) == 1 && first(rts) === Bottom && isnothing(info)
540+
report!(interp, InfiniteIterationErrorReport(interp, sv, itertype))
541+
end
542+
return ret
543+
end
544+
431545
function CC.finish(me::InferenceState, interp::JETInterpreter)
432546
@invoke finish(me::InferenceState, interp::AbstractInterpreter)
433547

src/reports.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,17 @@ let s = sprint(showerror, DivideError())
508508
global get_msg(::Type{DivideErrorReport}, interp, sv::InferenceState) = s
509509
end
510510

511+
@reportdef struct InfiniteIterationErrorReport <: InferenceErrorReport
512+
@nospecialize(typ)
513+
end
514+
# if provided, use program counter where infinite iteration(s) is found
515+
function InfiniteIterationErrorReport(interp, sv::InferenceState, @nospecialize(typ), pc = nothing)
516+
vst = VirtualFrame[get_virtual_frame(interp, sv, isnothing(pc) ? get_currpc(sv) : pc)]
517+
msg = "iterate(::$typ) won't terminate"
518+
sig = get_sig(interp, sv, isnothing(pc) ? get_stmt(sv) : get_stmt(sv, pc))
519+
return InfiniteIterationErrorReport(vst, msg, sig, typ)
520+
end
521+
511522
# TODO we may want to hoist `InvalidConstXXX` errors into top-level errors
512523

513524
@reportdef struct InvalidConstantRedefinition <: InferenceErrorReport

test/test_abstractinterpretation.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,3 +908,82 @@ end
908908
@test isa(r, GeneratorErrorReport) && r.err == "invalid argument"
909909
end
910910
end
911+
912+
@testset "infinite iteration" begin
913+
m = @fixturedef begin
914+
struct NeverTerminate
915+
val::Int
916+
end
917+
Base.iterate(nv::NeverTerminate, state = 0) =
918+
state>nv.val ? (state, state+1) : (state, state+1)
919+
end
920+
921+
let # iteration protocol
922+
interp, frame = @eval m $analyze_call((Int,)) do n
923+
for a in NeverTerminate(n)
924+
println(a)
925+
end
926+
end
927+
@test any(interp.reports) do r
928+
isa(r, InfiniteIterationErrorReport) &&
929+
r.typ === m.NeverTerminate &&
930+
any(r.vst) do vf
931+
vf.line == (@__LINE__)-8
932+
end
933+
end
934+
end
935+
936+
let # iteration protocol, nested
937+
interp, frame = @eval m $analyze_call((Int,)) do n
938+
for a in 1:n
939+
for b in NeverTerminate(a)
940+
println(b)
941+
end
942+
end
943+
end
944+
@test any(interp.reports) do r
945+
isa(r, InfiniteIterationErrorReport) &&
946+
r.typ === m.NeverTerminate &&
947+
any(r.vst) do vf
948+
vf.line == (@__LINE__)-9
949+
end
950+
end
951+
end
952+
953+
let # iteration protocol, on container type
954+
interp, frame = @eval m $analyze_call((Int,)) do n
955+
sum((a for a in NeverTerminate(n))...)
956+
end
957+
@test any(interp.reports) do r
958+
isa(r, InfiniteIterationErrorReport)
959+
end
960+
end
961+
962+
let # complicated control flow, no false positive
963+
interp, frame = analyze_call((Char,Tuple{Char,Char})) do x, itr
964+
# adapated from https://github.com/JuliaLang/julia/blob/24d9eab45632bdb3120c9e664503745eb58aa2d6/base/operators.jl#L1278-L1297
965+
anymissing = false
966+
for y in itr
967+
v = (y == x)
968+
if ismissing(v)
969+
anymissing = true
970+
elseif v
971+
return true
972+
end
973+
end
974+
return anymissing ? missing : false
975+
end
976+
@test isempty(interp.reports)
977+
end
978+
979+
let # general case of "this function never return"
980+
# NOTE comes down to `Base._foldl_impl`
981+
interp, frame = @eval m $analyze_call((Int,)) do n
982+
sum(a for a in NeverTerminate(n))
983+
end
984+
@test_broken any(interp.reports) do r
985+
isa(r, InfiniteIterationErrorReport) &&
986+
r.typ === m.NeverTerminate
987+
end
988+
end
989+
end

0 commit comments

Comments
 (0)