Skip to content

Commit

Permalink
very wip: inference: allow semi-concrete interpret to perform recursi…
Browse files Browse the repository at this point in the history
…ve inference

fix #48679
  • Loading branch information
aviatesk committed Mar 7, 2023
1 parent 9b9b99f commit 56bea36
Show file tree
Hide file tree
Showing 10 changed files with 281 additions and 196 deletions.
192 changes: 100 additions & 92 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,25 @@ include("compiler/ssair/ir.jl")
include("compiler/abstractlattice.jl")

include("compiler/inferenceresult.jl")

# TODO define the interface for this abstract type
abstract type AbsIntState end
function frame_instance end
function frame_module(sv::AbsIntState)
mi = frame_instance(sv)
def = mi.def
isa(def, Module) && return def
return def.module
end
function frame_parent end
function frame_cached end
function frame_src end
function callers_in_cycle end
# function recur_state end
# pclimitations(sv::AbsIntState) = recur_state(sv).pclimitations
# limitations(sv::AbsIntState) = recur_state(sv).limitations
# callers_in_cycle(sv::AbsIntState) = recur_state(sv).callers_in_cycle

include("compiler/inferencestate.jl")

include("compiler/typeutils.jl")
Expand Down
86 changes: 44 additions & 42 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,18 @@ function append!(bsbmp::BitSetBoundedMinPrioritySet, itr)
end
end

mutable struct InferenceState
struct AbsIntRecursionState
pclimitations::IdSet{AbsIntState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
limitations::IdSet{AbsIntState} # causes of precision restrictions (LimitedAccuracy) on return
callers_in_cycle::Vector{AbsIntState}
end
function AbsIntRecursionState()
return AbsIntRecursionState(IdSet{AbsIntState}(),
IdSet{AbsIntState}(),
Vector{AbsIntState}())
end

mutable struct InferenceState <: AbsIntState
#= information about this method instance =#
linfo::MethodInstance
world::UInt
Expand Down Expand Up @@ -195,23 +206,25 @@ end
is_inferred(sv::InferenceState) = is_inferred(sv.result)
is_inferred(result::InferenceResult) = result.result !== nothing

frame_instance(sv::InferenceState) = sv.linfo
frame_parent(sv::InferenceState) = sv.parent
frame_cached(sv::InferenceState) = sv.cached
frame_src(sv::InferenceState) = sv.src
callers_in_cycle(sv::InferenceState) = sv.callers_in_cycle

function merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects)
caller.ipo_effects = merge_effects(caller.ipo_effects, effects)
end

merge_effects!(interp::AbstractInterpreter, caller::InferenceState, callee::InferenceState) =
merge_effects!(interp, caller, Effects(callee))
merge_effects!(interp::AbstractInterpreter, caller::IRCode, effects::Effects) = nothing

is_effect_overridden(sv::InferenceState, effect::Symbol) = is_effect_overridden(sv.linfo, effect)
is_effect_overridden(sv::AbsIntState, effect::Symbol) = is_effect_overridden(frame_instance(sv), effect)
function is_effect_overridden(linfo::MethodInstance, effect::Symbol)
def = linfo.def
return isa(def, Method) && is_effect_overridden(def, effect)
end
is_effect_overridden(method::Method, effect::Symbol) = is_effect_overridden(decode_effects_override(method.purity), effect)
is_effect_overridden(override::EffectsOverride, effect::Symbol) = getfield(override, effect)

add_remark!(::AbstractInterpreter, sv::Union{InferenceState, IRCode}, remark) = return
add_remark!(::AbstractInterpreter, ::AbsIntState, remark) = return

struct InferenceLoopState
sig
Expand All @@ -222,13 +235,13 @@ struct InferenceLoopState
end
end

function bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
return isa(sv, InferenceState) && sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig)
function bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::InferenceState)
return sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig)
end
function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState)
return state.rt === Any && !is_foldable(state.effects)
end
function bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
function bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState)
return state.rt === Any
end

