Skip to content

Commit 1df7e97

Browse files
committed
Get better type info from partially generated functions
Consider the following function: ``` julia> function foo(a, b) ntuple(i->(a+b; i), Val(4)) end foo (generic function with 1 method) ``` (In particular note that the return type of the closure does not depend on the types of `a` and b`). Unfortunately, prior to this change, inference was unable to determine the return type in this situation: ``` julia> code_typed(foo, Tuple{Any, Any}, trace=true) Refused to call generated function with non-concrete argument types ntuple(::getfield(Main, Symbol("##15#16")){_A,_B} where _B where _A, ::Val{4}) [GeneratedNotConcrete] 1-element Array{Any,1}: CodeInfo( 1 ─ %1 = Main.:(##15#16)::Const(##15#16, false) │ %2 = Core.typeof(a)::DataType │ %3 = Core.typeof(b)::DataType │ %4 = Core.apply_type(%1, %2, %3)::Type{##15#16{_A,_B}} where _B where _A │ %5 = %new(%4, a, b)::##15#16{_A,_B} where _B where _A │ %6 = Main.ntuple(%5, $(QuoteNode(Val{4}())))::Any └── return %6 ) => Any ``` Looking at the definition of ntuple https://github.com/JuliaLang/julia/blob/abb09f88804c4e74c752a66157e767c9b0f8945d/base/ntuple.jl#L45-L56 we see that it is a generated function an inference thus refuses to invoke it, unless it can prove the concrete type of *all* arguments to the function. As the above example illustrates, this restriction is more stringent than necessary. It is true that we cannot invoke generated functions on arbitrary abstract signatures (because we neither want to the user to have to be able to nor do we trust that users are able to preverse monotonicity - i.e. that the return type of the generated code will always be a subtype of the return type of a more abstract signature). However, if some piece of information is not used (the type of the passed function in this case), there is no problem with calling the generated function (since information that is unnused cannot possibly affect monotnicity). This PR allows us to recognize pieces of information that are *syntactically* unused, and call the generated functions, even if we do not have those pieces of information. As a result, we are now able to infer the return type of the above function: ``` julia> code_typed(foo, Tuple{Any, Any}) 1-element Array{Any,1}: CodeInfo( 1 ─ %1 = Main.:(##3#4)::Const(##3#4, false) │ %2 = Core.typeof(a)::DataType │ %3 = Core.typeof(b)::DataType │ %4 = Core.apply_type(%1, %2, %3)::Type{##3#4{_A,_B}} where _B where _A │ %5 = %new(%4, a, b)::##3#4{_A,_B} where _B where _A │ %6 = Main.ntuple(%5, $(QuoteNode(Val{4}())))::NTuple{4,Int64} └── return %6 ) => NTuple{4,Int64} ``` In particular, we use the new frontent `used` flags from the previous commit. One additional complication is that we want to accesss these flags without uncompressing the generator source, so we change the compression scheme to place the flags at a known location. Fixes #31004
1 parent 635b8c5 commit 1df7e97

File tree

6 files changed

+101
-13
lines changed

6 files changed

+101
-13
lines changed

base/compiler/utilities.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,7 @@ function code_for_method(method::Method, @nospecialize(atypes), sparams::SimpleV
115115
if world < min_world(method) || world > max_world(method)
116116
return nothing
117117
end
118-
if isdefined(method, :generator) && !isdispatchtuple(atypes)
119-
# don't call staged functions on abstract types.
120-
# (see issues #8504, #10230)
121-
# we can't guarantee that their type behavior is monotonic.
118+
if isdefined(method, :generator) && !may_invoke_generator(method, atypes, sparams)
122119
return nothing
123120
end
124121
if preexisting

base/reflection.jl

+58-4
Original file line numberDiff line numberDiff line change
@@ -941,9 +941,63 @@ struct CodegenParams
941941
emit_function, emitted_function)
942942
end
943943

944+
const SLOT_USED = 0x8
945+
ast_slotflag(@nospecialize(code), i) = ccall(:jl_ast_slotflag, UInt8, (Any, Csize_t), code, i - 1)
946+
947+
"""
948+
may_invoke_generator(method, atypes, sparams)
949+
950+
Computes whether or not we may invoke the generator for the given `method` on
951+
the given atypes and sparams. For correctness, all generated function are
952+
required to return monotonic answers. However, since we don't expect users to
953+
be able to successfully implement this criterion, we only call generated
954+
functions on concrete types. The one exception to this is that we allow calling
955+
generators with abstract types if the generator does not use said abstract type
956+
(and thus cannot incorrectly use it to break monotonicity). This function
957+
computes whether we are in either of these cases.
958+
"""
959+
function may_invoke_generator(method::Method, @nospecialize(atypes), sparams::SimpleVector)
960+
# If we have complete information, we may always call the generator
961+
isdispatchtuple(atypes) && return true
962+
963+
# We don't have complete information, but it is possible that the generator
964+
# syntactically doesn't make use of the information we don't have. Check
965+
# for that.
966+
967+
# For now, only handle the (common, generated by the frontend case) that the
968+
# generator only has one method
969+
isa(method.generator, Core.GeneratedFunctionStub) || return false
970+
generator_mt = typeof(method.generator.gen).name.mt
971+
length(generator_mt) == 1 || return false
972+
973+
generator_method = first(MethodList(generator_mt))
974+
nsparams = length(sparams)
975+
isdefined(generator_method, :source) || return false
976+
code = generator_method.source
977+
nslots = ccall(:jl_ast_nslots, Int, (Any,), code)
978+
at = unwrap_unionall(atypes)
979+
(nslots >= 1 + length(sparams) + length(at.parameters)) || return false
980+
981+
for i = 1:nsparams
982+
if isa(sparams[i], TypeVar)
983+
if (ast_slotflag(code, 1 + i) & SLOT_USED) != 0
984+
return false
985+
end
986+
end
987+
end
988+
for i = 1:length(at.parameters)
989+
if !isdispatchelem(at.parameters[i])
990+
if (ast_slotflag(code, 1 + i + nsparams) & SLOT_USED) != 0
991+
return false
992+
end
993+
end
994+
end
995+
return true
996+
end
997+
944998
# give a decent error message if we try to instantiate a staged function on non-leaf types
945-
function func_for_method_checked(m::Method, @nospecialize types)
946-
if isdefined(m, :generator) && !isdispatchtuple(types)
999+
function func_for_method_checked(m::Method, @nospecialize(types), sparams::SimpleVector)
1000+
if isdefined(m, :generator) && !Core.Compiler.may_invoke_generator(m, types, sparams)
9471001
error("cannot call @generated function `", m, "` ",
9481002
"with abstract argument types: ", types)
9491003
end
@@ -978,7 +1032,7 @@ function code_typed(@nospecialize(f), @nospecialize(types=Tuple);
9781032
types = to_tuple_type(types)
9791033
asts = []
9801034
for x in _methods(f, types, -1, world)
981-
meth = func_for_method_checked(x[3], types)
1035+
meth = func_for_method_checked(x[3], types, x[2])
9821036
(code, ty) = Core.Compiler.typeinf_code(meth, x[1], x[2], optimize, params)
9831037
code === nothing && error("inference not successful") # inference disabled?
9841038
debuginfo == :none && remove_linenums!(code)
@@ -997,7 +1051,7 @@ function return_types(@nospecialize(f), @nospecialize(types=Tuple))
9971051
world = ccall(:jl_get_world_counter, UInt, ())
9981052
params = Core.Compiler.Params(world)
9991053
for x in _methods(f, types, -1, world)
1000-
meth = func_for_method_checked(x[3], types)
1054+
meth = func_for_method_checked(x[3], types, x[2])
10011055
ty = Core.Compiler.typeinf_type(meth, x[1], x[2], params)
10021056
ty === nothing && error("inference not successful") # inference disabled?
10031057
push!(rt, ty)

src/dump.c

+21-3
Original file line numberDiff line numberDiff line change
@@ -1212,7 +1212,7 @@ static void write_mod_list(ios_t *s, jl_array_t *a)
12121212
}
12131213

12141214
// "magic" string and version header of .ji file
1215-
static const int JI_FORMAT_VERSION = 7;
1215+
static const int JI_FORMAT_VERSION = 8;
12161216
static const char JI_MAGIC[] = "\373jli\r\n\032\n"; // based on PNG signature
12171217
static const uint16_t BOM = 0xFEFF; // byte-order marker
12181218
static void write_header(ios_t *s)
@@ -2459,6 +2459,13 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code)
24592459
size_t nsyms = jl_array_len(code->slotnames);
24602460
assert(nsyms >= m->nargs && nsyms < INT32_MAX); // required by generated functions
24612461
write_int32(s.s, nsyms);
2462+
assert(nsyms == jl_array_len(code->slotflags));
2463+
ios_write(s.s, (char*)jl_array_data(code->slotflags), nsyms);
2464+
2465+
// N.B.: The layout of everything before this point is explicitly referenced
2466+
// by the various jl_ast_ accessors. Make sure to adjust those if you change
2467+
// the data layout.
2468+
24622469
for (i = 0; i < nsyms; i++) {
24632470
jl_sym_t *name = (jl_sym_t*)jl_array_ptr_ref(code->slotnames, i);
24642471
assert(jl_is_symbol(name));
@@ -2468,7 +2475,7 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code)
24682475
}
24692476

24702477
size_t nf = jl_datatype_nfields(jl_code_info_type);
2471-
for (i = 0; i < nf - 5; i++) {
2478+
for (i = 0; i < nf - 6; i++) {
24722479
if (i == 1) // skip codelocs
24732480
continue;
24742481
int copy = (i != 2); // don't copy contents of method_for_inference_limit_heuristics field
@@ -2536,6 +2543,9 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data)
25362543
code->pure = !!(flags & (1 << 0));
25372544

25382545
size_t nslots = read_int32(&src);
2546+
code->slotflags = jl_alloc_array_1d(jl_array_uint8_type, nslots);
2547+
ios_read(s.s, (char*)jl_array_data(code->slotflags), nslots);
2548+
25392549
jl_array_t *syms = jl_alloc_vec_any(nslots);
25402550
code->slotnames = syms;
25412551
for (i = 0; i < nslots; i++) {
@@ -2547,7 +2557,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data)
25472557
}
25482558

25492559
size_t nf = jl_datatype_nfields(jl_code_info_type);
2550-
for (i = 0; i < nf - 5; i++) {
2560+
for (i = 0; i < nf - 6; i++) {
25512561
if (i == 1)
25522562
continue;
25532563
assert(jl_field_isptr(jl_code_info_type, i));
@@ -2620,6 +2630,14 @@ JL_DLLEXPORT ssize_t jl_ast_nslots(jl_array_t *data)
26202630
}
26212631
}
26222632

2633+
JL_DLLEXPORT uint8_t jl_ast_slotflag(jl_array_t *data, size_t i)
2634+
{
2635+
assert(i < jl_ast_nslots(data));
2636+
if (jl_is_code_info(data))
2637+
return ((uint8_t*)((jl_code_info_t*)data)->slotflags->data)[i];
2638+
return ((uint8_t*)data->data)[1 + sizeof(int32_t) + i];
2639+
}
2640+
26232641
JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names)
26242642
{
26252643
size_t i, nargs = jl_array_len(names);

src/julia.h

+1
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data)
15481548
JL_DLLEXPORT uint8_t jl_ast_flag_inferred(jl_array_t *data);
15491549
JL_DLLEXPORT uint8_t jl_ast_flag_inlineable(jl_array_t *data);
15501550
JL_DLLEXPORT uint8_t jl_ast_flag_pure(jl_array_t *data);
1551+
JL_DLLEXPORT uint8_t jl_ast_slotflag(jl_array_t *data, size_t i);
15511552
JL_DLLEXPORT void jl_fill_argnames(jl_array_t *data, jl_array_t *names);
15521553

