Skip to content

Commit 41fcdf2

Browse files
authored
Fixes for Julia nightly (#300)
* Adapt to JuliaLang/julia#56509 * Adapt to JuliaLang/julia#54734 * Use StmtRange explicitly * Adapt to JuliaLang/julia#57230 * Reuse Cthulhu code structure for Compiler cache/finish overrides * Adapt to JuliaLang/julia#57475 * Adapt to JuliaLang/julia#55976 * Adapt to JuliaLang/julia#54734 * Use CC instead of .Compiler * Implement ir.argtypes[1] fix from JuliaLang/julia#54458 * Comment out failing tests To highlight which are broken, should probably be fixed before merging * Treat `getproperty(::Module, ::Symbol)` like GlobalRefs * Uncomment passing tests, explicitly mark others as broken * Evaluate GlobalRef only if binding is defined * Use `rrule` for getproperty(::Module, ::Symbol) * Bump compat bound for StructArrays * Raise compat bound for Cthulhu * Revert `isconst` change now that it is fixed * Adapt to `finishinfer!` signature change --------- Co-authored-by: Cédric Belmant <cedric.belmant@juliahub.com>
1 parent 21747f8 commit 41fcdf2

14 files changed

+108
-52
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ ChainRules = "1.44.6"
2222
ChainRulesCore = "1.20"
2323
Combinatorics = "1"
2424
Compiler = "~0"
25-
Cthulhu = "2.10.1"
25+
Cthulhu = "2.16.3"
2626
OffsetArrays = "1"
2727
PrecompileTools = "1"
2828
StaticArrays = "1"
29-
StructArrays = "0.6"
29+
StructArrays = "0.6, 0.7"
3030
julia = "1.10"
3131

3232
[extras]

src/analysis/forward.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize
3434
# discover what they are. frules should be written in such a way that
3535
# whether or not they return `nothing`, only depends on the non-tangent arguments
3636
frule_arginfo = ArgInfo(nothing, frule_argtypes)
37-
frule_si = StmtInfo(true)
37+
frule_si = StmtInfo(true, false)
3838
# turn off frule analysis in the frule to avoid cycling
3939
interp′ = disable_forward(interp)
4040
frule_call = CC.abstract_call_gf_by_type(interp′,

src/codegen/forward_demand.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -352,11 +352,11 @@ function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::Met
352352
end
353353
end
354354

355-
method_info = CC.MethodInfo(src)
355+
info = @static VERSION v"1.12.0-DEV.1293" ? CC.SpecInfo(src) : CC.MethodInfo(src)
356356
argtypes = ir.argtypes[1:mi.def.nargs]
357357
world = get_inference_world(interp)
358-
irsv = IRInterpretationState(interp, method_info, ir, mi, argtypes, world, src.min_world, src.max_world)
359-
rt = CC._ir_abstract_constant_propagation(interp, irsv)
358+
irsv = IRInterpretationState(interp, info, ir, mi, argtypes, world, src.min_world, src.max_world)
359+
rt = CC.ir_abstract_constant_propagation(interp, irsv)
360360

361361
ir = compact!(ir)
362362

src/codegen/reverse.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ function make_opaque_closure(interp, typ, name, meth_nargs::Int, isva, lno, ci,
1414
ocm = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
1515
typ, Union{}, rettype, @__MODULE__, ci, lno.line, lno.file, meth_nargs, isva, ()).source
1616
end
17-
return Expr(:new_opaque_closure, typ, Union{}, Any, ocm, revs...)
1817
else
1918
oc_nargs = Int64(meth_nargs)
20-
Expr(:new_opaque_closure, typ, Union{}, Any,
21-
Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci), revs...)
19+
ocm = Expr(:opaque_closure_method, name, oc_nargs, isva, lno, ci)
2220
end
21+
oc = Expr(:new_opaque_closure, typ, Union{}, Any, true, ocm, revs...)
22+
@static VERSION < v"1.12.0-DEV.691" ? deleteat!(oc.args, 4) : nothing
23+
oc
2324
end
2425

2526
function diffract_ir!(ir, ci, meth, sparams::Core.SimpleVector, nargs::Int, N::Int, interp=nothing, curs=nothing)

src/extra_rules.jl

+6
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,12 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk},
268268
val, Δ->(NoTangent(), NoTangent(), Δ)
269269
end
270270

