Skip to content

Commit

Permalink
very wip: inference: allow semi-concrete interpret to perform recursi…
Browse files Browse the repository at this point in the history
…ve inference

fix #48679
  • Loading branch information
aviatesk committed Mar 6, 2023
1 parent 7eb9615 commit db75e9a
Show file tree
Hide file tree
Showing 12 changed files with 254 additions and 169 deletions.
190 changes: 99 additions & 91 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

21 changes: 20 additions & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Core.Intrinsics, Core.IR

import Core: print, println, show, write, unsafe_write, stdout, stderr,
_apply_iterate, svec, apply_type, Builtin, IntrinsicFunction,
MethodInstance, CodeInstance, MethodMatch, PartialOpaque,
MethodInstance, CodeInstance, MethodTable, MethodMatch, PartialOpaque,
TypeofVararg

const getproperty = Core.getfield
Expand Down Expand Up @@ -154,6 +154,25 @@ include("compiler/ssair/ir.jl")
include("compiler/abstractlattice.jl")

include("compiler/inferenceresult.jl")

# TODO define the interface for this abstract type
abstract type AbsIntState end
function frame_instance end
function frame_module(sv::AbsIntState)
mi = frame_instance(sv)
def = mi.def
isa(def, Module) && return def
return def.module
end
function frame_parent end
function frame_cached end
function frame_src end
function callers_in_cycle end
# function recur_state end
# pclimitations(sv::AbsIntState) = recur_state(sv).pclimitations
# limitations(sv::AbsIntState) = recur_state(sv).limitations
# callers_in_cycle(sv::AbsIntState) = recur_state(sv).callers_in_cycle

include("compiler/inferencestate.jl")

include("compiler/typeutils.jl")
Expand Down
56 changes: 34 additions & 22 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,18 @@ function append!(bsbmp::BitSetBoundedMinPrioritySet, itr)
end
end

mutable struct InferenceState
struct AbsIntRecursionState
pclimitations::IdSet{AbsIntState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
limitations::IdSet{AbsIntState} # causes of precision restrictions (LimitedAccuracy) on return
callers_in_cycle::Vector{AbsIntState}
end
function AbsIntRecursionState()
return AbsIntRecursionState(IdSet{AbsIntState}(),
IdSet{AbsIntState}(),
Vector{AbsIntState}())
end

mutable struct InferenceState <: AbsIntState
#= information about this method instance =#
linfo::MethodInstance
world::UInt
Expand Down Expand Up @@ -197,25 +208,26 @@ mutable struct InferenceState
end
end

frame_instance(sv::InferenceState) = sv.linfo
frame_parent(sv::InferenceState) = sv.parent
frame_cached(sv::InferenceState) = sv.cached
frame_src(sv::InferenceState) = sv.src
callers_in_cycle(sv::InferenceState) = sv.callers_in_cycle
Effects(state::InferenceState) = state.ipo_effects

function merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects)
caller.ipo_effects = merge_effects(caller.ipo_effects, effects)
end

merge_effects!(interp::AbstractInterpreter, caller::InferenceState, callee::InferenceState) =
merge_effects!(interp, caller, Effects(callee))
merge_effects!(interp::AbstractInterpreter, caller::IRCode, effects::Effects) = nothing

is_effect_overridden(sv::InferenceState, effect::Symbol) = is_effect_overridden(sv.linfo, effect)
is_effect_overridden(sv::AbsIntState, effect::Symbol) = is_effect_overridden(frame_instance(sv), effect)
function is_effect_overridden(linfo::MethodInstance, effect::Symbol)
def = linfo.def
return isa(def, Method) && is_effect_overridden(def, effect)
end
is_effect_overridden(method::Method, effect::Symbol) = is_effect_overridden(decode_effects_override(method.purity), effect)
is_effect_overridden(override::EffectsOverride, effect::Symbol) = getfield(override, effect)

add_remark!(::AbstractInterpreter, sv::Union{InferenceState, IRCode}, remark) = return
add_remark!(::AbstractInterpreter, ::AbsIntState, remark) = return

struct InferenceLoopState
sig
Expand All @@ -226,13 +238,13 @@ struct InferenceLoopState
end
end

function bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
return isa(sv, InferenceState) && sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig)
function bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::InferenceState)
return sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig)
end
function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState)
return state.rt === Any && !is_foldable(state.effects)
end
function bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
function bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState)
return state.rt === Any
end

