Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[enzyme] broken MultiHeadAttention gradient #2567

Open
CarloLucibello opened this issue Dec 31, 2024 · 5 comments
Open

[enzyme] broken MultiHeadAttention gradient #2567

CarloLucibello opened this issue Dec 31, 2024 · 5 comments
Labels

Comments

@CarloLucibello
Copy link
Member

CarloLucibello commented Dec 31, 2024

using Flux, Enzyme, Statistics, Random

function enzyme_withgradient(f, x...)
    args = []
    for x in x
        if x isa Number
            push!(args, Enzyme.Active(x))
        else
            push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x)))
        end
    end
    ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal)
    ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...)
    g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
    return ret[2], g
end

loss(model, x) = mean(model(x)[1])
model = MultiHeadAttention(16)
x = randn(Float32, 16, 5, 2)
enzyme_withgradient(loss, model, x)

Output:

ERROR: MethodError: no method matching function_attributes(::LLVM.UserOperandSet)
The function `function_attributes` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  function_attributes(::LLVM.Function)
   @ LLVM ~/.julia/packages/LLVM/wMjUU/src/core/function.jl:127

Stacktrace:
  [1] check_ir!(job::GPUCompiler.CompilerJob, errors::Vector{…}, imported::Set{…}, f::LLVM.Function, deletedfns::Vector{…}, mod::LLVM.Module)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler/validation.jl:402
  [2] check_ir!(job::GPUCompiler.CompilerJob, errors::Vector{Tuple{String, Vector{…}, Any}}, mod::LLVM.Module)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler/validation.jl:210
  [3] check_ir
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler/validation.jl:179 [inlined]
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:3413
  [5] codegen
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:3338 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5387
  [7] _thunk
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5387 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5439 [inlined]
  [9] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5550
 [10] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5735
 [11] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/DiEvV/src/Enzyme.jl:485
 [12] enzyme_withgradient(::Function, ::MultiHeadAttention{Dense{…}, Dropout{…}, Dense{…}}, ::Vararg{Any})
    @ Main ./REPL[14]:11
 [13] top-level scope
    @ ~/.julia/dev/Flux/prova.jl:16
Some type information was truncated. Use `show(err)` to see complete types.

cc @wsmoses

@wsmoses
Copy link
Contributor

wsmoses commented Dec 31, 2024

This should probably resolve the issue above: EnzymeAD/Enzyme.jl#2239

@CarloLucibello
Copy link
Member Author

Fixed, but gives some warnings

julia> enzyme_withgradient(loss, model, x)
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/2CW9L/src/utils.jl:59
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/2CW9L/src/utils.jl:59
(-0.06945832f0, (MultiHeadAttention(16; nheads=8), Float32[0.0098144375 0.019144177  0.005306968 0.009218991; -0.004492016 -0.007321746  0.00065406406 -0.0012907407;  ; 0.00060874515 -0.00632565  -0.0024870127 -0.0071006473; -0.010946414 -0.0075595975  0.006028706 -0.0074309492;;; 0.0026188223 0.010693712  0.02191737 0.0128253205; -0.002983723 0.0007982431  -0.0038285365 0.0014048074;  ; -0.002875765 -0.004498572  -0.0061464626 -0.0040150927; -0.004544966 -0.002571526  0.001993513 0.005268162]))

@wsmoses
Copy link
Contributor

wsmoses commented Jan 1, 2025

I think they can likely be ignored (that warning is over conservative and prints any time you have a spawn)

@CarloLucibello
Copy link
Member Author

@wsmoses this is still failing on julia 1.10 (works on 1.11)

julia> enzyme_withgradient(loss, model, x)
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/2CW9L/src/utils.jl:59
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/2CW9L/src/utils.jl:59
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/2CW9L/src/utils.jl:59
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/2CW9L/src/utils.jl:59
ERROR: Enzyme compilation failed due to an internal error.
 Please open an issue with the code to reproduce and full error log on github.com/EnzymeAD/Enzyme.jl
 To toggle more information for debugging (needed for bug reports), set Enzyme.Compiler.VERBOSE_ERRORS[] = true (default false)

Illegal replace ficticious phi for:   %_replacementE = phi {} addrspace(10)* , !dbg !319 of   %107 = call fastcc nonnull {} addrspace(10)* @julia_wait_11104() #438, !dbg !362

Stacktrace:
 [1] #wait#645
   @ ./condition.jl:130

Stacktrace:
  [1] julia_error(msg::String, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing}, data2::Ptr{LLVM.API.LLVMOpaqueValue}, B::Ptr{LLVM.API.LLVMOpaqueBuilder})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/errors.jl:384
  [2] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing}, data2::Ptr{LLVM.API.LLVMOpaqueValue}, B::Ptr{LLVM.API.LLVMOpaqueBuilder})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/errors.jl:210
  [3] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/R6sE8/src/api.jl:268
  [4] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:1706
  [5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:4550
  [6] codegen
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:3353 [inlined]
  [7] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5410
  [8] _thunk
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5410 [inlined]
  [9] cached_compilation
    @ ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5462 [inlined]
 [10] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::Tuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5573
 [11] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::Tuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/R6sE8/src/compiler.jl:5758
 [12] autodiff(::ReverseMode{true, true, FFIABI, false, false}, ::Const{typeof(loss)}, ::Type{Active}, ::Duplicated{MultiHeadAttention{Dense{…}, Dropout{…}, Dense{…}}}, ::Duplicated{Array{Float32, 3}})
    @ Enzyme ~/.julia/packages/Enzyme/R6sE8/src/Enzyme.jl:485
 [13] enzyme_withgradient(::Function, ::MultiHeadAttention{Dense{typeof(identity), Matrix{Float32}, Bool}, Dropout{Float64, Colon, TaskLocalRNG}, Dense{typeof(identity), Matrix{Float32}, Bool}}, ::Vararg{Any})
    @ Main ./REPL[3]:11
 [14] top-level scope
    @ REPL[7]:1
Some type information was truncated. Use `show(err)` to see complete types.

@wsmoses
Copy link
Contributor

wsmoses commented Jan 6, 2025

Ah yeah I just fixed that 1.11 intrinsic.

this one is more weird.

Can you try to make a more minimal mwe and open an issue in enzyme?

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants