Skip to content

Commit

Permalink
irinterp: reuse more inference routines
Browse files Browse the repository at this point in the history
This is required to handle e.g. `invoke` in irinterp.
  • Loading branch information
aviatesk committed Mar 16, 2023
1 parent 9b21e87 commit b780071
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 31 deletions.
26 changes: 13 additions & 13 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1854,7 +1854,7 @@ function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{An
return CallMeta(Any, EFFECTS_UNKNOWN, NoCallInfo())
end

function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, si::StmtInfo, sv::InferenceState)
function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, si::StmtInfo, sv::AbsIntState)
ft′ = argtype_by_index(argtypes, 2)
ft = widenconst(ft′)
ft === Bottom && return CallMeta(Bottom, EFFECTS_THROWS, NoCallInfo())
Expand Down Expand Up @@ -1917,7 +1917,7 @@ function invoke_rewrite(xs::Vector{Any})
return newxs
end

function abstract_finalizer(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
function abstract_finalizer(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::AbsIntState)
if length(argtypes) == 3
finalizer_argvec = Any[argtypes[2], argtypes[3]]
call = abstract_call(interp, ArgInfo(nothing, finalizer_argvec), StmtInfo(false), sv, 1)
Expand Down Expand Up @@ -2134,10 +2134,10 @@ function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
return unwraptv(T)
end

function abstract_eval_cfunction(interp::AbstractInterpreter, e::Expr, vtypes::VarTable, sv::InferenceState)
function abstract_eval_cfunction(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable,Nothing}, sv::AbsIntState)
f = abstract_eval_value(interp, e.args[2], vtypes, sv)
# rt = sp_type_rewrap(e.args[3], sv.linfo, true)
at = Any[ sp_type_rewrap(argt, sv.linfo, false) for argt in e.args[4]::SimpleVector ]
at = Any[ sp_type_rewrap(argt, frame_instance(sv), false) for argt in e.args[4]::SimpleVector ]
pushfirst!(at, f)
# this may be the wrong world for the call,
# but some of the result is likely to be valid anyways
Expand All @@ -2146,7 +2146,7 @@ function abstract_eval_cfunction(interp::AbstractInterpreter, e::Expr, vtypes::V
nothing
end

function abstract_eval_value_expr(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable, Nothing}, sv::AbsIntState)
function abstract_eval_value_expr(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable,Nothing}, sv::AbsIntState)
rt = Any
head = e.head
if head === :static_parameter
Expand Down Expand Up @@ -2188,7 +2188,7 @@ function abstract_eval_value_expr(interp::AbstractInterpreter, e::Expr, vtypes::
return rt
end

function abstract_eval_special_value(interp::AbstractInterpreter, @nospecialize(e), vtypes::Union{VarTable, Nothing}, sv::AbsIntState)
function abstract_eval_special_value(interp::AbstractInterpreter, @nospecialize(e), vtypes::Union{VarTable,Nothing}, sv::AbsIntState)
if isa(e, QuoteNode)
return Const(e.value)
elseif isa(e, SSAValue)
Expand Down Expand Up @@ -2217,7 +2217,7 @@ function abstract_eval_special_value(interp::AbstractInterpreter, @nospecialize(
return Const(e)
end

function abstract_eval_value(interp::AbstractInterpreter, @nospecialize(e), vtypes::Union{VarTable, Nothing}, sv::AbsIntState)
function abstract_eval_value(interp::AbstractInterpreter, @nospecialize(e), vtypes::Union{VarTable,Nothing}, sv::AbsIntState)
if isa(e, Expr)
return abstract_eval_value_expr(interp, e, vtypes, sv)
else
Expand All @@ -2226,7 +2226,7 @@ function abstract_eval_value(interp::AbstractInterpreter, @nospecialize(e), vtyp
end
end

function collect_argtypes(interp::AbstractInterpreter, ea::Vector{Any}, vtypes::Union{VarTable, Nothing}, sv::AbsIntState)
function collect_argtypes(interp::AbstractInterpreter, ea::Vector{Any}, vtypes::Union{VarTable,Nothing}, sv::AbsIntState)
n = length(ea)
argtypes = Vector{Any}(undef, n)
@inbounds for i = 1:n
Expand Down Expand Up @@ -2259,8 +2259,8 @@ function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, sv::Infere
return RTEffects(rt, effects)
end

function abstract_eval_call(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable, Nothing},
sv::InferenceState)
function abstract_eval_call(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable,Nothing},
sv::AbsIntState)
ea = e.args
argtypes = collect_argtypes(interp, ea, vtypes, sv)
if argtypes === nothing
Expand All @@ -2270,7 +2270,7 @@ function abstract_eval_call(interp::AbstractInterpreter, e::Expr, vtypes::Union{
return abstract_call(interp, arginfo, sv)
end

function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable, Nothing},
function abstract_eval_statement_expr(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable,Nothing},
sv::AbsIntState)
effects = EFFECTS_UNKNOWN
ehead = e.head
Expand Down Expand Up @@ -2471,7 +2471,7 @@ function refine_partial_type(@nospecialize t)
return t
end

function abstract_eval_foreigncall(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable, Nothing}, sv::AbsIntState)
function abstract_eval_foreigncall(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable,Nothing}, sv::AbsIntState)
abstract_eval_value(interp, e.args[1], vtypes, sv)
mi = frame_instance(sv)
t = sp_type_rewrap(e.args[2], mi, true)
Expand Down Expand Up @@ -2499,7 +2499,7 @@ function abstract_eval_foreigncall(interp::AbstractInterpreter, e::Expr, vtypes:
return RTEffects(t, effects)
end

function abstract_eval_phi(interp::AbstractInterpreter, phi::PhiNode, vtypes::Union{VarTable, Nothing}, sv::AbsIntState)
function abstract_eval_phi(interp::AbstractInterpreter, phi::PhiNode, vtypes::Union{VarTable,Nothing}, sv::AbsIntState)
rt = Union{}
for i in 1:length(phi.values)
isassigned(phi.values, i) || continue
Expand Down
23 changes: 8 additions & 15 deletions base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ function propagate_control_effects!(interp::AbstractInterpreter, idx::Int, stmt:
return false
end

function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, irsv::IRInterpretationState)
si = StmtInfo(true) # TODO better job here?
(; rt, effects, info) = abstract_call(interp, arginfo, si, irsv)
irsv.ir.stmts[irsv.curridx[]][:info] = info
return RTEffects(rt, effects)
end

function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union{Int,Nothing},
@nospecialize(inst), @nospecialize(typ), irsv::IRInterpretationState,
extra_reprocess::Union{Nothing,BitSet,BitSetBoundedMinPrioritySet})
Expand Down Expand Up @@ -103,21 +110,7 @@ function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union
rt = nothing
if isa(inst, Expr)
head = inst.head
if head === :call
argtypes = collect_argtypes(interp, inst.args, nothing, irsv)
if argtypes === nothing
rt = Bottom
else
arginfo = ArgInfo(inst.args, argtypes)
si = StmtInfo(true) # TODO better job here?
(; rt, effects, info) = abstract_call(interp, arginfo, si, irsv)
ir.stmts[idx][:flag] |= flags_for_effects(effects)
if is_foldable(effects) && isa(rt, Const) && is_inlineable_constant(rt.val)
ir.stmts[idx][:inst] = quoted(rt.val)
end
ir.stmts[idx][:info] = info
end
elseif head === :foreigncall || head === :new || head === :splatnew
if head === :call || head === :foreigncall || head === :new || head === :splatnew
(; rt, effects) = abstract_eval_statement_expr(interp, inst, nothing, irsv)
ir.stmts[idx][:flag] |= flags_for_effects(effects)
if is_foldable(effects) && isa(rt, Const) && is_inlineable_constant(rt.val)
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1362,7 +1362,7 @@ end
PT = Const(Pair)
return instanceof_tfunc(apply_type_tfunc(𝕃, PT, T, T))[1]
end
function abstract_modifyfield!(interp::AbstractInterpreter, argtypes::Vector{Any}, si::StmtInfo, sv::InferenceState)
function abstract_modifyfield!(interp::AbstractInterpreter, argtypes::Vector{Any}, si::StmtInfo, sv::AbsIntState)
nargs = length(argtypes)
if !isempty(argtypes) && isvarargtype(argtypes[nargs])
nargs - 1 <= 6 || return CallMeta(Bottom, EFFECTS_THROWS, NoCallInfo())
Expand Down Expand Up @@ -2537,7 +2537,7 @@ end