Expand Down Expand Up @@ -347,21 +360,21 @@ end
children before their parents (i.e. ascending the tree from the given
InferenceState). Note that cycles may be visited in any order.
"""
struct InfStackUnwind
inf::InferenceState
struct InfStackUnwind{SV<:AbsIntState}
inf::SV
end
iterate(unw::InfStackUnwind) = (unw.inf, (unw.inf, 0))
function iterate(unw::InfStackUnwind, (infstate, cyclei)::Tuple{InferenceState, Int})
function iterate(unw::InfStackUnwind{SV}, (infstate, cyclei)::Tuple{SV, Int}) where SV<:AbsIntState
# iterate through the cycle before walking to the parent
if cyclei < length(infstate.callers_in_cycle)
if cyclei < length(callers_in_cycle(infstate))
cyclei += 1
infstate = infstate.callers_in_cycle[cyclei]
infstate = callers_in_cycle(infstate)[cyclei]
else
cyclei = 0
infstate = infstate.parent
infstate = frame_parent(infstate)
end
infstate === nothing && return nothing
(infstate::InferenceState, (infstate, cyclei))
(infstate, (infstate, cyclei))
end

function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
Expand Down Expand Up @@ -500,12 +513,12 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
return sptypes
end

_topmod(sv::InferenceState) = _topmod(sv.mod)
_topmod(sv::InferenceState) = _topmod(frame_module(sv))

# work towards converging the valid age range for sv
function update_valid_age!(sv::InferenceState, valid_worlds::WorldRange)
valid_worlds = sv.valid_worlds = intersect(valid_worlds, sv.valid_worlds)
@assert(sv.world in valid_worlds, "invalid age range update")
@assert sv.world in valid_worlds "invalid age range update"
return valid_worlds
end

Expand Down Expand Up @@ -543,42 +556,31 @@ end

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(caller::InferenceState, mi::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, mi)
end
return nothing
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), mi)
end

function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), mi::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, invokesig, mi)
end
return nothing
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), invokesig, mi)
end

# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
function add_mt_backedge!(caller::InferenceState, mt::MethodTable, @nospecialize(typ))
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, mt, typ)
end
return nothing
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), mt, typ)
end

function get_stmt_edges!(caller::InferenceState)
if !isa(caller.linfo.def, Method)
return nothing # don't add backedges to toplevel exprs
end
edges = caller.stmt_edges[caller.currpc]
function get_stmt_edges!(caller::InferenceState, currpc::Int=caller.currpc)
stmt_edges = caller.stmt_edges
edges = stmt_edges[currpc]
if edges === nothing
edges = caller.stmt_edges[caller.currpc] = []
edges = stmt_edges[currpc] = []
end
return edges
end

function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc)
function empty_backedges!(frame::InferenceState, currpc::Int=frame.currpc)
edges = frame.stmt_edges[currpc]
edges === nothing || empty!(edges)
return nothing
Expand Down
18 changes: 9 additions & 9 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ struct InliningState{Interp<:AbstractInterpreter}
world::UInt
interp::Interp
end
function InliningState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
et = EdgeTracker(frame.stmt_edges[1]::Vector{Any}, frame.valid_worlds)
return InliningState(params, et, frame.world, interp)
function InliningState(sv::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
et = EdgeTracker(sv.stmt_edges[1]::Vector{Any}, sv.valid_worlds)
return InliningState(params, et, sv.world, interp)
end
function InliningState(params::OptimizationParams, interp::AbstractInterpreter)
return InliningState(params, nothing, get_world_counter(interp), interp)
Expand All @@ -151,12 +151,12 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter}
cfg::Union{Nothing,CFG}
insert_coverage::Bool
end
function OptimizationState(frame::InferenceState, params::OptimizationParams,
function OptimizationState(sv::InferenceState, params::OptimizationParams,
interp::AbstractInterpreter, recompute_cfg::Bool=true)
inlining = InliningState(frame, params, interp)
cfg = recompute_cfg ? nothing : frame.cfg
return OptimizationState(frame.linfo, frame.src, nothing, frame.stmt_info, frame.mod,
frame.sptypes, frame.slottypes, inlining, cfg, frame.insert_coverage)
inlining = InliningState(sv, params, interp)
cfg = recompute_cfg ? nothing : sv.cfg
return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, frame_module(sv),
sv.sptypes, sv.slottypes, inlining, cfg, sv.insert_coverage)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams,
interp::AbstractInterpreter)
Expand Down Expand Up @@ -387,9 +387,9 @@ function argextype(
return Const(x)
end
end
abstract_eval_ssavalue(s::SSAValue, src::CodeInfo) = abstract_eval_ssavalue(s, src.ssavaluetypes::Vector{Any})
abstract_eval_ssavalue(s::SSAValue, src::Union{IRCode,IncrementalCompact}) = types(src)[s]


"""
finish(interp::AbstractInterpreter, opt::OptimizationState,
params::OptimizationParams, ir::IRCode, caller::InferenceResult)
Expand Down
Loading

0 comments on commit 56bea36

Please sign in to comment.