Skip to content

Commit

Permalink
inference: enable :call inference in irinterp (#49191)
Browse files Browse the repository at this point in the history
* inference: enable `:call` inference in irinterp

Built on top of #48913, this commit enables `:call` inference in irinterp.
In a case when some regression is detected, we can simply revert this
commit rather than reverting the whole refactoring from #48913.

* fix irinterp lattice

Now `LimitedAccuracy` can appear in irinterp, so we should include
`InferenceLattice` for `[typeinf|ipo]_lattice` for irinterp.
  • Loading branch information
aviatesk authored May 4, 2023
1 parent 5032a1a commit c0e12cd
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 37 deletions.
13 changes: 3 additions & 10 deletions base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# TODO (#48913) remove this overload to enable interprocedural call inference from irinterp
function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype),
sv::IRInterpretationState, max_methods::Int)
return CallMeta(Any, Effects(), NoCallInfo())
end

function collect_limitations!(@nospecialize(typ), ::IRInterpretationState)
@assert !isa(typ, LimitedAccuracy) "irinterp is unable to handle heavy recursion"
return typ
Expand Down Expand Up @@ -147,15 +140,15 @@ function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union
# Handled at the very end
return false
elseif isa(inst, PiNode)
rt = tmeet(optimizer_lattice(interp), argextype(inst.val, ir), widenconst(inst.typ))
rt = tmeet(typeinf_lattice(interp), argextype(inst.val, ir), widenconst(inst.typ))
elseif inst === nothing
return false
elseif isa(inst, GlobalRef)
# GlobalRef is not refinable
else
error("reprocess_instruction!: unhandled instruction found")
end
if rt !== nothing && !(optimizer_lattice(interp), typ, rt)
if rt !== nothing && !(typeinf_lattice(interp), typ, rt)
ir.stmts[idx][:type] = rt
return true
end
Expand Down Expand Up @@ -323,7 +316,7 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
end
inst = ir.stmts[idx][:inst]::ReturnNode
rt = argextype(inst.val, ir)
ultimate_rt = tmerge(optimizer_lattice(interp), ultimate_rt, rt)
ultimate_rt = tmerge(typeinf_lattice(interp), ultimate_rt, rt)
end
end

Expand Down
8 changes: 6 additions & 2 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,12 @@ typeinf_lattice(::AbstractInterpreter) = InferenceLattice(BaseInferenceLattice.i
ipo_lattice(::AbstractInterpreter) = InferenceLattice(IPOResultLattice.instance)
optimizer_lattice(::AbstractInterpreter) = OptimizerLattice(SimpleInferenceLattice.instance)

typeinf_lattice(interp::NativeInterpreter) = interp.irinterp ? optimizer_lattice(interp) : InferenceLattice(BaseInferenceLattice.instance)
ipo_lattice(interp::NativeInterpreter) = interp.irinterp ? optimizer_lattice(interp) : InferenceLattice(IPOResultLattice.instance)
typeinf_lattice(interp::NativeInterpreter) = interp.irinterp ?
OptimizerLattice(InferenceLattice(SimpleInferenceLattice.instance)) :
InferenceLattice(BaseInferenceLattice.instance)
ipo_lattice(interp::NativeInterpreter) = interp.irinterp ?
InferenceLattice(SimpleInferenceLattice.instance) :
InferenceLattice(IPOResultLattice.instance)
optimizer_lattice(interp::NativeInterpreter) = OptimizerLattice(SimpleInferenceLattice.instance)

"""
Expand Down
49 changes: 24 additions & 25 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4784,31 +4784,30 @@ fhasmethod(::Integer, ::Int32) = 3
@test only(Base.return_types(()) do; Val(hasmethod(sin, Tuple{Int, Vararg{Int}})); end) == Val{false}
@test only(Base.return_types(()) do; Val(hasmethod(sin, Tuple{Int, Int, Vararg{Int}})); end) === Val{false}

# TODO (#48913) enable interprocedural call inference from irinterp
# # interprocedural call inference from irinterp
# @noinline Base.@assume_effects :total issue48679_unknown_any(x) = Base.inferencebarrier(x)

# @noinline _issue48679(y::Union{Nothing,T}) where {T} = T::Type
# Base.@constprop :aggressive function issue48679(x, b)
# if b
# x = issue48679_unknown_any(x)
# end
# return _issue48679(x)
# end
# @test Base.return_types((Float64,)) do x
# issue48679(x, false)
# end |> only == Type{Float64}

# Base.@constprop :aggressive @noinline _issue48679_const(b, y::Union{Nothing,T}) where {T} = b ? nothing : T::Type
# Base.@constprop :aggressive function issue48679_const(x, b)
# if b
# x = issue48679_unknown_any(x)
# end
# return _issue48679_const(b, x)
# end
# @test Base.return_types((Float64,)) do x
# issue48679_const(x, false)
# end |> only == Type{Float64}
# interprocedural call inference from irinterp
@noinline Base.@assume_effects :total issue48679_unknown_any(x) = Base.inferencebarrier(x)

@noinline _issue48679(y::Union{Nothing,T}) where {T} = T::Type
Base.@constprop :aggressive function issue48679(x, b)
if b
x = issue48679_unknown_any(x)
end
return _issue48679(x)
end
@test Base.return_types((Float64,)) do x
issue48679(x, false)
end |> only == Type{Float64}

Base.@constprop :aggressive @noinline _issue48679_const(b, y::Union{Nothing,T}) where {T} = b ? nothing : T::Type
Base.@constprop :aggressive function issue48679_const(x, b)
if b
x = issue48679_unknown_any(x)
end
return _issue48679_const(b, x)
end
@test Base.return_types((Float64,)) do x
issue48679_const(x, false)
end |> only == Type{Float64}

# `invoke` call in irinterp
@noinline _irinterp_invoke(x::Any) = :any
Expand Down

0 comments on commit c0e12cd

Please sign in to comment.