From e782f54bd10a1a7a17fa3ea8376466b142b1797e Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Sat, 9 Feb 2019 19:48:58 -0500 Subject: [PATCH] Get better type info from `partially generated` functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Consider the following function: ``` julia> function foo(a, b) ntuple(i->(a+b; i), Val(4)) end foo (generic function with 1 method) ``` (In particular note that the return type of the closure does not depend on the types of `a` and b`). Unfortunately, prior to this change, inference was unable to determine the return type in this situation: ``` julia> code_typed(foo, Tuple{Any, Any}, trace=true) Refused to call generated function with non-concrete argument types ntuple(::getfield(Main, Symbol("##15#16")){_A,_B} where _B where _A, ::Val{4}) [GeneratedNotConcrete] 1-element Array{Any,1}: CodeInfo( 1 ─ %1 = Main.:(##15#16)::Const(##15#16, false) │ %2 = Core.typeof(a)::DataType │ %3 = Core.typeof(b)::DataType │ %4 = Core.apply_type(%1, %2, %3)::Type{##15#16{_A,_B}} where _B where _A │ %5 = %new(%4, a, b)::##15#16{_A,_B} where _B where _A │ %6 = Main.ntuple(%5, $(QuoteNode(Val{4}())))::Any └── return %6 ) => Any ``` Looking at the definition of ntuple https://github.com/JuliaLang/julia/blob/abb09f88804c4e74c752a66157e767c9b0f8945d/base/ntuple.jl#L45-L56 we see that it is a generated function an inference thus refuses to invoke it, unless it can prove the concrete type of *all* arguments to the function. As the above example illustrates, this restriction is more stringent than necessary. It is true that we cannot invoke generated functions on arbitrary abstract signatures (because we neither want to the user to have to be able to nor do we trust that users are able to preverse monotonicity - i.e. that the return type of the generated code will always be a subtype of the return type of a more abstract signature). However, if some piece of information is not used (the type of the passed function in this case), there is no problem with calling the generated function (since information that is unnused cannot possibly affect monotnicity). This PR allows us to recognize pieces of information that are *syntactically* unused, and call the generated functions, even if we do not have those pieces of information. As a result, we are now able to infer the return type of the above function: ``` julia> code_typed(foo, Tuple{Any, Any}) 1-element Array{Any,1}: CodeInfo( 1 ─ %1 = Main.:(##3#4)::Const(##3#4, false) │ %2 = Core.typeof(a)::DataType │ %3 = Core.typeof(b)::DataType │ %4 = Core.apply_type(%1, %2, %3)::Type{##3#4{_A,_B}} where _B where _A │ %5 = %new(%4, a, b)::##3#4{_A,_B} where _B where _A │ %6 = Main.ntuple(%5, $(QuoteNode(Val{4}())))::NTuple{4,Int64} └── return %6 ) => NTuple{4,Int64} ``` In particular, we use the new frontent `used` flags from the previous commit. One additional complication is that we want to accesss these flags without uncompressing the generator source, so we change the compression scheme to place the flags at a known location. Fixes #31004 --- base/compiler/utilities.jl | 5 +- base/reflection.jl | 62 +++++++++++++++++++++++-- src/dump.c | 24 ++++++++-- src/julia.h | 1 + stdlib/InteractiveUtils/src/codeview.jl | 2 +- test/compiler/inference.jl | 20 +++++++- 6 files changed, 101 insertions(+), 13 deletions(-) diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index 43d2bdd9315ebc..ff6e89bdbb299e 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -115,10 +115,7 @@ function code_for_method(method::Method, @nospecialize(atypes), sparams::SimpleV if world < min_world(method) || world > max_world(method) return nothing end - if isdefined(method, :generator) && !isdispatchtuple(atypes) - # don't call staged functions on abstract types. - # (see issues #8504, #10230) - # we can't guarantee that their type behavior is monotonic. + if isdefined(method, :generator) && !may_invoke_generator(method, atypes, sparams) return nothing end if preexisting diff --git a/base/reflection.jl b/base/reflection.jl index c02600fcd89671..469a9b5f472562 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -941,9 +941,63 @@ struct CodegenParams emit_function, emitted_function) end +const SLOT_USED = 0x8 +ast_slotflag(@nospecialize(code), i) = ccall(:jl_ast_slotflag, UInt8, (Any, Csize_t), code, i - 1) + +""" + may_invoke_generator(method, atypes, sparams) + +Computes whether or not we may invoke the generator for the given `method` on +the given atypes and sparams. For correctness, all generated function are +required to return monotonic answers. However, since we don't expect users to +be able to successfully implement this criterion, we only call generated +functions on concrete types. The one exception to this is that we allow calling +generators with abstract types if the generator does not use said abstract type +(and thus cannot incorrectly use it to break monotonicity). This function +computes whether we are in either of these cases. +""" +function may_invoke_generator(method::Method, @nospecialize(atypes), sparams::SimpleVector) + # If we have complete information, we may always call the generator + isdispatchtuple(atypes) && return true + + # We don't have complete information, but it is possible that the generator + # syntactically doesn't make use of the information we don't have. Check + # for that. + + # For now, only handle the (common, generated by the frontend case) that the + # generator only has one method + isa(method.generator, Core.GeneratedFunctionStub) || return false + generator_mt = typeof(method.generator.gen).name.mt + length(generator_mt) == 1 || return false + + generator_method = first(MethodList(generator_mt)) + nsparams = length(sparams) + isdefined(generator_method, :source) || return false + code = generator_method.source + nslots = ccall(:jl_ast_nslots, Int, (Any,), code) + at = unwrap_unionall(atypes) + (nslots >= 1 + length(sparams) + length(at.parameters)) || return false + + for i = 1:nsparams + if isa(sparams[i], TypeVar) + if (ast_slotflag(code, 1 + i) & SLOT_USED) != 0 + return false + end + end + end + for i = 1:length(at.parameters) + if !isdispatchelem(at.parameters[i]) + if (ast_slotflag(code, 1 + i + nsparams) & SLOT_USED) != 0 + return false + end + end + end + return true +end + # give a decent error message if we try to instantiate a staged function on non-leaf types -function func_for_method_checked(m::Method, @nospecialize types) - if isdefined(m, :generator) && !isdispatchtuple(types) +function func_for_method_checked(m::Method, @nospecialize(types), sparams::SimpleVector) + if isdefined(m, :generator) && !Core.Compiler.may_invoke_generator(m, types, sparams) error("cannot call @generated function `", m, "` ", "with abstract argument types: ", types) end @@ -978,7 +1032,7 @@ function code_typed(@nospecialize(f), @nospecialize(types=Tuple); types = to_tuple_type(types) asts = [] for x in _methods(f, types, -1, world) - meth = func_for_method_checked(x[3], types) + meth = func_for_method_checked(x[3], types, x[2]) (code, ty) = Core.Compiler.typeinf_code(meth, x[1], x[2], optimize, params) code === nothing && error("inference not successful") # inference disabled? debuginfo == :none && remove_linenums!(code) @@ -997,7 +1051,7 @@ function return_types(@nospecialize(f), @nospecialize(types=Tuple)) world = ccall(:jl_get_world_counter, UInt, ()) params = Core.Compiler.Params(world) for x in _methods(f, types, -1, world) - meth = func_for_method_checked(x[3], types) + meth = func_for_method_checked(x[3], types, x[2]) ty = Core.Compiler.typeinf_type(meth, x[1], x[2], params) ty === nothing && error("inference not successful") # inference disabled? push!(rt, ty) diff --git a/src/dump.c b/src/dump.c index 472e8752d040e9..04bc292f816d43 100644 --- a/src/dump.c +++ b/src/dump.c @@ -1212,7 +1212,7 @@ static void write_mod_list(ios_t *s, jl_array_t *a) } // "magic" string and version header of .ji file -static const int JI_FORMAT_VERSION = 7; +static const int JI_FORMAT_VERSION = 8; static const char JI_MAGIC[] = "\373jli\r\n\032\n"; // based on PNG signature static const uint16_t BOM = 0xFEFF; // byte-order marker static void write_header(ios_t *s) @@ -2459,6 +2459,13 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code) size_t nsyms = jl_array_len(code->slotnames); assert(nsyms >= m->nargs && nsyms < INT32_MAX); // required by generated functions write_int32(s.s, nsyms); + assert(nsyms == jl_array_len(code->slotflags)); + ios_write(s.s, (char*)jl_array_data(code->slotflags), nsyms); + + // N.B.: The layout of everything before this point is explicitly referenced + // by the various jl_ast_ accessors. Make sure to adjust those if you change + // the data layout. + for (i = 0; i < nsyms; i++) { jl_sym_t *name = (jl_sym_t*)jl_array_ptr_ref(code->slotnames, i); assert(jl_is_symbol(name)); @@ -2468,7 +2475,7 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code) } size_t nf = jl_datatype_nfields(jl_code_info_type); - for (i = 0; i < nf - 5; i++) { + for (i = 0; i < nf - 6; i++) { if (i == 1) // skip codelocs continue; int copy = (i != 2); // don't copy contents of method_for_inference_limit_heuristics field @@ -2536,6 +2543,9 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data) code->pure = !!(flags & (1 << 0)); size_t nslots = read_int32(&src); + code->slotflags = jl_alloc_array_1d(jl_array_uint8_type, nslots); + ios_read(s.s, (char*)jl_array_data(code->slotflags), nslots); + jl_array_t *syms = jl_alloc_vec_any(nslots); code->slotnames = syms; for (i = 0; i < nslots; i++) { @@ -2547,7 +2557,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data) } size_t nf = jl_datatype_nfields(jl_code_info_type); - for (i = 0; i < nf - 5; i++) { + for (i = 0; i < nf - 6; i++) { if (i == 1) continue; assert(jl_field_isptr(jl_code_info_type, i)); @@ -2620,6 +2630,14 @@ JL_DLLEXPORT ssize_t jl_ast_nslots(jl_array_t *data) } } +JL_DLLEXPORT uint8_t jl_ast_slotflag(jl_array_t *data, size_t i) +{ + assert(i < jl_ast_nslots(data)); + if (jl_is_code_info(data)) + return ((uint8_t*)((jl_code_info_t*)data)->slotflags->data)[i]; + return ((uint8_t*)data->data)[1 + sizeof(int32_t) + i]; +} + JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names) { size_t i, nargs = jl_array_len(names); diff --git a/src/julia.h b/src/julia.h index a9040eaee70e6f..47a2705acd8e00 100644 --- a/src/julia.h +++ b/src/julia.h @@ -1548,6 +1548,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data) JL_DLLEXPORT uint8_t jl_ast_flag_inferred(jl_array_t *data); JL_DLLEXPORT uint8_t jl_ast_flag_inlineable(jl_array_t *data); JL_DLLEXPORT uint8_t jl_ast_flag_pure(jl_array_t *data); +JL_DLLEXPORT uint8_t jl_ast_slotflag(jl_array_t *data, size_t i); JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names); JL_DLLEXPORT int jl_is_operator(char *sym); diff --git a/stdlib/InteractiveUtils/src/codeview.jl b/stdlib/InteractiveUtils/src/codeview.jl index 547ca94cd48af7..3089e61e0f4f4a 100644 --- a/stdlib/InteractiveUtils/src/codeview.jl +++ b/stdlib/InteractiveUtils/src/codeview.jl @@ -67,7 +67,7 @@ function _dump_function(@nospecialize(f), @nospecialize(t), native::Bool, wrappe t = to_tuple_type(t) tt = signature_type(f, t) (ti, env) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), tt, meth.sig)::Core.SimpleVector - meth = Base.func_for_method_checked(meth, ti) + meth = Base.func_for_method_checked(meth, ti, env) linfo = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, (Any, Any, Any, UInt), meth, ti, env, world) # get the code for it return _dump_function_linfo(linfo, world, native, wrapper, strip_ir_metadata, dump_module, syntax, optimize, debuginfo, params) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 9f6ef35c8fd6df..5420c9980362a7 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1093,7 +1093,7 @@ function get_linfo(@nospecialize(f), @nospecialize(t)) tt = Tuple{ft, t.parameters...} precompile(tt) (ti, env) = ccall(:jl_type_intersection_with_env, Ref{Core.SimpleVector}, (Any, Any), tt, meth.sig) - meth = Base.func_for_method_checked(meth, tt) + meth = Base.func_for_method_checked(meth, tt, env) return ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, (Any, Any, Any, UInt), meth, tt, env, world) end @@ -2216,3 +2216,21 @@ _call_rttf_test() = Core.Compiler.return_type(_rttf_test, Tuple{Any}) f_with_Type_arg(::Type{T}) where {T} = T @test Base.return_types(f_with_Type_arg, (Any,)) == Any[Type] + +# Generated functions that only reference some of their arguments +@inline function my_ntuple(f::F, ::Val{N}) where {F,N} + N::Int + (N >= 0) || throw(ArgumentError(string("tuple length should be ≥0, got ", N))) + if @generated + quote + @Base.nexprs $N i -> t_i = f(i) + @Base.ncall $N tuple t + end + else + Tuple(f(i) for i = 1:N) + end +end +call_ntuple(a, b) = my_ntuple(i->(a+b; i), Val(4)) +@test Base.return_types(call_ntuple, Tuple{Any,Any}) == [NTuple{4, Int}] +@test length(code_typed(my_ntuple, Tuple{Any, Val{4}})) == 1 +@test_throws ErrorException code_typed(my_ntuple, Tuple{Any, Val})