Skip to content

Commit

Permalink
very wip: inference: allow semi-concrete interpret to perform recursi…
Browse files Browse the repository at this point in the history
…ve inference

fix #48679
  • Loading branch information
aviatesk committed Mar 7, 2023
1 parent 608e26d commit 9a7aecf
Show file tree
Hide file tree
Showing 14 changed files with 261 additions and 176 deletions.
190 changes: 99 additions & 91 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions base/compiler/cicache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ Internally, each `MethodInstance` keep a unique global cache of code instances
that have been created for the given method instance, stratified by world age
ranges. This struct abstracts over access to this cache.
"""
struct InternalCodeCache
end
struct InternalCodeCache end

function setindex!(cache::InternalCodeCache, ci::CodeInstance, mi::MethodInstance)
ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), mi, ci)
return cache
end

const GLOBAL_CI_CACHE = InternalCodeCache()
Expand Down Expand Up @@ -49,11 +49,11 @@ WorldView(wvc::WorldView, wr::WorldRange) = WorldView(wvc.cache, wr)
WorldView(wvc::WorldView, args...) = WorldView(wvc.cache, args...)

function haskey(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds))::Union{Nothing, CodeInstance} !== nothing
return ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds)) !== nothing
end

function get(wvc::WorldView{InternalCodeCache}, mi::MethodInstance, default)
r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds))::Union{Nothing, CodeInstance}
r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds))
if r === nothing
return default
end
Expand All @@ -66,5 +66,7 @@ function getindex(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
return r::CodeInstance
end

setindex!(wvc::WorldView{InternalCodeCache}, ci::CodeInstance, mi::MethodInstance) =
function setindex!(wvc::WorldView{InternalCodeCache}, ci::CodeInstance, mi::MethodInstance)
setindex!(wvc.cache, ci, mi)
return wvc
end
21 changes: 20 additions & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Core.Intrinsics, Core.IR

import Core: print, println, show, write, unsafe_write, stdout, stderr,
_apply_iterate, svec, apply_type, Builtin, IntrinsicFunction,
MethodInstance, CodeInstance, MethodMatch, PartialOpaque,
MethodInstance, CodeInstance, MethodTable, MethodMatch, PartialOpaque,
TypeofVararg

const getproperty = Core.getfield
Expand Down Expand Up @@ -154,6 +154,25 @@ include("compiler/ssair/ir.jl")
include("compiler/abstractlattice.jl")

include("compiler/inferenceresult.jl")

# TODO define the interface for this abstract type
abstract type AbsIntState end
function frame_instance end
function frame_module(sv::AbsIntState)
mi = frame_instance(sv)
def = mi.def
isa(def, Module) && return def
return def.module
end
function frame_parent end
function frame_cached end
function frame_src end
function callers_in_cycle end
# function recur_state end
# pclimitations(sv::AbsIntState) = recur_state(sv).pclimitations
# limitations(sv::AbsIntState) = recur_state(sv).limitations
# callers_in_cycle(sv::AbsIntState) = recur_state(sv).callers_in_cycle

include("compiler/inferencestate.jl")

include("compiler/typeutils.jl")
Expand Down
56 changes: 34 additions & 22 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,18 @@ function append!(bsbmp::BitSetBoundedMinPrioritySet, itr)
end
end

mutable struct InferenceState
struct AbsIntRecursionState
pclimitations::IdSet{AbsIntState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
limitations::IdSet{AbsIntState} # causes of precision restrictions (LimitedAccuracy) on return
callers_in_cycle::Vector{AbsIntState}
end
function AbsIntRecursionState()
return AbsIntRecursionState(IdSet{AbsIntState}(),
IdSet{AbsIntState}(),
Vector{AbsIntState}())
end

mutable struct InferenceState <: AbsIntState
#= information about this method instance =#
linfo::MethodInstance
world::UInt
Expand Down Expand Up @@ -195,25 +206,26 @@ end
is_inferred(sv::InferenceState) = is_inferred(sv.result)
is_inferred(result::InferenceResult) = result.result !== nothing

frame_instance(sv::InferenceState) = sv.linfo
frame_parent(sv::InferenceState) = sv.parent
frame_cached(sv::InferenceState) = sv.cached
frame_src(sv::InferenceState) = sv.src
callers_in_cycle(sv::InferenceState) = sv.callers_in_cycle
Effects(state::InferenceState) = state.ipo_effects

function merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects)
caller.ipo_effects = merge_effects(caller.ipo_effects, effects)
end

merge_effects!(interp::AbstractInterpreter, caller::InferenceState, callee::InferenceState) =
merge_effects!(interp, caller, Effects(callee))
merge_effects!(interp::AbstractInterpreter, caller::IRCode, effects::Effects) = nothing

is_effect_overridden(sv::InferenceState, effect::Symbol) = is_effect_overridden(sv.linfo, effect)
is_effect_overridden(sv::AbsIntState, effect::Symbol) = is_effect_overridden(frame_instance(sv), effect)
function is_effect_overridden(linfo::MethodInstance, effect::Symbol)
def = linfo.def
return isa(def, Method) && is_effect_overridden(def, effect)
end
is_effect_overridden(method::Method, effect::Symbol) = is_effect_overridden(decode_effects_override(method.purity), effect)
is_effect_overridden(override::EffectsOverride, effect::Symbol) = getfield(override, effect)

add_remark!(::AbstractInterpreter, sv::Union{InferenceState, IRCode}, remark) = return
add_remark!(::AbstractInterpreter, ::AbsIntState, remark) = return

struct InferenceLoopState
sig
Expand All @@ -224,13 +236,13 @@ struct InferenceLoopState
end
end

function bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
return isa(sv, InferenceState) && sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig)
function bail_out_toplevel_call(::AbstractInterpreter, state::InferenceLoopState, sv::InferenceState)
return sv.restrict_abstract_call_sites && !isdispatchtuple(state.sig)
end
function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
function bail_out_call(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState)
return state.rt === Any && !is_foldable(state.effects)
end
function bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, sv::Union{InferenceState, IRCode})
function bail_out_apply(::AbstractInterpreter, state::InferenceLoopState, ::InferenceState)
return state.rt === Any
end

Expand Down Expand Up @@ -349,21 +361,21 @@ end
children before their parents (i.e. ascending the tree from the given
InferenceState). Note that cycles may be visited in any order.
"""
struct InfStackUnwind
inf::InferenceState
struct InfStackUnwind{SV<:AbsIntState}
inf::SV
end
iterate(unw::InfStackUnwind) = (unw.inf, (unw.inf, 0))
function iterate(unw::InfStackUnwind, (infstate, cyclei)::Tuple{InferenceState, Int})
function iterate(unw::InfStackUnwind{SV}, (infstate, cyclei)::Tuple{SV, Int}) where SV<:AbsIntState
# iterate through the cycle before walking to the parent
if cyclei < length(infstate.callers_in_cycle)
if cyclei < length(callers_in_cycle(infstate))
cyclei += 1
infstate = infstate.callers_in_cycle[cyclei]
infstate = callers_in_cycle(infstate)[cyclei]
else
cyclei = 0
infstate = infstate.parent
infstate = frame_parent(infstate)
end
infstate === nothing && return nothing
(infstate::InferenceState, (infstate, cyclei))
(infstate, (infstate, cyclei))
end

