Skip to content

Commit

Permalink
tidy up compiler implementation (#48930)
Browse files Browse the repository at this point in the history
- remove `update_valid_age!(edge::InferenceState, sv::InferenceState)`
  and replace all the usages with `update_valid_age!(sv, edge.valid_worlds)`:
  this will simplify the incoming `AbsIntState` interface (see #48913)
- remove `Effects(sv::InferenceState)` utility: replace all the usages
  with `sv.ipo_effects`, which is more explictly saying that we are
  looking at IPO-valid effects
- normalize more `li::MethodInstance` to `mi::MethodInstance`
- import `Core.MethodTable`
- fix up `setindex!` return values
  • Loading branch information
aviatesk authored Mar 8, 2023
1 parent eb4b1a7 commit a2912e2
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 56 deletions.
14 changes: 7 additions & 7 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function should_infer_this_call(interp::AbstractInterpreter, sv::InferenceState)
end

function should_infer_for_effects(sv::InferenceState)
effects = Effects(sv)
effects = sv.ipo_effects
return is_terminates(effects) && is_effect_free(effects)
end

Expand Down Expand Up @@ -255,7 +255,7 @@ struct MethodMatches
applicable::Vector{Any}
info::MethodMatchInfo
valid_worlds::WorldRange
mt::Core.MethodTable
mt::MethodTable
fullmatch::Bool
nonoverlayed::Bool
end
Expand All @@ -267,7 +267,7 @@ struct UnionSplitMethodMatches
applicable_argtypes::Vector{Vector{Any}}
info::UnionSplitInfo
valid_worlds::WorldRange
mts::Vector{Core.MethodTable}
mts::Vector{MethodTable}
fullmatches::Vector{Bool}
nonoverlayed::Bool
end
Expand All @@ -282,15 +282,15 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
applicable = Any[]
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
valid_worlds = WorldRange()
mts = Core.MethodTable[]
mts = MethodTable[]
fullmatches = Bool[]
nonoverlayed = true
for i in 1:length(split_argtypes)
arg_n = split_argtypes[i]::Vector{Any}
sig_n = argtypes_to_type(arg_n)
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
mt === nothing && return FailedMethodMatch("Could not identify method table for call")
mt = mt::Core.MethodTable
mt = mt::MethodTable
result = findall(sig_n, method_table; limit = max_methods)
if result === nothing
return FailedMethodMatch("For one of the union split cases, too many methods matched")
Expand Down Expand Up @@ -329,7 +329,7 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
if mt === nothing
return FailedMethodMatch("Could not identify method table for call")
end
mt = mt::Core.MethodTable
mt = mt::MethodTable
result = findall(atype, method_table; limit = max_methods)
if result === nothing
# this means too many methods matched
Expand Down Expand Up @@ -3081,7 +3081,7 @@ function typeinf_nocycle(interp::AbstractInterpreter, frame::InferenceState)
typeinf_local(interp, caller)
no_active_ips_in_callers = false
end
caller.valid_worlds = intersect(caller.valid_worlds, frame.valid_worlds)
update_valid_age!(caller, frame.valid_worlds)
end
end
return true
Expand Down
12 changes: 7 additions & 5 deletions base/compiler/cicache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ Internally, each `MethodInstance` keep a unique global cache of code instances
that have been created for the given method instance, stratified by world age
ranges. This struct abstracts over access to this cache.
"""
struct InternalCodeCache
end
struct InternalCodeCache end

function setindex!(cache::InternalCodeCache, ci::CodeInstance, mi::MethodInstance)
ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), mi, ci)
return cache
end

const GLOBAL_CI_CACHE = InternalCodeCache()
Expand Down Expand Up @@ -49,11 +49,11 @@ WorldView(wvc::WorldView, wr::WorldRange) = WorldView(wvc.cache, wr)
WorldView(wvc::WorldView, args...) = WorldView(wvc.cache, args...)

function haskey(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds))::Union{Nothing, CodeInstance} !== nothing
return ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds)) !== nothing
end

function get(wvc::WorldView{InternalCodeCache}, mi::MethodInstance, default)
r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds))::Union{Nothing, CodeInstance}
r = ccall(:jl_rettype_inferred, Any, (Any, UInt, UInt), mi, first(wvc.worlds), last(wvc.worlds))
if r === nothing
return default
end
Expand All @@ -66,5 +66,7 @@ function getindex(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
return r::CodeInstance
end

setindex!(wvc::WorldView{InternalCodeCache}, ci::CodeInstance, mi::MethodInstance) =
function setindex!(wvc::WorldView{InternalCodeCache}, ci::CodeInstance, mi::MethodInstance)
setindex!(wvc.cache, ci, mi)
return wvc
end
2 changes: 1 addition & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Core.Intrinsics, Core.IR

import Core: print, println, show, write, unsafe_write, stdout, stderr,
_apply_iterate, svec, apply_type, Builtin, IntrinsicFunction,
MethodInstance, CodeInstance, MethodMatch, PartialOpaque,
MethodInstance, CodeInstance, MethodTable, MethodMatch, PartialOpaque,
TypeofVararg

const getproperty = Core.getfield
Expand Down
24 changes: 10 additions & 14 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,6 @@ end
is_inferred(sv::InferenceState) = is_inferred(sv.result)
is_inferred(result::InferenceResult) = result.result !== nothing

Effects(state::InferenceState) = state.ipo_effects

function merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::Effects)
caller.ipo_effects = merge_effects(caller.ipo_effects, effects)
end
Expand Down Expand Up @@ -505,14 +503,12 @@ end
_topmod(sv::InferenceState) = _topmod(sv.mod)

# work towards converging the valid age range for sv
function update_valid_age!(sv::InferenceState, worlds::WorldRange)
sv.valid_worlds = intersect(worlds, sv.valid_worlds)
@assert(sv.world in sv.valid_worlds, "invalid age range update")
nothing
function update_valid_age!(sv::InferenceState, valid_worlds::WorldRange)
valid_worlds = sv.valid_worlds = intersect(valid_worlds, sv.valid_worlds)
@assert(sv.world in valid_worlds, "invalid age range update")
return valid_worlds
end

update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(sv, edge.valid_worlds)

function record_ssa_assign!(𝕃ᵢ::AbstractLattice, ssa_id::Int, @nospecialize(new), frame::InferenceState)
ssavaluetypes = frame.ssavaluetypes
old = ssavaluetypes[ssa_id]
Expand All @@ -538,32 +534,32 @@ function record_ssa_assign!(𝕃ᵢ::AbstractLattice, ssa_id::Int, @nospecialize
end

function add_cycle_backedge!(caller::InferenceState, frame::InferenceState, currpc::Int)
update_valid_age!(frame, caller)
update_valid_age!(caller, frame.valid_worlds)
backedge = (caller, currpc)
contains_is(frame.cycle_backedges, backedge) || push!(frame.cycle_backedges, backedge)
add_backedge!(caller, frame.linfo)
return frame
end

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(caller::InferenceState, li::MethodInstance)
function add_backedge!(caller::InferenceState, mi::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, li)
push!(edges, mi)
end
return nothing
end

function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), li::MethodInstance)
function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), mi::MethodInstance)
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, invokesig, li)
push!(edges, invokesig, mi)
end
return nothing
end

# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
function add_mt_backedge!(caller::InferenceState, mt::Core.MethodTable, @nospecialize(typ))
function add_mt_backedge!(caller::InferenceState, mt::MethodTable, @nospecialize(typ))
edges = get_stmt_edges!(caller)
if edges !== nothing
push!(edges, mt, typ)
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ external table, e.g., to override existing method.
"""
struct OverlayMethodTable <: MethodTableView
world::UInt
mt::Core.MethodTable
mt::MethodTable
end

struct MethodMatchKey
Expand Down Expand Up @@ -98,7 +98,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
!isempty(result))
end

function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int)
function _findall(@nospecialize(sig::Type), mt::Union{Nothing,MethodTable}, world::UInt, limit::Int)
_min_val = RefValue{UInt}(typemin(UInt))
_max_val = RefValue{UInt}(typemax(UInt))
_ambig = RefValue{Int32}(0)
Expand Down Expand Up @@ -155,7 +155,7 @@ function findsup(@nospecialize(sig::Type), table::OverlayMethodTable)
false)
end