271+
# XXX: We should instead skip differentiation in the IR.
272+
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getproperty), mod::Module, name::Symbol)
273+
val = getproperty(mod, name)
274+
val, Δ->(NoTangent(), NoTangent(), NoTangent())
275+
end
276+
271277
Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDiff/ChainRulesCore.jl/pull/581
272278

273279
# Avoid https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495

src/stage1/compiler_utils.jl

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Utilities that should probably go into CC
2-
using .Compiler: IRCode, CFG, BasicBlock, BBIdxIter
2+
using .CC: IRCode, CFG, BasicBlock, BBIdxIter
33

44
function Base.push!(cfg::CFG, bb::BasicBlock)
55
@assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start
@@ -30,10 +30,6 @@ if VERSION < v"1.12.0-DEV.1268"
3030

3131
Base.copy(ir::IRCode) = CC.copy(ir)
3232

33-
CC.BasicBlock(x::UnitRange) =
34-
BasicBlock(StmtRange(first(x), last(x)))
35-
CC.BasicBlock(x::UnitRange, preds::Vector{Int}, succs::Vector{Int}) =
36-
BasicBlock(StmtRange(first(x), last(x)), preds, succs)
3733
Base.length(c::CC.NewNodeStream) = CC.length(c)
3834
Base.setindex!(i::Instruction, args...) = CC.setindex!(i, args...)
3935
Base.size(x::CC.UnitRange) = CC.size(x)

src/stage1/generated.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ struct ∂⃖recurse{N}; end
66

77
include("recurse.jl")
88

