Skip to content

Commit

Permalink
wip: inference: allow semi-concrete interpret to perform recursive in…
Browse files Browse the repository at this point in the history
…ference

TODOs that should be addressed before merging:
- [ ] implement proper recursion detection mechanism for `IRInterpretationState`
- [x] add proper invalidation support
- [x] allow constant inference from semi-concrete interpretation
- [x] propagate callinfo and allow double inlining

fix #48679
  • Loading branch information
aviatesk committed Mar 10, 2023
1 parent 162b9e9 commit 965c082
Show file tree
Hide file tree
Showing 10 changed files with 443 additions and 295 deletions.
276 changes: 146 additions & 130 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

32 changes: 32 additions & 0 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,38 @@ include("compiler/ssair/ir.jl")
include("compiler/abstractlattice.jl")

include("compiler/inferenceresult.jl")

# TODO define the interface for this abstract type
abstract type AbsIntState end
function frame_instance end
function frame_module(sv::AbsIntState)
mi = frame_instance(sv)
def = mi.def
isa(def, Module) && return def
return def.module
end
function frame_parent end
function frame_cached end

struct MethodInfo
propagate_inbounds::Bool
method_for_inference_limit_heuristics::Union{Nothing,Method}
end
function MethodInfo(src::CodeInfo)
return MethodInfo(src.propagate_inbounds,
src.method_for_inference_limit_heuristics::Union{Nothing,Method})
end
method_info(sv::AbsIntState) = method_info_impl(sv)::MethodInfo

propagate_inbounds(sv::AbsIntState) = method_info(sv).propagate_inbounds
method_for_inference_limit_heuristics(sv::AbsIntState) = method_info(sv).method_for_inference_limit_heuristics

function callers_in_cycle end
# function recur_state end
# pclimitations(sv::AbsIntState) = recur_state(sv).pclimitations
# limitations(sv::AbsIntState) = recur_state(sv).limitations
# callers_in_cycle(sv::AbsIntState) = recur_state(sv).callers_in_cycle

include("compiler/inferencestate.jl")

include("compiler/typeutils.jl")
Expand Down
91 changes: 48 additions & 43 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,18 @@ function append!(bsbmp::BitSetBoundedMinPrioritySet, itr)
end
end

mutable struct InferenceState
struct AbsIntRecursionState
pclimitations::IdSet{AbsIntState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
limitations::IdSet{AbsIntState} # causes of precision restrictions (LimitedAccuracy) on return
callers_in_cycle::Vector{AbsIntState}
end
function AbsIntRecursionState()
return AbsIntRecursionState(IdSet{AbsIntState}(),
IdSet{AbsIntState}(),
Vector{AbsIntState}())
end

mutable struct InferenceState <: AbsIntState
#= information about this method instance =#
linfo::MethodInstance
world::UInt
Expand All @@ -87,6 +98,7 @@ mutable struct InferenceState
slottypes::Vector{Any}
src::CodeInfo
cfg::CFG
method_info::MethodInfo

#= intermediate states for local abstract interpretation =#
currbb::Int
Expand Down Expand Up @@ -135,6 +147,7 @@ mutable struct InferenceState
sptypes = sptypes_from_meth_instance(linfo)
code = src.code::Vector{Any}
cfg = compute_basic_blocks(code)
method_info = MethodInfo(src)

currbb = currpc = 1
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
Expand Down Expand Up @@ -183,7 +196,7 @@ mutable struct InferenceState
cache !== :no && push!(get_inference_cache(interp), result)

return new(
linfo, world, mod, sptypes, slottypes, src, cfg,
linfo, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
result, valid_worlds, bestguess, ipo_effects,
Expand All @@ -195,23 +208,26 @@ end
is_inferred(sv::InferenceState) = is_inferred(sv.result)
is_inferred(result::InferenceResult) = result.result !== nothing

frame_instance(sv::InferenceState) = sv.linfo
frame_parent(sv::InferenceState) = sv.parent
frame_cached(sv::InferenceState) = sv.cached
method_info_impl(sv::InferenceState) = sv.method_info
frame_world(sv::InferenceState) = sv.world
callers_in_cycle(sv::InferenceState) = sv.callers_in_cycle

function merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects)
caller.ipo_effects = merge_effects(caller.ipo_effects, effects)
end

merge_effects!(interp::AbstractInterpreter, caller::InferenceState, callee::InferenceState) =
merge_effects!(interp, caller, Effects(callee))
merge_effects!(interp::AbstractInterpreter, caller::IRCode, effects::Effects) = nothing

is_effect_overridden(sv::InferenceState, effect::Symbol) = is_effect_overridden(sv.linfo, effect)
is_effect_overridden(sv::AbsIntState, effect::Symbol) = is_effect_overridden(frame_instance(sv), effect)
function is_effect_overridden(linfo::MethodInstance, effect::Symbol)
def = linfo.def
return isa(def, Method) && is_effect_overridden(def, effect)
end
is_effect_overridden(method::Method, effect::Symbol) = is_effect_overridden(decode_effects_override(method.purity), effect)
is_effect_overridden(override::EffectsOverride, effect::Symbol) = getfield(override, effect)

add_remark!(::AbstractInterpreter, sv::Union{InferenceState, IRCode}, remark) = return
add_remark!(::AbstractInterpreter, ::AbsIntState, remark) = return

struct InferenceLoopState
sig
Expand All @@ -222,13 +238,13 @@ struct InferenceLoopState
end
end

function bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
return isa(sv, InferenceState) && sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig)
function bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::InferenceState)
return sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig)
end
function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState)
return state.rt === Any && !is_foldable(state.effects)
end
function bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
function bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState)
return state.rt === Any
end