Expand Down Expand Up @@ -351,21 +363,21 @@ end
children before their parents (i.e. ascending the tree from the given
InferenceState). Note that cycles may be visited in any order.
"""
struct InfStackUnwind
inf::InferenceState
struct InfStackUnwind{SV<:AbsIntState}
inf::SV
end
iterate(unw::InfStackUnwind) = (unw.inf, (unw.inf, 0))
function iterate(unw::InfStackUnwind, (infstate, cyclei)::Tuple{InferenceState, Int})
function iterate(unw::InfStackUnwind{SV}, (infstate, cyclei)::Tuple{SV, Int}) where SV<:AbsIntState
# iterate through the cycle before walking to the parent
if cyclei < length(infstate.callers_in_cycle)
if cyclei < length(callers_in_cycle(infstate))
cyclei += 1
infstate = infstate.callers_in_cycle[cyclei]
infstate = callers_in_cycle(infstate)[cyclei]
else
cyclei = 0
infstate = infstate.parent
infstate = frame_parent(infstate)
end
infstate === nothing && return nothing
(infstate::InferenceState, (infstate, cyclei))
(infstate, (infstate, cyclei))
end

function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
Expand Down Expand Up @@ -504,7 +516,7 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
return sptypes
end

_topmod(sv::InferenceState) = _topmod(sv.mod)
_topmod(sv::InferenceState) = _topmod(frame_module(sv))

# work towards converging the valid age range for sv
function update_valid_age!(sv::InferenceState, worlds::WorldRange)
Expand Down Expand Up @@ -548,10 +560,10 @@ function add_cycle_backedge!(caller::InferenceState, frame::InferenceState, curr
end

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(caller::InferenceState, li::MethodInstance)
function add_backedge!(caller::InferenceState, mi::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, li)
push!(edges, mi)
end
return nothing
end
Expand All @@ -565,7 +577,7 @@ function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::T
end

# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
function add_mt_backedge!(caller::InferenceState, mt::Core.MethodTable, @nospecialize(typ))
function add_mt_backedge!(caller::InferenceState, mt::MethodTable, @nospecialize(typ))
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, mt, typ)
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ external table, e.g., to override existing method.
"""
struct OverlayMethodTable <: MethodTableView
world::UInt
mt::Core.MethodTable
mt::MethodTable
end

struct MethodMatchKey
Expand Down Expand Up @@ -98,7 +98,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
!isempty(result))
end

function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int)
function _findall(@nospecialize(sig::Type), mt::Union{Nothing,MethodTable}, world::UInt, limit::Int)
_min_val = RefValue{UInt}(typemin(UInt))
_max_val = RefValue{UInt}(typemax(UInt))
_ambig = RefValue{Int32}(0)
Expand Down Expand Up @@ -155,7 +155,7 @@ function findsup(@nospecialize(sig::Type), table::OverlayMethodTable)
false)
end

