From 3bc4ea7bc083c7797a1e26a7c578b08cd1b8dc4d Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Sat, 9 Feb 2019 19:48:58 -0500 Subject: [PATCH] Get better type info from `partially generated` functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Consider the following function: ``` julia> function foo(a, b) ntuple(i->(a+b; i), Val(4)) end foo (generic function with 1 method) ``` (In particular note that the return type of the closure does not depend on the types of `a` and b`). Unfortunately, prior to this change, inference was unable to determine the return type in this situation: ``` julia> code_typed(foo, Tuple{Any, Any}, trace=true) Refused to call generated function with non-concrete argument types ntuple(::getfield(Main, Symbol("##15#16")){_A,_B} where _B where _A, ::Val{4}) [GeneratedNotConcrete] 1-element Array{Any,1}: CodeInfo( 1 ─ %1 = Main.:(##15#16)::Const(##15#16, false) │ %2 = Core.typeof(a)::DataType │ %3 = Core.typeof(b)::DataType │ %4 = Core.apply_type(%1, %2, %3)::Type{##15#16{_A,_B}} where _B where _A │ %5 = %new(%4, a, b)::##15#16{_A,_B} where _B where _A │ %6 = Main.ntuple(%5, $(QuoteNode(Val{4}())))::Any └── return %6 ) => Any ``` Looking at the definition of ntuple https://github.com/JuliaLang/julia/blob/abb09f88804c4e74c752a66157e767c9b0f8945d/base/ntuple.jl#L45-L56 we see that it is a generated function an inference thus refuses to invoke it, unless it can prove the concrete type of *all* arguments to the function. As the above example illustrates, this restriction is more stringent than necessary. It is true that we cannot invoke generated functions on arbitrary abstract signatures (because we neither want to the user to have to be able to nor do we trust that users are able to preverse monotonicity - i.e. that the return type of the generated code will always be a subtype of the return type of a more abstract signature). However, if some piece of information is not used (the type of the passed function in this case), there is no problem with calling the generated function (since information that is unnused cannot possibly affect monotnicity). This PR allows us to recognize pieces of information that are *syntactically* unused, and call the generated functions, even if we do not have those pieces of information. As a result, we are now able to infer the return type of the above function: ``` julia> code_typed(foo, Tuple{Any, Any}) 1-element Array{Any,1}: CodeInfo( 1 ─ %1 = Main.:(##3#4)::Const(##3#4, false) │ %2 = Core.typeof(a)::DataType │ %3 = Core.typeof(b)::DataType │ %4 = Core.apply_type(%1, %2, %3)::Type{##3#4{_A,_B}} where _B where _A │ %5 = %new(%4, a, b)::##3#4{_A,_B} where _B where _A │ %6 = Main.ntuple(%5, $(QuoteNode(Val{4}())))::NTuple{4,Int64} └── return %6 ) => NTuple{4,Int64} ``` In particular, we use the new frontent `used` flags from the previous commit. One additional complication is that we want to accesss these flags without uncompressing the generator source, so we change the compression scheme to place the flags at a known location. Fixes #31004 --- base/compiler/utilities.jl | 5 +- base/reflection.jl | 71 +++++++++++++++++++++++-- src/dump.c | 26 +++++++-- src/julia.h | 1 + src/method.c | 5 +- stdlib/InteractiveUtils/src/codeview.jl | 2 +- test/compiler/inference.jl | 23 +++++++- 7 files changed, 118 insertions(+), 15 deletions(-) diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index 43d2bdd9315eb..ff6e89bdbb299 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -115,10 +115,7 @@ function code_for_method(method::Method, @nospecialize(atypes), sparams::SimpleV if world < min_world(method) || world > max_world(method) return nothing end - if isdefined(method, :generator) && !isdispatchtuple(atypes) - # don't call staged functions on abstract types. - # (see issues #8504, #10230) - # we can't guarantee that their type behavior is monotonic. + if isdefined(method, :generator) && !may_invoke_generator(method, atypes, sparams) return nothing end if preexisting diff --git a/base/reflection.jl b/base/reflection.jl index c02600fcd8967..6dcf5c2c58189 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -941,8 +941,72 @@ struct CodegenParams emit_function, emitted_function) end +const SLOT_USED = 0x8 +ast_slotflag(@nospecialize(code), i) = ccall(:jl_ast_slotflag, UInt8, (Any, Csize_t), code, i - 1) + +""" + may_invoke_generator(method, atypes, sparams) + +Computes whether or not we may invoke the generator for the given `method` on +the given atypes and sparams. For correctness, all generated function are +required to return monotonic answers. However, since we don't expect users to +be able to successfully implement this criterion, we only call generated +functions on concrete types. The one exception to this is that we allow calling +generators with abstract types if the generator does not use said abstract type +(and thus cannot incorrectly use it to break monotonicity). This function +computes whether we are in either of these cases. +""" +function may_invoke_generator(method::Method, @nospecialize(atypes), sparams::SimpleVector) + # If we have complete information, we may always call the generator + isdispatchtuple(atypes) && return true + + # We don't have complete information, but it is possible that the generator + # syntactically doesn't make use of the information we don't have. Check + # for that. + + # For now, only handle the (common, generated by the frontend case) that the + # generator only has one method + isa(method.generator, Core.GeneratedFunctionStub) || return false + generator_mt = typeof(method.generator.gen).name.mt + length(generator_mt) == 1 || return false + + generator_method = first(MethodList(generator_mt)) + nsparams = length(sparams) + isdefined(generator_method, :source) || return false + code = generator_method.source + nslots = ccall(:jl_ast_nslots, Int, (Any,), code) + at = unwrap_unionall(atypes) + (nslots >= 1 + length(sparams) + length(at.parameters)) || return false + + for i = 1:nsparams + if isa(sparams[i], TypeVar) + if (ast_slotflag(code, 1 + i) & SLOT_USED) != 0 + return false + end + end + end + for i = 1:length(at.parameters) + if !isdispatchelem(at.parameters[i]) + if (ast_slotflag(code, 1 + i + nsparams) & SLOT_USED) != 0 + return false + end + end + end + return true +end + # give a decent error message if we try to instantiate a staged function on non-leaf types -function func_for_method_checked(m::Method, @nospecialize types) +function func_for_method_checked(m::Method, @nospecialize(types), sparams::SimpleVector) + if isdefined(m, :generator) && !Core.Compiler.may_invoke_generator(m, types, sparams) + error("cannot call @generated function `", m, "` ", + "with abstract argument types: ", types) + end + return m +end + +function func_for_method_checked(m::Method, @nospecialize(types)) + depwarn("The two argument form of `func_for_method_checked` is deprecated. Pass sparams in addition.", + :func_for_method_checked) if isdefined(m, :generator) && !isdispatchtuple(types) error("cannot call @generated function `", m, "` ", "with abstract argument types: ", types) @@ -950,6 +1014,7 @@ function func_for_method_checked(m::Method, @nospecialize types) return m end + """ code_typed(f, types; optimize=true, debuginfo=:default) @@ -978,7 +1043,7 @@ function code_typed(@nospecialize(f), @nospecialize(types=Tuple); types = to_tuple_type(types) asts = [] for x in _methods(f, types, -1, world) - meth = func_for_method_checked(x[3], types) + meth = func_for_method_checked(x[3], types, x[2]) (code, ty) = Core.Compiler.typeinf_code(meth, x[1], x[2], optimize, params) code === nothing && error("inference not successful") # inference disabled? debuginfo == :none && remove_linenums!(code) @@ -997,7 +1062,7 @@ function return_types(@nospecialize(f), @nospecialize(types=Tuple)) world = ccall(:jl_get_world_counter, UInt, ()) params = Core.Compiler.Params(world) for x in _methods(f, types, -1, world) - meth = func_for_method_checked(x[3], types) + meth = func_for_method_checked(x[3], types, x[2]) ty = Core.Compiler.typeinf_type(meth, x[1], x[2], params) ty === nothing && error("inference not successful") # inference disabled? push!(rt, ty) diff --git a/src/dump.c b/src/dump.c index 472e8752d040e..1d563bdd96d1f 100644 --- a/src/dump.c +++ b/src/dump.c @@ -1212,7 +1212,7 @@ static void write_mod_list(ios_t *s, jl_array_t *a) } // "magic" string and version header of .ji file -static const int JI_FORMAT_VERSION = 7; +static const int JI_FORMAT_VERSION = 8; static const char JI_MAGIC[] = "\373jli\r\n\032\n"; // based on PNG signature static const uint16_t BOM = 0xFEFF; // byte-order marker static void write_header(ios_t *s) @@ -2459,6 +2459,13 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code) size_t nsyms = jl_array_len(code->slotnames); assert(nsyms >= m->nargs && nsyms < INT32_MAX); // required by generated functions write_int32(s.s, nsyms); + assert(nsyms == jl_array_len(code->slotflags)); + ios_write(s.s, (char*)jl_array_data(code->slotflags), nsyms); + + // N.B.: The layout of everything before this point is explicitly referenced + // by the various jl_ast_ accessors. Make sure to adjust those if you change + // the data layout. + for (i = 0; i < nsyms; i++) { jl_sym_t *name = (jl_sym_t*)jl_array_ptr_ref(code->slotnames, i); assert(jl_is_symbol(name)); @@ -2468,7 +2475,7 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code) } size_t nf = jl_datatype_nfields(jl_code_info_type); - for (i = 0; i < nf - 5; i++) { + for (i = 0; i < nf - 6; i++) { if (i == 1) // skip codelocs continue; int copy = (i != 2); // don't copy contents of method_for_inference_limit_heuristics field @@ -2536,6 +2543,9 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data) code->pure = !!(flags & (1 << 0)); size_t nslots = read_int32(&src); + code->slotflags = jl_alloc_array_1d(jl_array_uint8_type, nslots); + ios_read(s.s, (char*)jl_array_data(code->slotflags), nslots); + jl_array_t *syms = jl_alloc_vec_any(nslots); code->slotnames = syms; for (i = 0; i < nslots; i++) { @@ -2547,7 +2557,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data) } size_t nf = jl_datatype_nfields(jl_code_info_type); - for (i = 0; i < nf - 5; i++) { + for (i = 0; i < nf - 6; i++) { if (i == 1) continue; assert(jl_field_isptr(jl_code_info_type, i)); @@ -2620,6 +2630,14 @@ JL_DLLEXPORT ssize_t jl_ast_nslots(jl_array_t *data) } } +JL_DLLEXPORT uint8_t jl_ast_slotflag(jl_array_t *data, size_t i) +{ + assert(i < jl_ast_nslots(data)); + if (jl_is_code_info(data)) + return ((uint8_t*)((jl_code_info_t*)data)->slotflags->data)[i]; + return ((uint8_t*)data->data)[1 + sizeof(int32_t) + i]; +} + JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names) { size_t i, nargs = jl_array_len(names); @@ -2637,7 +2655,7 @@ JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names) int nslots = jl_load_unaligned_i32(d + 1); assert(nslots >= nargs); (void)nslots; - char *namestr = d + 5; + char *namestr = d + 5 + nslots; for (i = 0; i < nargs; i++) { size_t namelen = strlen(namestr); jl_sym_t *name = jl_symbol_n(namestr, namelen); diff --git a/src/julia.h b/src/julia.h index a9040eaee70e6..47a2705acd8e0 100644 --- a/src/julia.h +++ b/src/julia.h @@ -1548,6 +1548,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data) JL_DLLEXPORT uint8_t jl_ast_flag_inferred(jl_array_t *data); JL_DLLEXPORT uint8_t jl_ast_flag_inlineable(jl_array_t *data); JL_DLLEXPORT uint8_t jl_ast_flag_pure(jl_array_t *data); +JL_DLLEXPORT uint8_t jl_ast_slotflag(jl_array_t *data, size_t i); JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names); JL_DLLEXPORT int jl_is_operator(char *sym); diff --git a/src/method.c b/src/method.c index 7962502d57f1c..f5d07c2c20231 100644 --- a/src/method.c +++ b/src/method.c @@ -383,7 +383,7 @@ STATIC_INLINE jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo) { JL_TIMING(STAGED_FUNCTION); - jl_tupletype_t *tt = (jl_tupletype_t*)linfo->specTypes; + jl_value_t *tt = linfo->specTypes; jl_method_t *def = linfo->def.method; jl_value_t *generator = def->generator; assert(generator != NULL); @@ -402,7 +402,8 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo) ptls->world_age = def->min_world; // invoke code generator - ex = jl_call_staged(linfo->def.method, generator, linfo->sparam_vals, jl_svec_data(tt->parameters), jl_nparams(tt)); + jl_tupletype_t *ttdt = (jl_tupletype_t*)jl_unwrap_unionall(tt); + ex = jl_call_staged(linfo->def.method, generator, linfo->sparam_vals, jl_svec_data(ttdt->parameters), jl_nparams(ttdt)); if (jl_is_code_info(ex)) { func = (jl_code_info_t*)ex; diff --git a/stdlib/InteractiveUtils/src/codeview.jl b/stdlib/InteractiveUtils/src/codeview.jl index 547ca94cd48af..3089e61e0f4f4 100644 --- a/stdlib/InteractiveUtils/src/codeview.jl +++ b/stdlib/InteractiveUtils/src/codeview.jl @@ -67,7 +67,7 @@ function _dump_function(@nospecialize(f), @nospecialize(t), native::Bool, wrappe t = to_tuple_type(t) tt = signature_type(f, t) (ti, env) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), tt, meth.sig)::Core.SimpleVector - meth = Base.func_for_method_checked(meth, ti) + meth = Base.func_for_method_checked(meth, ti, env) linfo = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, (Any, Any, Any, UInt), meth, ti, env, world) # get the code for it return _dump_function_linfo(linfo, world, native, wrapper, strip_ir_metadata, dump_module, syntax, optimize, debuginfo, params) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index eb101b811aa0e..8eb46ace299f7 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1093,7 +1093,7 @@ function get_linfo(@nospecialize(f), @nospecialize(t)) tt = Tuple{ft, t.parameters...} precompile(tt) (ti, env) = ccall(:jl_type_intersection_with_env, Ref{Core.SimpleVector}, (Any, Any), tt, meth.sig) - meth = Base.func_for_method_checked(meth, tt) + meth = Base.func_for_method_checked(meth, tt, env) return ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, (Any, Any, Any, UInt), meth, tt, env, world) end @@ -2224,3 +2224,24 @@ _call_rttf_test() = Core.Compiler.return_type(_rttf_test, Tuple{Any}) f_with_Type_arg(::Type{T}) where {T} = T @test Base.return_types(f_with_Type_arg, (Any,)) == Any[Type] @test Base.return_types(f_with_Type_arg, (Type{Vector{T}} where T,)) == Any[Type{Vector{T}} where T] + +# Generated functions that only reference some of their arguments +@inline function my_ntuple(f::F, ::Val{N}) where {F,N} + N::Int + (N >= 0) || throw(ArgumentError(string("tuple length should be ≥0, got ", N))) + if @generated + quote + @Base.nexprs $N i -> t_i = f(i) + @Base.ncall $N tuple t + end + else + Tuple(f(i) for i = 1:N) + end +end +call_ntuple(a, b) = my_ntuple(i->(a+b; i), Val(4)) +@test Base.return_types(call_ntuple, Tuple{Any,Any}) == [NTuple{4, Int}] +@test length(code_typed(my_ntuple, Tuple{Any, Val{4}})) == 1 +@test_throws ErrorException code_typed(my_ntuple, Tuple{Any, Val}) + +@generated unionall_sig_generated(::Vector{T}, b::Vector{S}) where {T, S} = :($b) +@test length(code_typed(unionall_sig_generated, Tuple{Any, Vector{Int}})) == 1