diff --git a/base/boot.jl b/base/boot.jl index a9f33562ee481..99cda41c24800 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -521,7 +521,8 @@ end # invoke and wrap the results of @generated function (g::GeneratedFunctionStub)(@nospecialize args...) - body = g.gen(args...) + #body = g.gen(args...) + body = Core._apply_pure(g.gen, (args...,)) if body isa CodeInfo return body end diff --git a/src/julia.h b/src/julia.h index 3c10c398ab1d2..cfc252f2c4ac9 100644 --- a/src/julia.h +++ b/src/julia.h @@ -1202,6 +1202,7 @@ JL_DLLEXPORT jl_svec_t *jl_alloc_svec(size_t n); JL_DLLEXPORT jl_svec_t *jl_alloc_svec_uninit(size_t n); JL_DLLEXPORT jl_svec_t *jl_svec_copy(jl_svec_t *a); JL_DLLEXPORT jl_svec_t *jl_svec_fill(size_t n, jl_value_t *x); +// Construct the DataType for `Tuple{v, v, v...}` JL_DLLEXPORT jl_value_t *jl_tupletype_fill(size_t n, jl_value_t *v); JL_DLLEXPORT jl_sym_t *jl_symbol(const char *str) JL_NOTSAFEPOINT; JL_DLLEXPORT jl_sym_t *jl_symbol_lookup(const char *str) JL_NOTSAFEPOINT; diff --git a/src/method.c b/src/method.c index c9f00617f2cd1..820a702af92bf 100644 --- a/src/method.c +++ b/src/method.c @@ -376,6 +376,67 @@ STATIC_INLINE jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator return code; } +// Sets func.edges = MethodInstance[MethodInstance for generator(::Any, ::Any...)] +// This is so that @generated functions will be re-generated if any functions called from +// the generator body are invalidated. +static void set_codeinfo_forward_edge_to_generator(jl_method_t *def, + jl_code_info_t *func, + jl_value_t *generator, + jl_svec_t *sparam_vals, + jl_tupletype_t *ttdt) { + // Get generator function body (not GeneratedFunctionStub) to attach an edge to it + jl_value_t* gen_func = jl_get_field(generator, "gen"); + + // Get jl_method_t for generator body + // The generator body takes three sets of arguments: (typeof(func), params..., T...) + // For example, for `foo(x::Array{T}) where T`: + // Tuple{getfield(Main, Symbol("##s4#3")), typeof(Main.foo), Int64, Array{Int64}} + // BUT --- THIS IS THE WEIRD PART: + // Apparently the backedge needs to be from `##s4#3(::Any, ::Any)` instead of + // the actual types of the aruments! Who knows why!! Weirdness. + // (EDIT: I _think_ this is because the generatorbody is marked nospecialize?) + // Manually construct that weird signature, here: + size_t n_sparams = jl_svec_len(sparam_vals); + // Get def->nargs, not number of args in ttdt, to correctly count varargs. + size_t numargs = n_sparams + def->nargs; + // Construct Tuple{typeof(gen_func), ::Any, ::Any} for correct number of args. + jl_value_t* weird_types_tuple = jl_tupletype_fill(numargs, (jl_value_t*)jl_any_type); + jl_tupletype_t* typesig = (jl_tupletype_t*)jl_argtype_with_function(gen_func, weird_types_tuple); + + // TODO: I still don't know what the right way to specialize the method is. + // I've tried `jl_gf_invoke_lookup` (but that segfaults during bootstrap), + // `jl_get_specialization1` and `jl_specializations_get_linfo.` The last two seem + // to behave identically, so I don't know which is better. Certainly + // jl_get_specialization1 is simpler. + // UPDATE: Actually jl_specializations_get_linfo seems to also cause an error during + // boostrap, complaining about `UndefRefError()`. So maybe jl_get_specialization1. + + // next look for or create a specialization of this definition. + size_t min_valid = 0; + size_t max_valid = ~(size_t)0; + jl_method_instance_t *edge = jl_get_specialization1((jl_tupletype_t*)typesig, -1, + &min_valid, &max_valid, + 1 /* store new specialization */); + + if (edge != NULL && (jl_value_t*)edge != jl_nothing) { + // Now create the edges array and set the edge! + if (func->edges == jl_nothing) { + // TODO: How to construct this array type properly + jl_value_t* array_mi_type = jl_apply_type2((jl_value_t*)jl_array_type, + (jl_value_t*)jl_method_instance_type, jl_box_long(1)); + + func->edges = (jl_value_t*)jl_alloc_array_1d(array_mi_type, 0); + } + + //jl_method_instance_add_backedge(edge, linfo); + jl_array_ptr_1d_push((jl_array_t*)func->edges, (jl_value_t*)edge); + } + else { + jl_printf(JL_STDERR, "WARNING: no edge for generated function body "); + jl_static_show(JL_STDERR, (jl_value_t*)def); jl_printf(JL_STDERR, "\n"); + } +} + // return a newly allocated CodeInfo for the function signature // effectively described by the tuple (specTypes, env, Method) inside linfo JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo) @@ -420,6 +481,11 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo) jl_resolve_globals_in_ir(stmts, def->module, linfo->sparam_vals, 1); } + // Set forward edge from the staged func `func` to the generator, so that the + // generator will be rerun if any dependent functions called from the generator body + // are invalidated. This keeps generated functions up-to-date, like other functions. + set_codeinfo_forward_edge_to_generator(def, func, generator, linfo->sparam_vals, ttdt); + ptls->in_pure_callback = last_in; jl_lineno = last_lineno; ptls->world_age = last_age; diff --git a/test/method.jl b/test/method.jl new file mode 100644 index 0000000000000..1a4010c761669 --- /dev/null +++ b/test/method.jl @@ -0,0 +1,83 @@ +using Test + +@testset "Normal generated functions" begin + @generated g1(x) = :(x) + @test g1(1) == 1 + @test g1("hi") == "hi" +end + +# Invalidating Generated Functions +# TODO: For some reason this doesn't work inside a testset right now (See below) +@generated foo() = bar() +bar() = 2 +# We can still call foo() even though bar() was defined after! Hooray! :) +@test foo() == 2 +# Now we can change bar(), and foo() is also updated! Woohoo! :D +bar() = 3 +@test foo() == 3 + + +@testset "invalidating generated functions in a testset" begin + @generated foo() = bar() + bar() = 2 + + # TODO: It seems like this doesn't work because of the @testset. Is that expected? + # Would this work for regular functions? I think it's broken... + @test foo() == 2 + bar() = 3 + @test_broken foo() == 3 +end + + +# Functions that take arguments +@generated f(x) = f2(x) + f2(x) +f2(x::Type) = sizeof(x) +@test f(1) == 16 +f2(x::Type) = sizeof(x)รท2 +@test f(1) == 8 + + +# Method at bottom of call-stack accepts ::Type, not ::Any +# The simple case -- bar(::Any): +@generated foo(x) = bar(x) +bar(x) = 2 +@test foo(1) == 2 +bar(x) = 3 +@test foo(1) == 3 +# This also works, with t(::Type{Int}) +@generated f_type(x) = t(x) +t(::Type{Int}) = 2 +@test f_type(1) == 2 +t(::Type{Int}) = 3 +@test f_type(1) == 3 +# Yet for some reason this does not work: +# Somehow having t1(T) call typemax prevents forming a backedge from t1 to the generator. +@generated f_type2(x) = t1(x) +t1(T) = typemax(T) +@test f_type2(1) == typemax(Int) +t1(T) = 3 +@test_broken f_type2(1) == 3 + + +# Functions with type params +@generated f(x::T) where T<:Number = width(T) +width(::Type{Int}) where T = 5 +@test f(10) == 5 +width(::Type{Int}) where T = 100 +@test f(10) == 100 + +# It also works for newly defined types +struct MyNum <: Number x::Int end +width(::Type{MyNum}) where T = MyNum(5) +@test f(MyNum(10)) == MyNum(5) +width(::Type{MyNum}) where T = MyNum(100) +@test f(MyNum(10)) == MyNum(100) + + +# Functions with varargs +@generated f(a::T, b, c...) where T<:Number = bar(T) + bar(a) + bar(b) + sum(bar(v) for v in c) +bar(x) = 2 +@test f(2,3,4) == 8 +bar(x) = 3 +@test f(2,3,4) == 12 +@test f(2,3,4,5) == 15