Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: ~~refine~~ replace recursion non-detection algorithm #23912

Merged
merged 5 commits into from
Oct 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -628,18 +628,28 @@ function _collect_indices(indsA, A)
copy!(B, CartesianRange(indices(B)), A, CartesianRange(indsA))
end

# define this as a macro so that the call to Inference
# gets inlined into the caller before recursion detection
# gets a chance to see it, so that recursive calls to the caller
# don't trigger the inference limiter
if isdefined(Core, :Inference)
_default_eltype(@nospecialize itrt) = Core.Inference.return_type(first, Tuple{itrt})
macro default_eltype(itrt)
return quote
Core.Inference.return_type(first, Tuple{$(esc(itrt))})
end
end
else
_default_eltype(@nospecialize itr) = Any
macro default_eltype(itrt)
return :(Any)
end
end

_array_for(::Type{T}, itr, ::HasLength) where {T} = Array{T,1}(Int(length(itr)::Integer))
_array_for(::Type{T}, itr, ::HasShape) where {T} = similar(Array{T}, indices(itr))

function collect(itr::Generator)
isz = iteratorsize(itr.iter)
et = _default_eltype(typeof(itr))
et = @default_eltype(typeof(itr))
if isa(isz, SizeUnknown)
return grow_to!(Array{et,1}(0), itr)
else
Expand All @@ -653,12 +663,12 @@ function collect(itr::Generator)
end

_collect(c, itr, ::EltypeUnknown, isz::SizeUnknown) =
grow_to!(_similar_for(c, _default_eltype(typeof(itr)), itr, isz), itr)
grow_to!(_similar_for(c, @default_eltype(typeof(itr)), itr, isz), itr)