function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt)
function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,MethodTable}, world::UInt)
min_valid = RefValue{UInt}(typemin(UInt))
max_valid = RefValue{UInt}(typemax(UInt))
match = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}),
Expand Down
18 changes: 9 additions & 9 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ struct InliningState{Interp<:AbstractInterpreter}
world::UInt
interp::Interp
end
function InliningState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
et = EdgeTracker(frame.stmt_edges[1]::Vector{Any}, frame.valid_worlds)
return InliningState(params, et, frame.world, interp)
function InliningState(sv::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
et = EdgeTracker(sv.stmt_edges[1]::Vector{Any}, sv.valid_worlds)
return InliningState(params, et, sv.world, interp)
end
function InliningState(params::OptimizationParams, interp::AbstractInterpreter)
return InliningState(params, nothing, get_world_counter(interp), interp)
Expand All @@ -151,12 +151,12 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter}
cfg::Union{Nothing,CFG}
insert_coverage::Bool
end
function OptimizationState(frame::InferenceState, params::OptimizationParams,
function OptimizationState(sv::InferenceState, params::OptimizationParams,
interp::AbstractInterpreter, recompute_cfg::Bool=true)
inlining = InliningState(frame, params, interp)
cfg = recompute_cfg ? nothing : frame.cfg
return OptimizationState(frame.linfo, frame.src, nothing, frame.stmt_info, frame.mod,
frame.sptypes, frame.slottypes, inlining, cfg, frame.insert_coverage)
inlining = InliningState(sv, params, interp)
cfg = recompute_cfg ? nothing : sv.cfg
return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, frame_module(sv),
sv.sptypes, sv.slottypes, inlining, cfg, sv.insert_coverage)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams,
interp::AbstractInterpreter)
Expand Down Expand Up @@ -387,9 +387,9 @@ function argextype(
return Const(x)
end
end
abstract_eval_ssavalue(s::SSAValue, src::CodeInfo) = abstract_eval_ssavalue(s, src.ssavaluetypes::Vector{Any})
abstract_eval_ssavalue(s::SSAValue, src::Union{IRCode,IncrementalCompact}) = types(src)[s]


"""
finish(interp::AbstractInterpreter, opt::OptimizationState,
params::OptimizationParams, ir::IRCode, caller::InferenceResult)
Expand Down
59 changes: 46 additions & 13 deletions base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,17 @@ function getindex(tpdum::TwoPhaseDefUseMap, idx::Int)
return TwoPhaseVectorView(tpdum.data, nelems, range)
end

struct IRInterpretationState
# TODO add `result::InferenceResult` & `parent::InferenceState` for this
struct IRInterpretationState <: AbsIntState
ir::IRCode
mi::MethodInstance
world::UInt
argtypes_refined::Vector{Bool}
sptypes::Vector{VarState}
tpdum::TwoPhaseDefUseMap
ssa_refined::BitSet
lazydomtree::LazyDomtree
callers_in_cycle::Vector{InferenceState}
function IRInterpretationState(interp::AbstractInterpreter,
ir::IRCode, mi::MethodInstance, world::UInt, argtypes::Vector{Any})
argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, mi)
Expand All @@ -114,10 +117,40 @@ struct IRInterpretationState
tpdum = TwoPhaseDefUseMap(length(ir.stmts))
ssa_refined = BitSet()
lazydomtree = LazyDomtree(ir)
return new(ir, mi, world, argtypes_refined, tpdum, ssa_refined, lazydomtree)
callers_in_cycle = Vector{InferenceState}()
return new(ir, mi, world, argtypes_refined, ir.sptypes, tpdum, ssa_refined, lazydomtree, callers_in_cycle)
end
end

frame_instance(sv::IRInterpretationState) = sv.mi
frame_parent(sv::IRInterpretationState) = nothing
frame_cached(sv::IRInterpretationState) = false
frame_src(sv::IRInterpretationState) = retrieve_code_info(sv.mi) # TODO optimize
callers_in_cycle(sv::IRInterpretationState) = sv.callers_in_cycle
# TODO
merge_effects!(::AbstractInterpreter, ::IRInterpretationState, ::Effects) = return
get_max_methods(::IRInterpretationState, ::AbstractInterpreter) = 3
get_max_methods(@nospecialize(f), ::IRInterpretationState, ::AbstractInterpreter) = 3
ssa_def_slot(@nospecialize(arg), ::IRInterpretationState) = nothing
function bail_out_toplevel_call(::AbstractInterpreter, ::InferenceLoopState, ::IRInterpretationState)
return false
end
function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, ::IRInterpretationState)
return state.rt === Any && !is_foldable(state.effects)
end
function bail_out_apply(::AbstractInterpreter, @nospecialize(rt), ::IRInterpretationState)
return rt === Any
end
should_infer_this_call(::AbstractInterpreter, ::IRInterpretationState) = true
const_prop_enabled(::AbstractInterpreter, ::IRInterpretationState, match::MethodMatch) = false

# TODO
update_valid_age!(::IRInterpretationState, ::WorldRange) = return
update_valid_age!(::InferenceState, ::IRInterpretationState) = return
add_backedge!(::IRInterpretationState, ::MethodInstance) = return
add_invoke_backedge!(::IRInterpretationState, @nospecialize(invokesig::Type), ::MethodInstance) = return
add_mt_backedge!(::IRInterpretationState, ::MethodTable, @nospecialize(typ)) = return

function codeinst_to_ir(interp::AbstractInterpreter, code::CodeInstance)
src = @atomic :monotonic code.inferred
mi = code.def
Expand All @@ -129,13 +162,13 @@ function codeinst_to_ir(interp::AbstractInterpreter, code::CodeInstance)
return inflate_ir(src, mi)
end

function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype),
sv::IRCode, max_methods::Int)
return CallMeta(Any, Effects(), NoCallInfo())
function from_interconditional(::AbstractLattice,
typ, ::IRInterpretationState, ::ArgInfo, maybecondinfo)
@nospecialize typ maybecondinfo
return widenconditional(typ)
end

function collect_limitations!(@nospecialize(typ), ::IRCode)
function collect_limitations!(@nospecialize(typ), ::IRInterpretationState)
@assert !isa(typ, LimitedAccuracy) "semi-concrete eval on recursive call graph"
return typ
end
Expand All @@ -147,7 +180,7 @@ function concrete_eval_invoke(interp::AbstractInterpreter,
if code === nothing
return Pair{Any, Bool}(nothing, false)
end
argtypes = collect_argtypes(interp, inst.args[2:end], nothing, irsv.ir)
argtypes = collect_argtypes(interp, inst.args[2:end], nothing, irsv)
argtypes === nothing && return Pair{Any, Bool}(Union{}, false)
effects = decode_effects(code.ipo_purity_bits)
if is_foldable(effects) && is_all_const_arg(argtypes, #=start=#1)
Expand All @@ -169,8 +202,10 @@ function concrete_eval_invoke(interp::AbstractInterpreter,
return Pair{Any, Bool}(nothing, is_nothrow(effects))
end

abstract_eval_ssavalue(s::SSAValue, sv::IRInterpretationState) = abstract_eval_ssavalue(s, sv.ir)

function abstract_eval_phi_stmt(interp::AbstractInterpreter, phi::PhiNode, ::Int, irsv::IRInterpretationState)
return abstract_eval_phi(interp, phi, nothing, irsv.ir)
return abstract_eval_phi(interp, phi, nothing, irsv)
end

function propagate_control_effects!(interp::AbstractInterpreter, idx::Int, stmt::GotoIfNot,
Expand Down Expand Up @@ -237,7 +272,7 @@ function reprocess_instruction!(interp::AbstractInterpreter,
if isa(inst, Expr)
head = inst.head
if head === :call || head === :foreigncall || head === :new || head === :splatnew
(; rt, effects) = abstract_eval_statement_expr(interp, inst, nothing, ir, irsv.mi)
(; rt, effects) = abstract_eval_statement_expr(interp, inst, nothing, irsv)
# All other effects already guaranteed effect free by construction
if is_nothrow(effects)
ir.stmts[idx][:flag] |= IR_FLAG_NOTHROW
Expand All @@ -261,7 +296,6 @@ function reprocess_instruction!(interp::AbstractInterpreter,
head === :gc_preserve_end
return false
else
ccall(:jl_, Cvoid, (Any,), inst)
error("reprocess_instruction!: unhandled expression found")
end
elseif isa(inst, PhiNode)
Expand All @@ -276,8 +310,7 @@ function reprocess_instruction!(interp::AbstractInterpreter,
elseif isa(inst, GlobalRef)
# GlobalRef is not refinable
else
ccall(:jl_, Cvoid, (Any,), inst)
error()
error("reprocess_instruction!: unhandled instruction found")
end
if rt !== nothing && !(optimizer_lattice(interp), typ, rt)
ir.stmts[idx][:type] = rt
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1945,7 +1945,7 @@ function array_elmtype(@nospecialize ary)
return Any
end

@nospecs function _opaque_closure_tfunc(𝕃::AbstractLattice, arg, lb, ub, source, env::Vector{Any}, linfo::MethodInstance)
@nospecs function opaque_closure_tfunc(𝕃::AbstractLattice, arg, lb, ub, source, env::Vector{Any}, linfo::MethodInstance)
argt, argt_exact = instanceof_tfunc(arg)
lbt, lb_exact = instanceof_tfunc(lb)
if !lb_exact
Expand Down Expand Up @@ -2307,7 +2307,7 @@ function builtin_nothrow(𝕃::AbstractLattice, @nospecialize(f), argtypes::Vect
end

function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any},
sv::Union{InferenceState,IRCode,Nothing})
sv::Union{AbsIntState, Nothing})
𝕃ᵢ = typeinf_lattice(interp)
if f === tuple
return tuple_tfunc(𝕃ᵢ, argtypes)
Expand Down Expand Up @@ -2478,7 +2478,7 @@ end
# TODO: this function is a very buggy and poor model of the return_type function
# since abstract_call_gf_by_type is a very inaccurate model of _method and of typeinf_type,
# while this assumes that it is an absolutely precise and accurate and exact model of both
function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, si::StmtInfo, sv::Union{InferenceState, IRCode})
function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, si::StmtInfo, sv::AbsIntState)
if length(argtypes) == 3
tt = widenslotwrapper(argtypes[3])
if isa(tt, Const) || (isType(tt) && !has_free_typevars(tt))
Expand Down Expand Up @@ -2605,7 +2605,7 @@ function _hasmethod_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, sv
types = rewrap_unionall(Tuple{ft, unwrapped.parameters...}, types)::Type
end
mt = ccall(:jl_method_table_for, Any, (Any,), types)
if !isa(mt, Core.MethodTable)
if !isa(mt, MethodTable)
return CallMeta(Bool, EFFECTS_THROWS, NoCallInfo())
end
match, valid_worlds, overlayed = findsup(types, method_table(interp))
Expand Down
Loading

0 comments on commit db75e9a

Please sign in to comment.