diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index 43d2bdd9315eb..ff6e89bdbb299 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -115,10 +115,7 @@ function code_for_method(method::Method, @nospecialize(atypes), sparams::SimpleV if world < min_world(method) || world > max_world(method) return nothing end - if isdefined(method, :generator) && !isdispatchtuple(atypes) - # don't call staged functions on abstract types. - # (see issues #8504, #10230) - # we can't guarantee that their type behavior is monotonic. + if isdefined(method, :generator) && !may_invoke_generator(method, atypes, sparams) return nothing end if preexisting diff --git a/base/reflection.jl b/base/reflection.jl index c02600fcd8967..6dcf5c2c58189 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -941,8 +941,72 @@ struct CodegenParams emit_function, emitted_function) end +const SLOT_USED = 0x8 +ast_slotflag(@nospecialize(code), i) = ccall(:jl_ast_slotflag, UInt8, (Any, Csize_t), code, i - 1) + +""" + may_invoke_generator(method, atypes, sparams) + +Computes whether or not we may invoke the generator for the given `method` on +the given atypes and sparams. For correctness, all generated function are +required to return monotonic answers. However, since we don't expect users to +be able to successfully implement this criterion, we only call generated +functions on concrete types. The one exception to this is that we allow calling +generators with abstract types if the generator does not use said abstract type +(and thus cannot incorrectly use it to break monotonicity). This function +computes whether we are in either of these cases. +""" +function may_invoke_generator(method::Method, @nospecialize(atypes), sparams::SimpleVector) + # If we have complete information, we may always call the generator + isdispatchtuple(atypes) && return true + + # We don't have complete information, but it is possible that the generator + # syntactically doesn't make use of the information we don't have. Check + # for that. + + # For now, only handle the (common, generated by the frontend case) that the + # generator only has one method + isa(method.generator, Core.GeneratedFunctionStub) || return false + generator_mt = typeof(method.generator.gen).name.mt + length(generator_mt) == 1 || return false + + generator_method = first(MethodList(generator_mt)) + nsparams = length(sparams) + isdefined(generator_method, :source) || return false + code = generator_method.source + nslots = ccall(:jl_ast_nslots, Int, (Any,), code) + at = unwrap_unionall(atypes) + (nslots >= 1 + length(sparams) + length(at.parameters)) || return false + + for i = 1:nsparams + if isa(sparams[i], TypeVar) + if (ast_slotflag(code, 1 + i) & SLOT_USED) != 0 + return false + end + end + end + for i = 1:length(at.parameters) + if !isdispatchelem(at.parameters[i]) + if (ast_slotflag(code, 1 + i + nsparams) & SLOT_USED) != 0 + return false + end + end + end + return true +end + # give a decent error message if we try to instantiate a staged function on non-leaf types -function func_for_method_checked(m::Method, @nospecialize types) +function func_for_method_checked(m::Method, @nospecialize(types), sparams::SimpleVector) + if isdefined(m, :generator) && !Core.Compiler.may_invoke_generator(m, types, sparams) + error("cannot call @generated function `", m, "` ", + "with abstract argument types: ", types) + end + return m +end + +function func_for_method_checked(m::Method, @nospecialize(types)) + depwarn("The two argument form of `func_for_method_checked` is deprecated. Pass sparams in addition.", + :func_for_method_checked) if isdefined(m, :generator) && !isdispatchtuple(types) error("cannot call @generated function `", m, "` ", "with abstract argument types: ", types) @@ -950,6 +1014,7 @@ function func_for_method_checked(m::Method, @nospecialize types) return m end + """ code_typed(f, types; optimize=true, debuginfo=:default) @@ -978,7 +1043,7 @@ function code_typed(@nospecialize(f), @nospecialize(types=Tuple); types = to_tuple_type(types) asts = [] for x in _methods(f, types, -1, world) - meth = func_for_method_checked(x[3], types) + meth = func_for_method_checked(x[3], types, x[2]) (code, ty) = Core.Compiler.typeinf_code(meth, x[1], x[2], optimize, params) code === nothing && error("inference not successful") # inference disabled? debuginfo == :none && remove_linenums!(code) @@ -997,7 +1062,7 @@ function return_types(@nospecialize(f), @nospecialize(types=Tuple)) world = ccall(:jl_get_world_counter, UInt, ()) params = Core.Compiler.Params(world) for x in _methods(f, types, -1, world) - meth = func_for_method_checked(x[3], types) + meth = func_for_method_checked(x[3], types, x[2]) ty = Core.Compiler.typeinf_type(meth, x[1], x[2], params) ty === nothing && error("inference not successful") # inference disabled? push!(rt, ty) diff --git a/src/dump.c b/src/dump.c index 472e8752d040e..1d563bdd96d1f 100644 --- a/src/dump.c +++ b/src/dump.c @@ -1212,7 +1212,7 @@ static void write_mod_list(ios_t *s, jl_array_t *a) } // "magic" string and version header of .ji file -static const int JI_FORMAT_VERSION = 7; +static const int JI_FORMAT_VERSION = 8; static const char JI_MAGIC[] = "\373jli\r\n\032\n"; // based on PNG signature static const uint16_t BOM = 0xFEFF; // byte-order marker static void write_header(ios_t *s) @@ -2459,6 +2459,13 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code) size_t nsyms = jl_array_len(code->slotnames); assert(nsyms >= m->nargs && nsyms < INT32_MAX); // required by generated functions write_int32(s.s, nsyms); + assert(nsyms == jl_array_len(code->slotflags)); + ios_write(s.s, (char*)jl_array_data(code->slotflags), nsyms); + + // N.B.: The layout of everything before this point is explicitly referenced + // by the various jl_ast_ accessors. Make sure to adjust those if you change + // the data layout. + for (i = 0; i < nsyms; i++) { jl_sym_t *name = (jl_sym_t*)jl_array_ptr_ref(code->slotnames, i); assert(jl_is_symbol(name)); @@ -2468,7 +2475,7 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code) } size_t nf = jl_datatype_nfields(jl_code_info_type); - for (i = 0; i < nf - 5; i++) { + for (i = 0; i < nf - 6; i++) { if (i == 1) // skip codelocs continue; int copy = (i != 2); // don't copy contents of method_for_inference_limit_heuristics field @@ -2536,6 +2543,9 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data) code->pure = !!(flags & (1 << 0)); size_t nslots = read_int32(&src); + code->slotflags = jl_alloc_array_1d(jl_array_uint8_type, nslots); + ios_read(s.s, (char*)jl_array_data(code->slotflags), nslots); + jl_array_t *syms = jl_alloc_vec_any(nslots); code->slotnames = syms; for (i = 0; i < nslots; i++) { @@ -2547,7 +2557,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data) } size_t nf = jl_datatype_nfields(jl_code_info_type); - for (i = 0; i < nf - 5; i++) { + for (i = 0; i < nf - 6; i++) { if (i == 1) continue; assert(jl_field_isptr(jl_code_info_type, i)); @@ -2620,6 +2630,14 @@ JL_DLLEXPORT ssize_t jl_ast_nslots(jl_array_t *data) } } +JL_DLLEXPORT uint8_t jl_ast_slotflag(jl_array_t *data, size_t i) +{ + assert(i < jl_ast_nslots(data)); + if (jl_is_code_info(data)) + return ((uint8_t*)((jl_code_info_t*)data)->slotflags->data)[i]; + return ((uint8_t*)data->data)[1 + sizeof(int32_t) + i]; +} + JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names) { size_t i, nargs = jl_array_len(names); @@ -2637,7 +2655,7 @@ JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names) int nslots = jl_load_unaligned_i32(d + 1); assert(nslots >= nargs); (void)nslots; - char *namestr = d + 5; + char *namestr = d + 5 + nslots; for (i = 0; i < nargs; i++) { size_t namelen = strlen(namestr); jl_sym_t *name = jl_symbol_n(namestr, namelen); diff --git a/src/julia.h b/src/julia.h index a9040eaee70e6..47a2705acd8e0 100644 --- a/src/julia.h +++ b/src/julia.h @@ -1548,6 +1548,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data) JL_DLLEXPORT uint8_t jl_ast_flag_inferred(jl_array_t *data); JL_DLLEXPORT uint8_t jl_ast_flag_inlineable(jl_array_t *data); JL_DLLEXPORT uint8_t jl_ast_flag_pure(jl_array_t *data); +JL_DLLEXPORT uint8_t jl_ast_slotflag(jl_array_t *data, size_t i); JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names); JL_DLLEXPORT int jl_is_operator(char *sym); diff --git a/src/method.c b/src/method.c index 7962502d57f1c..f5d07c2c20231 100644 --- a/src/method.c +++ b/src/method.c @@ -383,7 +383,7 @@ STATIC_INLINE jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo) { JL_TIMING(STAGED_FUNCTION); - jl_tupletype_t *tt = (jl_tupletype_t*)linfo->specTypes; + jl_value_t *tt = linfo->specTypes; jl_method_t *def = linfo->def.method; jl_value_t *generator = def->generator; assert(generator != NULL); @@ -402,7 +402,8 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo) ptls->world_age = def->min_world; // invoke code generator - ex = jl_call_staged(linfo->def.method, generator, linfo->sparam_vals, jl_svec_data(tt->parameters), jl_nparams(tt)); + jl_tupletype_t *ttdt = (jl_tupletype_t*)jl_unwrap_unionall(tt); + ex = jl_call_staged(linfo->def.method, generator, linfo->sparam_vals, jl_svec_data(ttdt->parameters), jl_nparams(ttdt)); if (jl_is_code_info(ex)) { func = (jl_code_info_t*)ex; diff --git a/stdlib/InteractiveUtils/src/codeview.jl b/stdlib/InteractiveUtils/src/codeview.jl index 547ca94cd48af..3089e61e0f4f4 100644 --- a/stdlib/InteractiveUtils/src/codeview.jl +++ b/stdlib/InteractiveUtils/src/codeview.jl @@ -67,7 +67,7 @@ function _dump_function(@nospecialize(f), @nospecialize(t), native::Bool, wrappe t = to_tuple_type(t) tt = signature_type(f, t) (ti, env) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), tt, meth.sig)::Core.SimpleVector - meth = Base.func_for_method_checked(meth, ti) + meth = Base.func_for_method_checked(meth, ti, env) linfo = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, (Any, Any, Any, UInt), meth, ti, env, world) # get the code for it return _dump_function_linfo(linfo, world, native, wrapper, strip_ir_metadata, dump_module, syntax, optimize, debuginfo, params) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index eb101b811aa0e..8eb46ace299f7 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1093,7 +1093,7 @@ function get_linfo(@nospecialize(f), @nospecialize(t)) tt = Tuple{ft, t.parameters...} precompile(tt) (ti, env) = ccall(:jl_type_intersection_with_env, Ref{Core.SimpleVector}, (Any, Any), tt, meth.sig) - meth = Base.func_for_method_checked(meth, tt) + meth = Base.func_for_method_checked(meth, tt, env) return ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, (Any, Any, Any, UInt), meth, tt, env, world) end @@ -2224,3 +2224,24 @@ _call_rttf_test() = Core.Compiler.return_type(_rttf_test, Tuple{Any}) f_with_Type_arg(::Type{T}) where {T} = T @test Base.return_types(f_with_Type_arg, (Any,)) == Any[Type] @test Base.return_types(f_with_Type_arg, (Type{Vector{T}} where T,)) == Any[Type{Vector{T}} where T] + +# Generated functions that only reference some of their arguments +@inline function my_ntuple(f::F, ::Val{N}) where {F,N} + N::Int + (N >= 0) || throw(ArgumentError(string("tuple length should be ≥0, got ", N))) + if @generated + quote + @Base.nexprs $N i -> t_i = f(i) + @Base.ncall $N tuple t + end + else + Tuple(f(i) for i = 1:N) + end +end +call_ntuple(a, b) = my_ntuple(i->(a+b; i), Val(4)) +@test Base.return_types(call_ntuple, Tuple{Any,Any}) == [NTuple{4, Int}] +@test length(code_typed(my_ntuple, Tuple{Any, Val{4}})) == 1 +@test_throws ErrorException code_typed(my_ntuple, Tuple{Any, Val}) + +@generated unionall_sig_generated(::Vector{T}, b::Vector{S}) where {T, S} = :($b) +@test length(code_typed(unionall_sig_generated, Tuple{Any, Vector{Int}})) == 1