Skip to content

Commit 2f98d45

Browse files
LiozouN5N3
authored andcommitted
Allow exact redefinition for types with recursive supertype reference (#55380)
This PR allows redefining a type when the new type is exactly identical to the previous one (like #17618, #20592 and #21024), even if the type has a reference to itself in its supertype. That particular case used to error (issue #54757), whereas with this PR: ```julia julia> struct Rec <: AbstractVector{Rec} end julia> struct Rec <: AbstractVector{Rec} end # this used to error julia> ``` Fix #54757 by implementing the solution proposed there. Hence, this should also fix downstream Revise bug timholy/Revise.jl#813. --------- Co-authored-by: N5N3 <2642243996@qq.com>
1 parent ca5506d commit 2f98d45

File tree

4 files changed

+136
-0
lines changed

4 files changed

+136
-0
lines changed

src/builtins.c

+3
Original file line numberDiff line numberDiff line change
@@ -2197,6 +2197,9 @@ static int equiv_type(jl_value_t *ta, jl_value_t *tb)
21972197
JL_GC_PUSH2(&a, &b);
21982198
a = jl_rewrap_unionall((jl_value_t*)dta->super, dta->name->wrapper);
21992199
b = jl_rewrap_unionall((jl_value_t*)dtb->super, dtb->name->wrapper);
2200+
// if tb recursively refers to itself in its supertype, assume that it refers to ta
2201+
// before checking whether the supertypes are equal
2202+
b = jl_substitute_datatype(b, dtb, dta);
22002203
if (!jl_types_equal(a, b))
22012204
goto no;
22022205
JL_TRY {

src/jltypes.c

+112
Original file line numberDiff line numberDiff line change
@@ -1607,6 +1607,118 @@ jl_value_t *jl_rewrap_unionall_(jl_value_t *t, jl_value_t *u)
16071607
return t;
16081608
}
16091609

1610+
// Create a copy of type expression t where any occurrence of data type x is replaced by y.
1611+
// If x does not occur in t, return t without any copy.
1612+
// For example, jl_substitute_datatype(Foo{Bar}, Foo{T}, Qux{S}) is Qux{Bar}, with T and S
1613+
// free type variables.
1614+
// To substitute type variables, use jl_substitute_var instead.
1615+
jl_value_t *jl_substitute_datatype(jl_value_t *t, jl_datatype_t * x, jl_datatype_t * y)
1616+
{
1617+
if jl_is_datatype(t) {
1618+
jl_datatype_t *typ = (jl_datatype_t*)t;
1619+
// For datatypes call itself recursively on the parameters to form new parameters.
1620+
// Then, if typename(t) == typename(x), rewrap the wrapper of y around the new
1621+
// parameters. Otherwise, do the same around the wrapper of t.
1622+
// This ensures that the types and supertype are properly set.
1623+
// Start by check whether there is a parameter that needs replacing.
1624+
long i_firstnewparam = -1;
1625+
size_t nparams = jl_svec_len(typ->parameters);
1626+
jl_value_t *firstnewparam = NULL;
1627+
JL_GC_PUSH1(&firstnewparam);
1628+
for (size_t i = 0; i < nparams; i++) {
1629+
jl_value_t *param = NULL;
1630+
JL_GC_PUSH1(&param);
1631+
param = jl_svecref(typ->parameters, i);
1632+
firstnewparam = jl_substitute_datatype(param, x, y);
1633+
if (param != firstnewparam) {
1634+
i_firstnewparam = i;
1635+
JL_GC_POP();
1636+
break;
1637+
}
1638+
JL_GC_POP();
1639+
}
1640+
// If one of the parameters needs to be updated, or if the type name is that to
1641+
// substitute, create a new datataype
1642+
if (i_firstnewparam != -1 || typ->name == x->name) {
1643+
jl_datatype_t *uw = typ->name == x->name ? y : typ; // substitution occurs here
1644+
jl_value_t *wrapper = uw->name->wrapper;
1645+
jl_datatype_t *w = (jl_datatype_t*)jl_unwrap_unionall(wrapper);
1646+
jl_svec_t *sv = jl_alloc_svec_uninit(jl_svec_len(uw->parameters));
1647+
JL_GC_PUSH1(&sv);
1648+
jl_value_t **vals = jl_svec_data(sv);
1649+
// no JL_GC_PUSHARGS(vals, ...) since GC is already aware of sv
1650+
for (long i = 0; i < i_firstnewparam; i++) { // copy the identical parameters
1651+
vals[i] = jl_svecref(typ->parameters, i); // value
1652+
}
1653+
if (i_firstnewparam != -1) { // insert the first non-identical parameter
1654+
vals[i_firstnewparam] = firstnewparam;
1655+
}
1656+
for (size_t i = i_firstnewparam+1; i < nparams; i++) { // insert the remaining parameters
1657+
vals[i] = jl_substitute_datatype(jl_svecref(typ->parameters, i), x, y);
1658+
}
1659+
if (jl_is_tuple_type(wrapper)) {
1660+
// special case for tuples, since the wrapper (Tuple) does not have as
1661+
// many parameters as t (it only has a Vararg instead).
1662+
t = jl_apply_tuple_type(sv, 0);
1663+
} else {
1664+
t = jl_instantiate_type_in_env((jl_value_t*)w, (jl_unionall_t*)wrapper, vals);
1665+
}
1666+
JL_GC_POP();
1667+
}
1668+
JL_GC_POP();
1669+
}
1670+
else if jl_is_unionall(t) { // recursively call itself on body and var bounds
1671+
jl_unionall_t* ut = (jl_unionall_t*)t;
1672+
jl_value_t *lb = NULL;
1673+
jl_value_t *ub = NULL;
1674+
jl_value_t *body = NULL;
1675+
JL_GC_PUSH3(&lb, &ub, &body);
1676+
lb = jl_substitute_datatype(ut->var->lb, x, y);
1677+
ub = jl_substitute_datatype(ut->var->ub, x, y);
1678+
body = jl_substitute_datatype(ut->body, x, y);
1679+
if (lb != ut->var->lb || ub != ut->var->ub) {
1680+
jl_tvar_t *newtvar = jl_new_typevar(ut->var->name, lb, ub);
1681+
JL_GC_PUSH1(&newtvar);
1682+
body = jl_substitute_var(body, ut->var, (jl_value_t*)newtvar);
1683+
t = jl_new_struct(jl_unionall_type, newtvar, body);
1684+
JL_GC_POP();
1685+
}
1686+
else if (body != ut->body) {
1687+
t = jl_new_struct(jl_unionall_type, ut->var, body);
1688+
}
1689+
JL_GC_POP();
1690+
}
1691+
else if jl_is_uniontype(t) { // recursively call itself on a and b
1692+
jl_uniontype_t *u = (jl_uniontype_t*)t;
1693+
jl_value_t *a = NULL;
1694+
jl_value_t *b = NULL;
1695+
JL_GC_PUSH2(&a, &b);
1696+
a = jl_substitute_datatype(u->a, x, y);
1697+
b = jl_substitute_datatype(u->b, x, y);
1698+
if (a != u->a || b != u->b) {
1699+
t = jl_new_struct(jl_uniontype_type, a, b);
1700+
}
1701+
JL_GC_POP();
1702+
}
1703+
else if jl_is_vararg(t) { // recursively call itself on T
1704+
jl_vararg_t *vt = (jl_vararg_t*)t;
1705+
if (vt->T) { // vt->T could be NULL
1706+
jl_value_t *rT = NULL;
1707+
JL_GC_PUSH1(&rT);
1708+
rT = jl_substitute_datatype(vt->T, x, y);
1709+
if (rT != vt->T) {
1710+
jl_task_t *ct = jl_current_task;
1711+
t = jl_gc_alloc(ct->ptls, sizeof(jl_vararg_t), jl_vararg_type);
1712+
jl_set_typetagof((jl_vararg_t *)t, jl_vararg_tag, 0);
1713+
((jl_vararg_t *)t)->T = rT;
1714+
((jl_vararg_t *)t)->N = vt->N;
1715+
}
1716+
JL_GC_POP();
1717+
}
1718+
}
1719+
return t;
1720+
}
1721+
16101722
static jl_value_t *lookup_type_stack(jl_typestack_t *stack, jl_datatype_t *tt, size_t ntp,
16111723
jl_value_t **iparams)
16121724
{

src/julia_internal.h

+1
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,7 @@ jl_unionall_t *jl_rename_unionall(jl_unionall_t *u);
769769
JL_DLLEXPORT jl_value_t *jl_unwrap_unionall(jl_value_t *v JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT;
770770
JL_DLLEXPORT jl_value_t *jl_rewrap_unionall(jl_value_t *t, jl_value_t *u);
771771
JL_DLLEXPORT jl_value_t *jl_rewrap_unionall_(jl_value_t *t, jl_value_t *u);
772+
jl_value_t* jl_substitute_datatype(jl_value_t *t, jl_datatype_t * x, jl_datatype_t * y);
772773
int jl_count_union_components(jl_value_t *v);
773774
JL_DLLEXPORT jl_value_t *jl_nth_union_component(jl_value_t *v JL_PROPAGATES_ROOT, int i) JL_NOTSAFEPOINT;
774775
int jl_find_union_component(jl_value_t *haystack, jl_value_t *needle, unsigned *nth) JL_NOTSAFEPOINT;

test/core.jl

+20
Original file line numberDiff line numberDiff line change
@@ -5611,6 +5611,26 @@ end
56115611
x::Array{T} where T<:Integer
56125612
end
56135613

5614+
# issue #54757, type redefinitions with recursive reference in supertype
5615+
struct T54757{A>:Int,N} <: AbstractArray{Tuple{X,Tuple{Vararg},Union{T54757{Union{X,Integer}},T54757{A,N}},Vararg{Y,N}} where {X,Y<:T54757}, N}
5616+
x::A
5617+
y::Union{A,T54757{A,N}}
5618+
z::T54757{A}
5619+
end
5620+
5621+
struct T54757{A>:Int,N} <: AbstractArray{Tuple{X,Tuple{Vararg},Union{T54757{Union{X,Integer}},T54757{A,N}},Vararg{Y,N}} where {X,Y<:T54757}, N}
5622+
x::A
5623+
y::Union{A,T54757{A,N}}
5624+
z::T54757{A}
5625+
end
5626+
5627+
@test_throws ErrorException struct T54757{A>:Int,N} <: AbstractArray{Tuple{X,Tuple{Vararg},Union{T54757{Union{X,Integer}},T54757{A}},Vararg{Y,N}} where {X,Y<:T54757}, N}
5628+
x::A
5629+
y::Union{A,T54757{A,N}}
5630+
z::T54757{A}
5631+
end
5632+
5633+
56145634
let a = Vector{Core.TypeofBottom}(undef, 2)
56155635
@test a[1] == Union{}
56165636
@test a == [Union{}, Union{}]

0 commit comments

Comments
 (0)