15531554
JL_DLLEXPORT int jl_is_operator(char *sym);

stdlib/InteractiveUtils/src/codeview.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ function _dump_function(@nospecialize(f), @nospecialize(t), native::Bool, wrappe
6767
t = to_tuple_type(t)
6868
tt = signature_type(f, t)
6969
(ti, env) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), tt, meth.sig)::Core.SimpleVector
70-
meth = Base.func_for_method_checked(meth, ti)
70+
meth = Base.func_for_method_checked(meth, ti, env)
7171
linfo = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, (Any, Any, Any, UInt), meth, ti, env, world)
7272
# get the code for it
7373
return _dump_function_linfo(linfo, world, native, wrapper, strip_ir_metadata, dump_module, syntax, optimize, debuginfo, params)

test/compiler/inference.jl

+19-1
Original file line numberDiff line numberDiff line change
@@ -1093,7 +1093,7 @@ function get_linfo(@nospecialize(f), @nospecialize(t))
10931093
tt = Tuple{ft, t.parameters...}
10941094
precompile(tt)
10951095
(ti, env) = ccall(:jl_type_intersection_with_env, Ref{Core.SimpleVector}, (Any, Any), tt, meth.sig)
1096-
meth = Base.func_for_method_checked(meth, tt)
1096+
meth = Base.func_for_method_checked(meth, tt, env)
10971097
return ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
10981098
(Any, Any, Any, UInt), meth, tt, env, world)
10991099
end
@@ -2224,3 +2224,21 @@ _call_rttf_test() = Core.Compiler.return_type(_rttf_test, Tuple{Any})
22242224
f_with_Type_arg(::Type{T}) where {T} = T
22252225
@test Base.return_types(f_with_Type_arg, (Any,)) == Any[Type]
22262226
@test Base.return_types(f_with_Type_arg, (Type{Vector{T}} where T,)) == Any[Type{Vector{T}} where T]
2227+
2228+
# Generated functions that only reference some of their arguments
2229+
@inline function my_ntuple(f::F, ::Val{N}) where {F,N}
2230+
N::Int
2231+
(N >= 0) || throw(ArgumentError(string("tuple length should be ≥0, got ", N)))
2232+
if @generated
2233+
quote
2234+
@Base.nexprs $N i -> t_i = f(i)
2235+
@Base.ncall $N tuple t
2236+
end
2237+
else
2238+
Tuple(f(i) for i = 1:N)
2239+
end
2240+
end
2241+
call_ntuple(a, b) = my_ntuple(i->(a+b; i), Val(4))
2242+
@test Base.return_types(call_ntuple, Tuple{Any,Any}) == [NTuple{4, Int}]
2243+
@test length(code_typed(my_ntuple, Tuple{Any, Val{4}})) == 1
2244+
@test_throws ErrorException code_typed(my_ntuple, Tuple{Any, Val})

0 commit comments

Comments
 (0)