# a simplified model of abstract_call_gf_by_type for applicable
function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
sv::InferenceState, max_methods::Int)
sv::AbsIntState, max_methods::Int)
length(argtypes) < 2 && return CallMeta(Union{}, EFFECTS_UNKNOWN, NoCallInfo())
isvarargtype(argtypes[2]) && return CallMeta(Bool, EFFECTS_UNKNOWN, NoCallInfo())
argtypes = argtypes[2:end]
Expand Down Expand Up @@ -2583,7 +2583,7 @@ end
add_tfunc(applicable, 1, INT_INF, @nospecs((𝕃::AbstractLattice, f, args...)->Bool), 40)

# a simplified model of abstract_invoke for Core._hasmethod
function _hasmethod_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
function _hasmethod_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::AbsIntState)
if length(argtypes) == 3 && !isvarargtype(argtypes[3])
ft′ = argtype_by_index(argtypes, 2)
ft = widenconst(ft′)
Expand Down
10 changes: 10 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4761,6 +4761,16 @@ end
issue48679_const(x, false)
end |> only == Type{Float64}

# `invoke` call in irinterp
@noinline _irinterp_invoke(x::Any) = :any
@noinline _irinterp_invoke(x::T) where T = T
Base.@constprop :aggressive Base.@assume_effects :foldable function irinterp_invoke(x::T, b) where T
return @invoke _irinterp_invoke(x::(b ? T : Any))
end
@test Base.return_types((Int,)) do x
irinterp_invoke(x, true)
end |> only == Type{Int}

# recursion detection for semi-concrete interpretation
# avoid direct infinite loop via `concrete_eval_invoke`
Base.@assume_effects :foldable function recur_irinterp1(x, y)
Expand Down

0 comments on commit b780071

Please sign in to comment.