From 5cf10211fe9097a2551b1d1bc0a2ce670fa4df7f Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Fri, 26 Jan 2024 10:37:39 +0800 Subject: [PATCH] Subtype: Fix some diagonal rule related false alarm (#53034) close #33137 close #53021 --------- Co-authored-by: Jameson Nash --- src/jltypes.c | 80 ++++++++++++++++++++++++++++++------------------- src/julia.h | 2 ++ src/subtype.c | 6 ++-- test/core.jl | 15 ++++++++-- test/subtype.jl | 19 ++++++------ 5 files changed, 77 insertions(+), 45 deletions(-) diff --git a/src/jltypes.c b/src/jltypes.c index 5bd5ec31185d9..21c3cebbb8de7 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -556,6 +556,43 @@ static void isort_union(jl_value_t **a, size_t len) JL_NOTSAFEPOINT } } +static int simple_subtype(jl_value_t *a, jl_value_t *b, int hasfree, int isUnion) +{ + if (a == jl_bottom_type || b == (jl_value_t*)jl_any_type) + return 1; + if (jl_egal(a, b)) + return 1; + if (hasfree == 0) { + int mergeable = isUnion; + if (!mergeable) // issue #24521: don't merge Type{T} where typeof(T) varies + mergeable = !(jl_is_type_type(a) && jl_is_type_type(b) && + jl_typeof(jl_tparam0(a)) != jl_typeof(jl_tparam0(b))); + return mergeable && jl_subtype(a, b); + } + if (jl_is_typevar(a)) { + jl_value_t *na = ((jl_tvar_t*)a)->ub; + hasfree &= jl_has_free_typevars(na); + return simple_subtype(na, b, hasfree, isUnion); + } + if (jl_is_typevar(b)) { + jl_value_t *nb = ((jl_tvar_t*)b)->lb; + // This branch is not valid if `b` obeys diagonal rule, + // as it might normalize `Union` into a single `TypeVar`, e.g. + // Tuple{Union{Int,T},T} where {T>:Int} != Tuple{T,T} where {T>:Int} + if (is_leaf_bound(nb)) + return 0; + hasfree &= jl_has_free_typevars(nb) << 1; + return simple_subtype(a, nb, hasfree, isUnion); + } + if (b==(jl_value_t*)jl_datatype_type || b==(jl_value_t*)jl_typeofbottom_type) { + // This branch is not valid for `Union`/`UnionAll`, e.g. + // (Type{Union{Int,T2} where {T2<:T1}} where {T1}){Int} == Type{Int64} + // (Type{Union{Int,T1}} where {T1}){Int} == Type{Int64} + return jl_is_type_type(a) && jl_typeof(jl_tparam0(a)) == b; + } + return 0; +} + JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n) { if (n == 0) @@ -580,13 +617,9 @@ JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n) int has_free = temp[i] != NULL && jl_has_free_typevars(temp[i]); for (j = 0; j < nt; j++) { if (j != i && temp[i] && temp[j]) { - if (temp[i] == jl_bottom_type || - temp[j] == (jl_value_t*)jl_any_type || - jl_egal(temp[i], temp[j]) || - (!has_free && !jl_has_free_typevars(temp[j]) && - jl_subtype(temp[i], temp[j]))) { + int has_free2 = has_free | (jl_has_free_typevars(temp[j]) << 1); + if (simple_subtype(temp[i], temp[j], has_free2, 1)) temp[i] = NULL; - } } } } @@ -608,17 +641,7 @@ JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n) return tu; } -// note: this is turned off as `Union` doesn't do such normalization. -// static int simple_subtype(jl_value_t *a, jl_value_t *b) -// { -// if (jl_is_kind(b) && jl_is_type_type(a) && jl_typeof(jl_tparam0(a)) == b) -// return 1; -// if (jl_is_typevar(b) && obviously_egal(a, ((jl_tvar_t*)b)->lb)) -// return 1; -// return 0; -// } - -static int simple_subtype2(jl_value_t *a, jl_value_t *b, int hasfree) +static int simple_subtype2(jl_value_t *a, jl_value_t *b, int hasfree, int isUnion) { int subab = 0, subba = 0; if (jl_egal(a, b)) { @@ -630,9 +653,9 @@ static int simple_subtype2(jl_value_t *a, jl_value_t *b, int hasfree) else if (b == jl_bottom_type || a == (jl_value_t*)jl_any_type) { subba = 1; } - else if (hasfree) { - // subab = simple_subtype(a, b); - // subba = simple_subtype(b, a); + else if (hasfree != 0) { + subab = simple_subtype(a, b, hasfree, isUnion); + subba = simple_subtype(b, a, hasfree, isUnion); } else if (jl_is_type_type(a) && jl_is_type_type(b) && jl_typeof(jl_tparam0(a)) != jl_typeof(jl_tparam0(b))) { @@ -664,10 +687,11 @@ jl_value_t *simple_union(jl_value_t *a, jl_value_t *b) // first remove cross-redundancy and check if `a >: b` or `a <: b`. for (i = 0; i < nta; i++) { if (temp[i] == NULL) continue; - int hasfree = jl_has_free_typevars(temp[i]); + int has_free = jl_has_free_typevars(temp[i]); for (j = nta; j < nt; j++) { if (temp[j] == NULL) continue; - int subs = simple_subtype2(temp[i], temp[j], hasfree || jl_has_free_typevars(temp[j])); + int has_free2 = has_free | (jl_has_free_typevars(temp[j]) << 1); + int subs = simple_subtype2(temp[i], temp[j], has_free2, 0); int subab = subs & 1, subba = subs >> 1; if (subab) { temp[i] = NULL; @@ -697,15 +721,9 @@ jl_value_t *simple_union(jl_value_t *a, jl_value_t *b) size_t jmax = i < nta ? nta : nt; for (j = jmin; j < jmax; j++) { if (j != i && temp[i] && temp[j]) { - if (temp[i] == jl_bottom_type || - temp[j] == (jl_value_t*)jl_any_type || - jl_egal(temp[i], temp[j]) || - (!has_free && !jl_has_free_typevars(temp[j]) && - // issue #24521: don't merge Type{T} where typeof(T) varies - !(jl_is_type_type(temp[i]) && jl_is_type_type(temp[j]) && jl_typeof(jl_tparam0(temp[i])) != jl_typeof(jl_tparam0(temp[j]))) && - jl_subtype(temp[i], temp[j]))) { + int has_free2 = has_free | (jl_has_free_typevars(temp[j]) << 1); + if (simple_subtype(temp[i], temp[j], has_free2, 0)) temp[i] = NULL; - } } } } @@ -769,7 +787,7 @@ jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi) int hasfree = jl_has_free_typevars(temp[i]); for (j = nta; j < nt; j++) { if (temp[j] == NULL) continue; - int subs = simple_subtype2(temp[i], temp[j], hasfree || jl_has_free_typevars(temp[j])); + int subs = simple_subtype2(temp[i], temp[j], hasfree || jl_has_free_typevars(temp[j]), 0); int subab = subs & 1, subba = subs >> 1; if (subba && !subab) { stemp[i] = -1; diff --git a/src/julia.h b/src/julia.h index 12c920cc6899d..3e794e85030fe 100644 --- a/src/julia.h +++ b/src/julia.h @@ -1496,6 +1496,8 @@ static inline int jl_field_isconst(jl_datatype_t *st, int i) JL_NOTSAFEPOINT JL_DLLEXPORT int jl_subtype(jl_value_t *a, jl_value_t *b); +int is_leaf_bound(jl_value_t *v) JL_NOTSAFEPOINT; + STATIC_INLINE int jl_is_kind(jl_value_t *v) JL_NOTSAFEPOINT { return (v==(jl_value_t*)jl_uniontype_type || v==(jl_value_t*)jl_datatype_type || diff --git a/src/subtype.c b/src/subtype.c index 50b62c9b6e3da..600df9da8757e 100644 --- a/src/subtype.c +++ b/src/subtype.c @@ -805,7 +805,7 @@ static int subtype_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int R, int pa // check that a type is concrete or quasi-concrete (Type{T}). // this is used to check concrete typevars: // issubtype is false if the lower bound of a concrete type var is not concrete. -static int is_leaf_bound(jl_value_t *v) JL_NOTSAFEPOINT +int is_leaf_bound(jl_value_t *v) JL_NOTSAFEPOINT { if (v == jl_bottom_type) return 1; @@ -1997,7 +1997,7 @@ static int obvious_subtype(jl_value_t *x, jl_value_t *y, jl_value_t *y0, int *su if (var_occurs_invariant(body, (jl_tvar_t*)b)) return 0; } - if (nparams_expanded_x > npy && jl_is_typevar(b) && concrete_min(a1) > 1) { + if (nparams_expanded_x > npy && jl_is_typevar(b) && is_leaf_typevar((jl_tvar_t *)b) && concrete_min(a1) > 1) { // diagonal rule for 2 or more elements: they must all be concrete on the LHS *subtype = 0; return 1; @@ -2008,7 +2008,7 @@ static int obvious_subtype(jl_value_t *x, jl_value_t *y, jl_value_t *y0, int *su } for (; i < nparams_expanded_x; i++) { jl_value_t *a = (vx != JL_VARARG_NONE && i >= npx - 1) ? vxt : jl_tparam(x, i); - if (i > npy && jl_is_typevar(b)) { // i == npy implies a == a1 + if (i > npy && jl_is_typevar(b) && is_leaf_typevar((jl_tvar_t *)b)) { // i == npy implies a == a1 // diagonal rule: all the later parameters are also constrained to be type-equal to the first jl_value_t *a2 = a; jl_value_t *au = jl_unwrap_unionall(a); diff --git a/test/core.jl b/test/core.jl index 67e0102bc164a..050127a5621c5 100644 --- a/test/core.jl +++ b/test/core.jl @@ -239,8 +239,8 @@ k11840(::Type{Union{Tuple{Int32}, Tuple{Int64}}}) = '2' # issue #20511 f20511(x::DataType) = 0 f20511(x) = 1 -Type{Integer} # cache this -@test f20511(Union{Integer,T} where T <: Unsigned) == 1 +Type{AbstractSet} # cache this +@test f20511(Union{AbstractSet,Set{T}} where T) == 1 # join @test typejoin(Int8,Int16) === Signed @@ -8101,3 +8101,14 @@ end # #52433 @test_throws ErrorException Core.Intrinsics.pointerref(Ptr{Vector{Int64}}(C_NULL), 1, 0) + +# #53034 (Union normalization for typevar elimination) +@test Tuple{Int,Any} <: Tuple{Union{Int,T},T} where {T>:Int} +@test Tuple{Int,Any} <: Tuple{Union{Int,T},T} where {T>:Integer} +# #53034 (Union normalization for Type elimination) +@test Int isa Type{Union{Int,T2} where {T2<:T1}} where {T1} +@test Int isa Type{Union{Int,T1}} where {T1} +@test Int isa Union{UnionAll, Type{Union{Int,T2} where {T2<:T1}}} where {T1} +@test Int isa Union{Union, Type{Union{Int,T1}}} where {T1} +@test_broken Int isa Union{UnionAll, Type{Union{Int,T2} where {T2<:T1}} where {T1}} +@test_broken Int isa Union{Union, Type{Union{Int,T1}} where {T1}} diff --git a/test/subtype.jl b/test/subtype.jl index edc38c8556f3c..c8197dbddbf6d 100644 --- a/test/subtype.jl +++ b/test/subtype.jl @@ -146,6 +146,14 @@ function test_diagonal() @test isequal_type(Ref{Tuple{T, T} where Int<:T<:Int}, Ref{Tuple{S, S}} where Int<:S<:Int) + # issue #53021 + @test Tuple{X, X} where {X<:Union{}} <: Tuple{X, X, Vararg{Any}} where {Int<:X<:Int} + @test Tuple{Integer, X, Vararg{X}} where {X<:Int} <: Tuple{Any, Vararg{X}} where {X>:Int} + @test Tuple{Any, X, Vararg{X}} where {X<:Int} <: Tuple{Vararg{X}} where X>:Integer + @test Tuple{Integer, Integer, Any, Vararg{Any}} <: Tuple{Vararg{X}} where X>:Integer + # issue #53019 + @test Tuple{T,T} where {T<:Int} <: Tuple{T,T} where {T>:Int} + let A = Tuple{Int,Int8,Vector{Integer}}, B = Tuple{T,T,Vector{T}} where T>:Integer, C = Tuple{T,T,Vector{Union{Integer,T}}} where T @@ -1260,14 +1268,7 @@ let a = Tuple{Tuple{T2,4},T6} where T2 where T6, end let a = Tuple{T3,Int64,Tuple{T3}} where T3, b = Tuple{S3,S3,S4} where S4 where S3 - I1 = typeintersect(a, b) - I2 = typeintersect(b, a) - @test I1 <: I2 - @test I2 <: I1 - @test_broken I1 <: a - @test I2 <: a - @test I1 <: b - @test I2 <: b + @testintersect(a, b, Tuple{Int64, Int64, Tuple{Int64}}) end let a = Tuple{T1,Val{T2},T2} where T2 where T1, b = Tuple{Float64,S1,S2} where S2 where S1 @@ -2445,7 +2446,7 @@ abstract type P47654{A} end @test_broken typeintersect(Type{Tuple{Array{T,1} where T}}, UnionAll) != Union{} #issue 33137 - @test_broken (Tuple{Q,Int} where Q<:Int) <: Tuple{T,T} where T + @test (Tuple{Q,Int} where Q<:Int) <: Tuple{T,T} where T # issue 24333 @test (Type{Union{Ref,Cvoid}} <: Type{Union{T,Cvoid}} where T)