function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
Expand Down Expand Up @@ -502,7 +514,7 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
return sptypes
end

_topmod(sv::InferenceState) = _topmod(sv.mod)
_topmod(sv::InferenceState) = _topmod(frame_module(sv))

# work towards converging the valid age range for sv
function update_valid_age!(sv::InferenceState, worlds::WorldRange)
Expand Down Expand Up @@ -546,10 +558,10 @@ function add_cycle_backedge!(caller::InferenceState, frame::InferenceState, curr
end

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(caller::InferenceState, li::MethodInstance)
function add_backedge!(caller::InferenceState, mi::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, li)
push!(edges, mi)
end
return nothing
end
Expand All @@ -563,7 +575,7 @@ function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::T
end

# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
function add_mt_backedge!(caller::InferenceState, mt::Core.MethodTable, @nospecialize(typ))
function add_mt_backedge!(caller::InferenceState, mt::MethodTable, @nospecialize(typ))
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, mt, typ)
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ external table, e.g., to override existing method.
"""
struct OverlayMethodTable <: MethodTableView
world::UInt
mt::Core.MethodTable
mt::MethodTable
end

struct MethodMatchKey
Expand Down Expand Up @@ -98,7 +98,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
!isempty(result))
end

function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int)
function _findall(@nospecialize(sig::Type), mt::Union{Nothing,MethodTable}, world::UInt, limit::Int)
_min_val = RefValue{UInt}(typemin(UInt))
_max_val = RefValue{UInt}(typemax(UInt))
_ambig = RefValue{Int32}(0)
Expand Down Expand Up @@ -155,7 +155,7 @@ function findsup(@nospecialize(sig::Type), table::OverlayMethodTable)
false)
end

function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt)
function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,MethodTable}, world::UInt)
min_valid = RefValue{UInt}(typemin(UInt))
max_valid = RefValue{UInt}(typemax(UInt))
match = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}),
Expand Down
18 changes: 9 additions & 9 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ struct InliningState{Interp<:AbstractInterpreter}
world::UInt
interp::Interp
end
function InliningState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
et = EdgeTracker(frame.stmt_edges[1]::Vector{Any}, frame.valid_worlds)
return InliningState(params, et, frame.world, interp)
function InliningState(sv::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
et = EdgeTracker(sv.stmt_edges[1]::Vector{Any}, sv.valid_worlds)
return InliningState(params, et, sv.world, interp)
end
function InliningState(params::OptimizationParams, interp::AbstractInterpreter)
return InliningState(params, nothing, get_world_counter(interp), interp)
Expand All @@ -151,12 +151,12 @@ mutable struct OptimizationState{Interp<:AbstractInterpreter}
cfg::Union{Nothing,CFG}
insert_coverage::Bool
end
function OptimizationState(frame::InferenceState, params::OptimizationParams,
function OptimizationState(sv::InferenceState, params::OptimizationParams,
interp::AbstractInterpreter, recompute_cfg::Bool=true)
inlining = InliningState(frame, params, interp)
cfg = recompute_cfg ? nothing : frame.cfg
return OptimizationState(frame.linfo, frame.src, nothing, frame.stmt_info, frame.mod,
frame.sptypes, frame.slottypes, inlining, cfg, frame.insert_coverage)
inlining = InliningState(sv, params, interp)
cfg = recompute_cfg ? nothing : sv.cfg
return OptimizationState(sv.linfo, sv.src, nothing, sv.stmt_info, frame_module(sv),
sv.sptypes, sv.slottypes, inlining, cfg, sv.insert_coverage)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams,
interp::AbstractInterpreter)
Expand Down Expand Up @@ -387,9 +387,9 @@ function argextype(
return Const(x)
end
end
abstract_eval_ssavalue(s::SSAValue, src::CodeInfo) = abstract_eval_ssavalue(s, src.ssavaluetypes::Vector{Any})
abstract_eval_ssavalue(s::SSAValue, src::Union{IRCode,IncrementalCompact}) = types(src)[s]


"""
finish(interp::AbstractInterpreter, opt::OptimizationState,
params::OptimizationParams, ir::IRCode, caller::InferenceResult)
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ end

function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::SSAValue)
@assert idx.id < compact.result_idx
(compact.result[idx.id][:inst] === v) && return
(compact.result[idx.id][:inst] === v) && return compact
# Kill count for current uses
kill_current_uses!(compact, compact.result[idx.id][:inst])
compact.result[idx.id][:inst] = v
Expand All @@ -949,7 +949,7 @@ function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::OldSSAVal
id = idx.id
if id < compact.idx
new_idx = compact.ssa_rename[id]
(compact.result[new_idx][:inst] === v) && return
(compact.result[new_idx][:inst] === v) && return compact
kill_current_uses!(compact, compact.result[new_idx][:inst])
compact.result[new_idx][:inst] = v
count_added_node!(compact, v) && push!(compact.late_fixup, new_idx)
Expand Down
Loading

0 comments on commit 9a7aecf

Please sign in to comment.