Expand Down Expand Up @@ -347,21 +363,21 @@ end
children before their parents (i.e. ascending the tree from the given
InferenceState). Note that cycles may be visited in any order.
"""
struct InfStackUnwind
inf::InferenceState
struct InfStackUnwind{SV<:AbsIntState}
inf::SV
end
iterate(unw::InfStackUnwind) = (unw.inf, (unw.inf, 0))
function iterate(unw::InfStackUnwind, (infstate, cyclei)::Tuple{InferenceState, Int})
function iterate(unw::InfStackUnwind{SV}, (infstate, cyclei)::Tuple{SV, Int}) where SV<:AbsIntState
# iterate through the cycle before walking to the parent
if cyclei < length(infstate.callers_in_cycle)
if cyclei < length(callers_in_cycle(infstate))
cyclei += 1
infstate = infstate.callers_in_cycle[cyclei]
infstate = callers_in_cycle(infstate)[cyclei]
else
cyclei = 0
infstate = infstate.parent
infstate = frame_parent(infstate)
end
infstate === nothing && return nothing
(infstate::InferenceState, (infstate, cyclei))
(infstate, (infstate, cyclei))
end

function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
Expand Down Expand Up @@ -500,12 +516,12 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
return sptypes
end

_topmod(sv::InferenceState) = _topmod(sv.mod)
_topmod(sv::InferenceState) = _topmod(frame_module(sv))

# work towards converging the valid age range for sv
function update_valid_age!(sv::InferenceState, valid_worlds::WorldRange)
valid_worlds = sv.valid_worlds = intersect(valid_worlds, sv.valid_worlds)
@assert(sv.world in valid_worlds, "invalid age range update")
@assert sv.world in valid_worlds "invalid age range update"
return valid_worlds
end

Expand Down Expand Up @@ -543,42 +559,31 @@ end

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(caller::InferenceState, mi::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, mi)
end
return nothing
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), mi)
end

function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), mi::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, invokesig, mi)
end
return nothing
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), invokesig, mi)
end

# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
function add_mt_backedge!(caller::InferenceState, mt::MethodTable, @nospecialize(typ))
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, mt, typ)
end
return nothing
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), mt, typ)
end

function get_stmt_edges!(caller::InferenceState)
if !isa(caller.linfo.def, Method)
return nothing # don't add backedges to toplevel exprs
end
edges = caller.stmt_edges[caller.currpc]
function get_stmt_edges!(caller::InferenceState, currpc::Int=caller.currpc)
stmt_edges = caller.stmt_edges
edges = stmt_edges[currpc]
if edges === nothing
edges = caller.stmt_edges[caller.currpc] = []
edges = stmt_edges[currpc] = []
end
return edges
end

function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc)
function empty_backedges!(frame::InferenceState, currpc::Int=frame.currpc)
edges = frame.stmt_edges[currpc]
edges === nothing || empty!(edges)
return nothing
Expand Down
18 changes: 9 additions & 9 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ struct InliningState{Interp<:AbstractInterpreter}
world::UInt
interp::Interp
end
function InliningState(frame::InferenceState, interp::AbstractInterpreter)
et = EdgeTracker(frame.stmt_edges[1]::Vector{Any}, frame.valid_worlds)
return InliningState(et, frame.world, interp)
function InliningState(sv::InferenceState, interp::AbstractInterpreter)
et = EdgeTracker(sv.stmt_edges[1]::Vector{Any}, sv.valid_worlds)
return InliningState(et, sv.world, interp)
end
function InliningState(interp::AbstractInterpreter)
return InliningState(nothing, get_world_counter(interp), interp)
Expand All @@ -150,12 +150,12 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter}
cfg::Union{Nothing,CFG}
insert_coverage::Bool
end
function OptimizationState(frame::InferenceState, interp::AbstractInterpreter,
function OptimizationState(sv::InferenceState, interp::AbstractInterpreter,
recompute_cfg::Bool=true)
inlining = InliningState(frame, interp)
cfg = recompute_cfg ? nothing : frame.cfg
return OptimizationState(frame.linfo, frame.src, nothing, frame.stmt_info, frame.mod,
frame.sptypes, frame.slottypes, inlining, cfg, frame.insert_coverage)
inlining = InliningState(sv, interp)
cfg = recompute_cfg ? nothing : sv.cfg
return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, frame_module(sv),
sv.sptypes, sv.slottypes, inlining, cfg, sv.insert_coverage)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, interp::AbstractInterpreter)
# prepare src for running optimization passes if it isn't already
Expand Down Expand Up @@ -385,9 +385,9 @@ function argextype(
return Const(x)
end
end
abstract_eval_ssavalue(s::SSAValue, src::CodeInfo) = abstract_eval_ssavalue(s, src.ssavaluetypes::Vector{Any})
abstract_eval_ssavalue(s::SSAValue, src::Union{IRCode,IncrementalCompact}) = types(src)[s]


"""
finish(interp::AbstractInterpreter, opt::OptimizationState,
ir::IRCode, caller::InferenceResult)
Expand Down
Loading

0 comments on commit 965c082

Please sign in to comment.