9-
function generate_lambda_ex(world::UInt, source::LineNumberNode,
9+
# source is a Method starting from https://github.com/JuliaLang/julia/pull/57230
10+
function generate_lambda_ex(world::UInt, source::Union{Method,LineNumberNode},
1011
args::Core.SimpleVector, sparams::Core.SimpleVector, body::Expr)
1112
stub = Core.GeneratedFunctionStub(identity, args, sparams)
1213
return stub(world, source, body)
@@ -16,7 +17,7 @@ struct NonTransformableError
1617
args
1718
end
1819

19-
function perform_optic_transform(world::UInt, source::LineNumberNode,
20+
function perform_optic_transform(world::UInt, source::Union{Method,LineNumberNode},
2021
@nospecialize(ff::Type{∂⃖recurse{N}}), @nospecialize(args)) where {N}
2122
@assert N >= 1
2223

src/stage1/recurse.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,10 @@ function split_critical_edges!(ir)
183183
bb = ir.stmts[i][:inst].args[1]
184184
ir.stmts[i][:inst] = nothing
185185
bbnew = bb + ninserted
186-
insert!(cfg.blocks, bbnew, BasicBlock(i:i))
186+
insert!(cfg.blocks, bbnew, BasicBlock(StmtRange(i:i)))
187187
bb_rename_offset[bb] += 1
188188
bblock = cfg.blocks[bbnew+1]
189-
cfg.blocks[bbnew+1] = BasicBlock((i+1):last(bblock.stmts),
189+
cfg.blocks[bbnew+1] = BasicBlock(StmtRange((i+1):last(bblock.stmts)),
190190
bblock.preds, bblock.succs)
191191
i += 1
192192
while i <= last(bblock.stmts)

src/stage1/recurse_fwd.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E)
222222
return ci
223223
end
224224

225-
function perform_fwd_transform(world::UInt, source::LineNumberNode,
225+
function perform_fwd_transform(world::UInt, source::Union{Method,LineNumberNode},
226226
@nospecialize(ff::Type{∂☆recurse{N,E}}), @nospecialize(args)) where {N,E}
227227
if all(x->x <: ZeroBundle, args)
228228
return generate_lambda_ex(world, source,

src/stage2/forward.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@ end
2121
# unlikely to be the actual interface. For now, it is used for testing.
2222
function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1; eras_mode = false)
2323
interp = ADInterpreter(; forward=true, backward=false)
24-
match = Base._which(tt)
25-
frame = CC.typeinf_frame(interp, match.method, match.spec_types, match.sparams, #=run_optimizer=#true)
26-
mi = frame.linfo
24+
mi = @ccall jl_method_lookup_by_tt(tt::Any, Base.tls_world_age()::Csize_t, #= method table =# nothing::Any)::Ref{MethodInstance}
25+
ci = CC.typeinf_ext_toplevel(interp, mi, CC.SOURCE_MODE_ABI)
2726

2827
src = CC.copy(interp.unopt[0][mi].src)
29-
ir = CC.copy((@atomic :monotonic interp.opt[0][mi].inferred).ir::IRCode)
28+
ir = CC.copy((@atomic :monotonic ci.inferred).ir::IRCode)
3029

3130
# Find all Return Nodes
3231
vals = Pair{SSAValue, Int}[]
@@ -83,6 +82,7 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1; eras_mode = fa
8382
end
8483

8584
ir = forward_diff!(interp, ir, src, mi, vals; visit_custom!, transform!, eras_mode)
85+
ir.argtypes[1] = Tuple{}
8686

8787
return OpaqueClosure(ir)
8888
end

src/stage2/interpreter.jl

+68-23
Original file line numberDiff line numberDiff line change
@@ -273,12 +273,76 @@ end
273273
# TODO: `get_remarks` should get a cursor?
274274
Cthulhu.get_remarks(interp::ADInterpreter, key::Union{MethodInstance,InferenceResult}) = get(interp.remarks[interp.current_level], key, nothing)
275275

276-
function CC.finish(sv::InferenceState, interp::ADInterpreter)
277-
res = @invoke CC.finish(sv::InferenceState, interp::AbstractInterpreter)
278-
key = (@static VERSION v"1.12.0-DEV.317" ? CC.is_constproped(sv) : CC.any(sv.result.overridden_by_const)) ? sv.result : sv.linfo
279-
interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(sv)
276+
@static if VERSION v"1.13.0-DEV.126"
277+
function diffractor_finish(@specialize(finishfunc), state::InferenceState, interp::ADInterpreter, cycleid::Int)
278+
res = @invoke finishfunc(state::InferenceState, interp::AbstractInterpreter, cycleid::Int)
279+
key = CC.is_constproped(state) ? state.result : state.linfo
280+
interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(state)
281+
return res
282+
end
283+
else
284+
function diffractor_finish(@specialize(finishfunc), state::InferenceState, interp::ADInterpreter)
285+
res = @invoke finishfunc(state::InferenceState, interp::AbstractInterpreter)
286+
key = (@static VERSION v"1.12.0-DEV.317" ? CC.is_constproped(state) : CC.any(state.result.overridden_by_const)) ? state.result : state.linfo
287+
interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(state)
280288
return res
281289
end
290+
end
291+
292+
@static if VERSION v"1.12.0-DEV.1823"
293+
@static if VERSION v"1.13.0-DEV.126" || VERSION v"1.12.0-alpha1"
294+
CC.finishinfer!(state::InferenceState, interp::ADInterpreter, cycleid::Int) = diffractor_finish(CC.finishinfer!, state, interp, cycleid)
295+
else
296+
CC.finishinfer!(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finishinfer!, state, interp)
297+
end
298+
@static if VERSION v"1.12.0-DEV.1988"
299+
function CC.finish!(interp::ADInterpreter, caller::InferenceState, validation_world::UInt)
300+
Cthulhu.set_cthulhu_source!(caller.result)
301+
return @invoke CC.finish!(interp::AbstractInterpreter, caller::InferenceState, validation_world::UInt)
302+
end
303+
else
304+
function CC.finish!(interp::ADInterpreter, caller::InferenceState)
305+
Cthulhu.set_cthulhu_source!(caller.result)
306+
return @invoke CC.finish!(interp::AbstractInterpreter, caller::InferenceState)
307+
end
308+
end
309+
310+
elseif VERSION v"1.12.0-DEV.734"
311+
CC.finishinfer!(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finishinfer!, state, interp)
312+
function CC.finish!(interp::ADInterpreter, caller::InferenceState;
313+
can_discard_trees::Bool=false)
314+
Cthulhu.set_cthulhu_source!(caller.result)
315+
return @invoke CC.finish!(interp::AbstractInterpreter, caller::InferenceState;
316+
can_discard_trees)
317+
end
318+
319+
elseif VERSION v"1.11.0-DEV.737"
320+
CC.finish(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finish, state, interp)
321+
function CC.finish!(interp::ADInterpreter, caller::InferenceState)
322+
result = caller.result
323+
opt = result.src
324+
Cthulhu.set_cthulhu_source!(result)
325+
if opt isa CC.OptimizationState
326+
CC.ir_to_codeinf!(opt)
327+
end
328+
return nothing
329+
end
330+
function CC.transform_result_for_cache(::ADInterpreter, ::MethodInstance, ::WorldRange,
331+
result::InferenceResult)
332+
return result.src
333+
end
334+
335+
else # VERSION < v"1.11.0-DEV.737"
336+
CC.finish(state::InferenceState, interp::ADInterpreter) = diffractor_finish(CC.finish, state, interp)
337+
function CC.transform_result_for_cache(::ADInterpreter, ::MethodInstance, ::WorldRange,
338+
result::InferenceResult)
339+
return create_cthulhu_source(result.src, result.ipo_effects)
340+
end
341+
function CC.finish!(::ADInterpreter, caller::InferenceResult)
342+
Cthulhu.set_cthulhu_source(interp, caller)
343+
end
344+
345+
end # @static if
282346

283347
const StmtFlag = @static VERSION v"1.11.0-DEV.377" ? UInt32 : UInt8
284348
function diffractor_inlining_policy(@nospecialize(src), @nospecialize(info::CC.CallInfo),
@@ -303,10 +367,6 @@ function diffractor_inlining_policy(@nospecialize(src), @nospecialize(info::CC.C
303367
end
304368

305369
@static if VERSION v"1.12.0-DEV.45"
306-
function CC.transform_result_for_cache(interp::ADInterpreter,
307-
::MethodInstance, ::WorldRange, result::InferenceResult, ::Bool)
308-
return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects)
309-
end
310370
function CC.src_inlining_policy(interp::ADInterpreter,
311371
@nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::StmtFlag)
312372
ret = diffractor_inlining_policy(src, info, stmt_flag)
@@ -316,10 +376,6 @@ function CC.src_inlining_policy(interp::ADInterpreter,
316376
src::Any, info::CC.CallInfo, stmt_flag::StmtFlag)
317377
end
318378
else
319-
function CC.transform_result_for_cache(interp::ADInterpreter,
320-
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
321-
return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects)
322-
end
323379
function CC.inlining_policy(interp::ADInterpreter,
324380
@nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::StmtFlag,
325381
mi::MethodInstance, argtypes::Vector{Any})
@@ -351,17 +407,6 @@ function CC.optimize(interp::ADInterpreter, opt::OptimizationState,
351407
end
352408
=#
353409

354-
function _finish!(caller::InferenceResult)
355-
effects = caller.ipo_effects
356-
caller.src = Cthulhu.create_cthulhu_source(caller.src, effects)
357-
end
358-
359-
@static if VERSION v"1.11.0-DEV.737"
360-
CC.finish!(::ADInterpreter, caller::InferenceState) = _finish!(caller.result)
361-
else
362-
CC.finish!(::ADInterpreter, caller::InferenceResult) = _finish!(caller)
363-
end
364-
365410
@static if VERSION v"1.11.0-DEV.1278"
366411
function CC.bail_out_const_call(interp::ADInterpreter, result::CC.MethodCallResult,
367412
si::StmtInfo, sv::CC.AbsIntState)

test/forward_diff_no_inf.jl

+10-4
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,15 @@ module forward_diff_no_inf
3131
ir[SSAValue(i)][:flag] |= CC.IR_FLAG_REFINED
3232
end
3333

34-
method_info = CC.MethodInfo(#=propagate_inbounds=#true, nothing)
34+
info = @static if VERSION v"1.12.0-DEV.1293"
35+
CC.SpecInfo(#=nargs=#length(ir.argtypes), #=isva=#false, #=propagate_inbounds=#true, nothing)
36+
else
37+
CC.MethodInfo(#=propagate_inbounds=#true, nothing)
38+
end
3539
min_world = world = (interp).world
3640
max_world = Diffractor.get_world_counter()
37-
irsv = CC.IRInterpretationState(interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world)
38-
(rt, nothrow) = CC._ir_abstract_constant_propagation(interp, irsv)
41+
irsv = CC.IRInterpretationState(interp, info, ir, mi, ir.argtypes, world, min_world, max_world)
42+
(rt, nothrow) = CC.ir_abstract_constant_propagation(interp, irsv)
3943
return rt
4044
end
4145

@@ -79,6 +83,7 @@ module forward_diff_no_inf
7983
ir = first(only(Base.code_ircode(foo_148, Tuple{Float64})))
8084
Diffractor.forward_diff_no_inf!(ir, [SSAValue(1) => 1]; transform! = identity_transform!)
8185
ir2 = CC.compact!(ir)
86+
ir2.argtypes[1] = Tuple{}
8287
f = Core.OpaqueClosure(ir2; do_compile=false)
8388
@test f(1.0) == Bar148(1.0) # This would error if we were not handling constructors (%new) right
8489
end
@@ -96,6 +101,7 @@ module forward_diff_no_inf
96101
stmt = ir2.stmts[stmt_idx]
97102
@test stmt[:inst].name == :_coeff
98103
@test stmt[:type] == Float64
104+
ir2.argtypes[1] = Tuple{}
99105
f = Core.OpaqueClosure(ir2; do_compile=false)
100106
@test f(3.5) == 28.0
101107
end
@@ -124,6 +130,7 @@ module forward_diff_no_inf
124130
Diffractor.forward_diff_no_inf!(ir, diff_ssa .=> 1; transform! = identity_transform!)
125131
ir2 = CC.compact!(ir)
126132
CC.verify_ir(ir2) # This would error if we were not handling nonconst phi nodes correctly (after https://github.com/JuliaLang/julia/pull/50158)
133+
ir2.argtypes[1] = Tuple{}
127134
f = Core.OpaqueClosure(ir2; do_compile=false)
128135
@test f(3.5) == 3.5 # this will segfault if we are not handling phi nodes correctly
129136
end
@@ -154,4 +161,3 @@ module forward_diff_no_inf
154161
end
155162
end
156163
end # module
157-

test/gradcheck.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ end
9595

9696
@testset "sum, prod" begin
9797
@test gradcheck(x -> sum(abs2, x), randn(4, 3, 2))
98-
@test gradcheck(x -> sum(x[i] for i in 1:length(x)), randn(10))
98+
# Fails in `diffract_ir!` on $(Expr(:isdefined, :($(Expr(:static_parameter, 1)))))
99+
@test_broken gradcheck(x -> sum(x[i] for i in 1:length(x)), randn(10))
99100
@test gradcheck(x -> sum(i->x[i], 1:length(x)), randn(10)) # issue #231
100101
@test gradcheck(x -> sum((i->x[i]).(1:length(x))), randn(10))
101102
@test gradcheck(X -> sum(x -> x^2, X), randn(10))

test/reverse.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ let var"'" = Diffractor.PrimeDerivativeBack
7070
# Integration tests
7171
@test @inferred(sin'(1.0)) == cos(1.0)
7272
@test @inferred(sin''(1.0)) == -sin(1.0)
73-
@test @inferred(sin'''(1.0)) == -cos(1.0)
7473
# FIXME: These error with:
7574
# Control flow support not fully implemented yet for higher-order reverse mode (TODO)
75+
@test_broken @inferred(sin'''(1.0)) == -cos(1.0)
7676
@test_broken @inferred(sin''''(1.0)) == sin(1.0)
7777
@test_broken @inferred(sin'''''(1.0)) == cos(1.0)
7878
@test_broken @inferred(sin''''''(1.0)) == -sin(1.0)

0 commit comments

Comments
 (0)