function _collect(c, itr, ::EltypeUnknown, isz::Union{HasLength,HasShape})
st = start(itr)
if done(itr,st)
return _similar_for(c, _default_eltype(typeof(itr)), itr, isz)
return _similar_for(c, @default_eltype(typeof(itr)), itr, isz)
end
v1, st = next(itr, st)
collect_to_with_first!(_similar_for(c, typeof(v1), itr, isz), v1, itr, st)
Expand Down
4 changes: 2 additions & 2 deletions base/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ associative_with_eltype(DT_apply, kv, ::TP{K,V}) where {K,V} = DT_apply(K, V)(kv
associative_with_eltype(DT_apply, kv::Generator, ::TP{K,V}) where {K,V} = DT_apply(K, V)(kv)
associative_with_eltype(DT_apply, ::Type{Pair{K,V}}) where {K,V} = DT_apply(K, V)()
associative_with_eltype(DT_apply, ::Type) = DT_apply(Any, Any)()
associative_with_eltype(DT_apply::F, kv, t) where {F} = grow_to!(associative_with_eltype(DT_apply, _default_eltype(typeof(kv))), kv)
associative_with_eltype(DT_apply::F, kv, t) where {F} = grow_to!(associative_with_eltype(DT_apply, @default_eltype(typeof(kv))), kv)
function associative_with_eltype(DT_apply::F, kv::Generator, t) where F
T = _default_eltype(typeof(kv))
T = @default_eltype(typeof(kv))
if T <: Union{Pair, Tuple{Any, Any}} && _isleaftype(T)
return associative_with_eltype(DT_apply, kv, T)
end
Expand Down
282 changes: 174 additions & 108 deletions base/inference.jl

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ Returns the sum of all elements of `A`, using the Kahan-Babuska-Neumaier compens
summation algorithm for additional accuracy.
"""
function sum_kbn(A)
T = _default_eltype(typeof(A))
T = @default_eltype(typeof(A))
c = r_promote(+, zero(T)::T)
i = start(A)
if done(A, i)
Expand Down
4 changes: 2 additions & 2 deletions base/set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ for sets of arbitrary objects.
"""
Set(itr) = Set{eltype(itr)}(itr)
function Set(g::Generator)
T = _default_eltype(typeof(g))
T = @default_eltype(typeof(g))
(_isleaftype(T) || T === Union{}) || return grow_to!(Set{T}(), g)
return Set{T}(g)
end
Expand Down Expand Up @@ -258,7 +258,7 @@ julia> unique(Real[1, 1.0, 2])
```
"""
function unique(itr)
T = _default_eltype(typeof(itr))
T = @default_eltype(typeof(itr))
out = Vector{T}()
seen = Set{T}()
i = start(itr)
Expand Down
43 changes: 20 additions & 23 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -926,37 +926,34 @@ end
# vectors/matrices in mixedargs in their orginal order, and such that the result of
# broadcast(parevalf, passedargstup...) is broadcast(f, mixedargs...)
@inline function capturescalars(f, mixedargs)
let makeargs = _capturescalars(mixedargs...),
parevalf = (passed...) -> f(makeargs(passed...)...),
passedsrcargstup = _capturenonscalars(mixedargs...)
let (passedsrcargstup, makeargs) = _capturescalars(mixedargs...)
parevalf = (passed...) -> f(makeargs(passed...)...)
return (parevalf, passedsrcargstup)
end
end

@inline _capturenonscalars(nonscalararg::SparseVecOrMat, mixedargs...) =
(nonscalararg, _capturenonscalars(mixedargs...)...)
@inline _capturenonscalars(scalararg, mixedargs...) =
_capturenonscalars(mixedargs...)
@inline _capturenonscalars() = ()
nonscalararg(::SparseVecOrMat) = true
nonscalararg(::Any) = false

@inline _capturescalars(nonscalararg::SparseVecOrMat, mixedargs...) =
let f = _capturescalars(mixedargs...)
(head, tail...) -> (head, f(tail...)...) # pass-through
@inline function _capturescalars()
return (), () -> ()
end
@inline function _capturescalars(arg, mixedargs...)
let (rest, f) = _capturescalars(mixedargs...)
if nonscalararg(arg)
return (arg, rest...), (head, tail...) -> (head, f(tail...)...) # pass-through to broadcast
else
return rest, (tail...) -> (arg, f(tail...)...) # add back scalararg after (in makeargs)
end
end
@inline _capturescalars(scalararg, mixedargs...) =
let f = _capturescalars(mixedargs...)
(tail...) -> (scalararg, f(tail...)...) # add scalararg
end
@inline function _capturescalars(arg) # this definition is just an optimization (to bottom out the recursion slightly sooner)
if nonscalararg(arg)
return (arg,), (head,) -> (head,) # pass-through
else
return (), () -> (arg,) # add scalararg
end
# TODO: use the implicit version once inference can handle it
# handle too-many-arguments explicitly
@inline function _capturescalars()
too_many_arguments() = ()
too_many_arguments(tail...) = throw(ArgumentError("too many"))
end
#@inline _capturescalars(nonscalararg::SparseVecOrMat) =
# (head,) -> (head,) # pass-through
#@inline _capturescalars(scalararg) =
# () -> (scalararg,) # add scalararg

# NOTE: The following two method definitions work around #19096.
broadcast(f::Tf, ::Type{T}, A::SparseMatrixCSC) where {Tf,T} = broadcast(y -> f(T, y), A)
Expand Down
25 changes: 18 additions & 7 deletions src/rtutils.c
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,12 @@ static size_t jl_static_show_x_(JL_STREAM *out, jl_value_t *v, jl_datatype_t *vt
else if (vt == jl_method_instance_type) {
jl_method_instance_t *li = (jl_method_instance_t*)v;
if (jl_is_method(li->def.method)) {
jl_method_t *m = li->def.method;
n += jl_static_show_x(out, (jl_value_t*)m->module, depth);
if (li->specTypes) {
n += jl_printf(out, ".");
n += jl_show_svec(out, ((jl_datatype_t*)jl_unwrap_unionall(li->specTypes))->parameters,
jl_symbol_name(m->name), "(", ")");
n += jl_static_show_func_sig(out, li->specTypes);
}
else {
jl_method_t *m = li->def.method;
n += jl_static_show_x(out, (jl_value_t*)m->module, depth);
n += jl_printf(out, ".%s(?)", jl_symbol_name(m->name));
}
}
Expand Down Expand Up @@ -949,15 +947,15 @@ JL_DLLEXPORT size_t jl_static_show_func_sig(JL_STREAM *s, jl_value_t *type)
if (ftype == NULL)
return jl_static_show(s, type);
size_t n = 0;
if (jl_nparams(ftype)==0 || ftype == ((jl_datatype_t*)ftype)->name->wrapper) {
if (jl_nparams(ftype) == 0 || ftype == ((jl_datatype_t*)ftype)->name->wrapper) {
n += jl_printf(s, "%s", jl_symbol_name(((jl_datatype_t*)ftype)->name->mt->name));
}
else {
n += jl_printf(s, "(::");
n += jl_static_show(s, ftype);
n += jl_printf(s, ")");
}
// TODO: better way to show method parameters
jl_unionall_t *tvars = (jl_unionall_t*)type;
type = jl_unwrap_unionall(type);
if (!jl_is_datatype(type)) {
n += jl_printf(s, " ");
Expand All @@ -984,6 +982,19 @@ JL_DLLEXPORT size_t jl_static_show_func_sig(JL_STREAM *s, jl_value_t *type)
}
}
n += jl_printf(s, ")");
if (jl_is_unionall(tvars)) {
int first = 1;
n += jl_printf(s, " where {");
while (jl_is_unionall(tvars)) {
if (first)
first = 0;
else
n += jl_printf(s, ", ");
n += jl_static_show(s, (jl_value_t*)tvars->var);
tvars = (jl_unionall_t*)tvars->body;
}
n += jl_printf(s, "}");
}
return n;
}

Expand Down
8 changes: 3 additions & 5 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3339,10 +3339,6 @@ end
@test EmptyIIOtherField13175(EmptyImmutable13175(), 1.0) == EmptyIIOtherField13175(EmptyImmutable13175(), 1.0)
@test EmptyIIOtherField13175(EmptyImmutable13175(), 1.0) != EmptyIIOtherField13175(EmptyImmutable13175(), 2.0)

# issue #13183
gg13183(x::X...) where {X} = 1==0 ? gg13183(x, x) : 0
@test gg13183(5) == 0

# issue 8932 (llvm return type legalizer error)
struct Vec3_8932
x::Float32
Expand Down Expand Up @@ -5317,7 +5313,8 @@ module UnionOptimizations
using Test

const boxedunions = [Union{}, Union{String, Void}]
const unboxedunions = [Union{Int8, Void}, Union{Int8, Float16, Void},
const unboxedunions = [Union{Int8, Void},
Union{Int8, Float16, Void},
Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Int128, UInt128},
Union{Char, Date, Int}]

Expand Down Expand Up @@ -5443,6 +5440,7 @@ t4 = vcat(A23567, t2, t3)
@test t4[11:15] == A23567

for U in unboxedunions
Base.unionlen(U) > 5 && continue # larger values cause subtyping to crash
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related? If so, how? Does the better inference perform more/other subtype checks that now trigger the crash?

local U
for N in (1, 2, 3, 4)
A = Array{U}(ntuple(x->0, N)...)
Expand Down
28 changes: 23 additions & 5 deletions test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@

# tests for Core.Inference correctness and precision
import Core.Inference: Const, Conditional, ⊑
const isleaftype = Core.Inference._isleaftype

# demonstrate some of the type-size limits
@test Core.Inference.limit_type_size(Ref{Complex{T} where T}, Ref, Ref, 0) == Ref
@test Core.Inference.limit_type_size(Ref{Complex{T} where T}, Ref{Complex{T} where T}, Ref, 0) == Ref{Complex{T} where T}
let comparison = Tuple{X, X} where X<:Tuple
sig = Tuple{X, X} where X<:comparison
ref = Tuple{X, X} where X
@test Core.Inference.limit_type_size(sig, comparison, comparison, 10) == comparison
@test Core.Inference.limit_type_size(sig, ref, comparison, 10) == comparison
@test Core.Inference.limit_type_size(Tuple{sig}, Tuple{ref}, comparison, 10) == Tuple{comparison}
@test Core.Inference.limit_type_size(sig, ref, Tuple{comparison}, 10) == sig
end


# issue 9770
@noinline x9770() = false
Expand Down Expand Up @@ -186,7 +200,6 @@ function find_tvar10930(arg)
end
@test find_tvar10930(Vararg{Int}) === 1

const isleaftype = Base._isleaftype

# issue #12474
@generated function f12474(::Any)
Expand Down Expand Up @@ -980,13 +993,13 @@ copy_dims_out(out) = ()
copy_dims_out(out, dim::Int, tail...) = copy_dims_out((out..., dim), tail...)
copy_dims_out(out, dim::Colon, tail...) = copy_dims_out((out..., dim), tail...)
@test Base.return_types(copy_dims_out, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}]
@test all(m -> 2 < count_specializations(m) < 15, methods(copy_dims_out))
@test all(m -> 10 < count_specializations(m) < 25, methods(copy_dims_out))

copy_dims_pair(out) = ()
copy_dims_pair(out, dim::Int, tail...) = copy_dims_out(out => dim, tail...)
copy_dims_pair(out, dim::Colon, tail...) = copy_dims_out(out => dim, tail...)
copy_dims_pair(out, dim::Int, tail...) = copy_dims_pair(out => dim, tail...)
copy_dims_pair(out, dim::Colon, tail...) = copy_dims_pair(out => dim, tail...)
@test Base.return_types(copy_dims_pair, (Tuple{}, Vararg{Union{Int,Colon}})) == Any[Tuple{}, Tuple{}, Tuple{}]
@test all(m -> 5 < count_specializations(m) < 25, methods(copy_dims_out))
@test all(m -> 5 < count_specializations(m) < 25, methods(copy_dims_pair))

# splatting an ::Any should still allow inference to use types of parameters preceding it
f22364(::Int, ::Any...) = 0
Expand Down Expand Up @@ -1225,3 +1238,8 @@ end
let t = Tuple{Type{T23786{D, N} where N where D<:Tuple{Vararg{Array{T, 1} where T, N} where N}}}
@test Core.Inference.limit_type_depth(t, 4) >: t
end

# issue #13183
_false13183 = false
gg13183(x::X...) where {X} = (_false13183 ? gg13183(x, x) : 0)
@test gg13183(5) == 0