function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt)
function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,MethodTable}, world::UInt)
min_valid = RefValue{UInt}(typemin(UInt))
max_valid = RefValue{UInt}(typemax(UInt))
match = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, Any, UInt, Ptr{Csize_t}, Ptr{Csize_t}),
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -852,15 +852,15 @@ function resolve_todo(mi::MethodInstance, result::Union{MethodMatch,InferenceRes
#XXX: update_valid_age!(min_valid[1], max_valid[1], sv)
if isa(result, InferenceResult)
src = result.src
if is_foldable_nothrow(result.ipo_effects)
effects = result.ipo_effects
if is_foldable_nothrow(effects)
res = result.result
if isa(res, Const) && is_inlineable_constant(res.val)
# use constant calling convention
add_inlining_backedge!(et, mi)
return ConstantCase(quoted(res.val))
end
end
effects = result.ipo_effects
else
cached_result = get_cached_result(state, mi)
if cached_result isa ConstantCase
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ end

function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::SSAValue)
@assert idx.id < compact.result_idx
(compact.result[idx.id][:inst] === v) && return
(compact.result[idx.id][:inst] === v) && return compact
# Kill count for current uses
kill_current_uses!(compact, compact.result[idx.id][:inst])
compact.result[idx.id][:inst] = v
Expand All @@ -949,7 +949,7 @@ function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::OldSSAVal
id = idx.id
if id < compact.idx
new_idx = compact.ssa_rename[id]
(compact.result[new_idx][:inst] === v) && return
(compact.result[new_idx][:inst] === v) && return compact
kill_current_uses!(compact, compact.result[new_idx][:inst])
compact.result[new_idx][:inst] = v
count_added_node!(compact, v) && push!(compact.late_fixup, new_idx)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2605,7 +2605,7 @@ function _hasmethod_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, sv
types = rewrap_unionall(Tuple{ft, unwrapped.parameters...}, types)::Type
end
mt = ccall(:jl_method_table_for, Any, (Any,), types)
if !isa(mt, Core.MethodTable)
if !isa(mt, MethodTable)
return CallMeta(Bool, EFFECTS_THROWS, NoCallInfo())
end
match, valid_worlds, overlayed = findsup(types, method_table(interp))
Expand Down
28 changes: 14 additions & 14 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ function cycle_fix_limited(@nospecialize(typ), sv::InferenceState)
end

function adjust_effects(sv::InferenceState)
ipo_effects = Effects(sv)
ipo_effects = sv.ipo_effects

# refine :consistent-cy effect using the return type information
# TODO this adjustment tries to compromise imprecise :consistent-cy information,
Expand Down Expand Up @@ -577,7 +577,7 @@ function store_backedges(frame::MethodInstance, edges::Vector{Any})
if isa(caller, MethodInstance)
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), caller, sig, frame)
else
typeassert(caller, Core.MethodTable)
typeassert(caller, MethodTable)
ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), caller, sig, frame)
end
end
Expand Down Expand Up @@ -792,28 +792,28 @@ function merge_call_chain!(interp::AbstractInterpreter, parent::InferenceState,
end
end

function is_same_frame(interp::AbstractInterpreter, linfo::MethodInstance, frame::InferenceState)
return linfo === frame.linfo
function is_same_frame(interp::AbstractInterpreter, mi::MethodInstance, frame::InferenceState)
return mi === frame.linfo
end

function poison_callstack(infstate::InferenceState, topmost::InferenceState)
push!(infstate.pclimitations, topmost)
nothing
end

# Walk through `linfo`'s upstream call chain, starting at `parent`. If a parent
# frame matching `linfo` is encountered, then there is a cycle in the call graph
# (i.e. `linfo` is a descendant callee of itself). Upon encountering this cycle,
# Walk through `mi`'s upstream call chain, starting at `parent`. If a parent
# frame matching `mi` is encountered, then there is a cycle in the call graph
# (i.e. `mi` is a descendant callee of itself). Upon encountering this cycle,
# we "resolve" it by merging the call chain, which entails unioning each intermediary
# frame's `callers_in_cycle` field and adding the appropriate backedges. Finally,
# we return `linfo`'s pre-existing frame. If no cycles are found, `nothing` is
# we return `mi`'s pre-existing frame. If no cycles are found, `nothing` is
# returned instead.
function resolve_call_cycle!(interp::AbstractInterpreter, linfo::MethodInstance, parent::InferenceState)
function resolve_call_cycle!(interp::AbstractInterpreter, mi::MethodInstance, parent::InferenceState)
frame = parent
uncached = false
while isa(frame, InferenceState)
uncached |= !frame.cached # ensure we never add an uncached frame to a cycle
if is_same_frame(interp, linfo, frame)
if is_same_frame(interp, mi, frame)
if uncached
# our attempt to speculate into a constant call lead to an undesired self-cycle
# that cannot be converged: poison our call-stack (up to the discovered duplicate frame)
Expand All @@ -825,7 +825,7 @@ function resolve_call_cycle!(interp::AbstractInterpreter, linfo::MethodInstance,
return frame
end
for caller in frame.callers_in_cycle
if is_same_frame(interp, linfo, caller)
if is_same_frame(interp, mi, caller)
if uncached
poison_callstack(parent, frame)
return true
Expand Down Expand Up @@ -916,16 +916,16 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
frame.parent = caller
end
typeinf(interp, frame)
update_valid_age!(frame, caller)
update_valid_age!(caller, frame.valid_worlds)
edge = is_inferred(frame) ? mi : nothing
return EdgeCallResult(frame.bestguess, edge, Effects(frame)) # effects are adjusted already within `finish`
return EdgeCallResult(frame.bestguess, edge, frame.ipo_effects) # effects are adjusted already within `finish`
elseif frame === true
# unresolvable cycle
return EdgeCallResult(Any, nothing, Effects())
end
# return the current knowledge about this cycle
frame = frame::InferenceState
update_valid_age!(frame, caller)
update_valid_age!(caller, frame.valid_worlds)
return EdgeCallResult(frame.bestguess, nothing, adjust_effects(frame))
end

Expand Down
14 changes: 7 additions & 7 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ Return an iterator over a list of backedges. Iteration returns `(sig, caller)` e
which will be one of the following:
- `BackedgePair(nothing, caller::MethodInstance)`: a call made by ordinary inferable dispatch
- `BackedgePair(invokesig, caller::MethodInstance)`: a call made by `invoke(f, invokesig, args...)`
- `BackedgePair(specsig, mt::MethodTable)`: an abstract call
- `BackedgePair(invokesig::Type, caller::MethodInstance)`: a call made by `invoke(f, invokesig, args...)`
- `BackedgePair(specsig::Type, mt::MethodTable)`: an abstract call
# Examples
Expand Down Expand Up @@ -305,17 +305,17 @@ const empty_backedge_iter = BackedgeIterator(Any[])

struct BackedgePair
sig # ::Union{Nothing,Type}
caller::Union{MethodInstance,Core.MethodTable}
BackedgePair(@nospecialize(sig), caller::Union{MethodInstance,Core.MethodTable}) = new(sig, caller)
caller::Union{MethodInstance,MethodTable}
BackedgePair(@nospecialize(sig), caller::Union{MethodInstance,MethodTable}) = new(sig, caller)
end

function iterate(iter::BackedgeIterator, i::Int=1)
backedges = iter.backedges
i > length(backedges) && return nothing
item = backedges[i]
isa(item, MethodInstance) && return BackedgePair(nothing, item), i+1 # regular dispatch
isa(item, Core.MethodTable) && return BackedgePair(backedges[i+1], item), i+2 # abstract dispatch
return BackedgePair(item, backedges[i+1]::MethodInstance), i+2 # `invoke` calls
isa(item, MethodInstance) && return BackedgePair(nothing, item), i+1 # regular dispatch
isa(item, MethodTable) && return BackedgePair(backedges[i+1], item), i+2 # abstract dispatch
return BackedgePair(item, backedges[i+1]::MethodInstance), i+2 # `invoke` calls
end

#########
Expand Down

0 comments on commit a2912e2

Please sign in to comment.