Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

is Flux.huber_loss type-unstable ? #2459

Open
filchristou opened this issue Jun 17, 2024 · 1 comment
Open

is Flux.huber_loss type-unstable ? #2459

filchristou opened this issue Jun 17, 2024 · 1 comment

Comments

@filchristou
Copy link

It looks like Flux.huber_loss is type unstable when it comes to Zygote autodiff ?

using Flux, Zygote
import Statistics: mean

function internfunc_nobroad(m, x, y)
    modelvals = m(x)
    Flux.mse(modelvals, y)
end

function internfunc_nobroad_huberloss(m, x, y)
    modelvals = m(x)
    Flux.huber_loss(modelvals, y)
end

function wrapfunc(model, xdata, ydata, func)
    grad = let xdata=xdata, ydata=ydata
        Zygote.gradient(m -> func(m, xdata, ydata), model)
    end
    return grad
end

fc = Flux.Chain(Flux.Dense(5=>3, Flux.relu), Flux.Dense(3=>3, Flux.relu), Flux.Dense(3=>1))

fobs_ar = fill(5f0, 5, 10)
labels_ar = fill(2f0, 1, 10)
julia> @code_warntype wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad)

image

julia> @code_warntype wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad_huberloss)

image

@mcabbott
Copy link
Member

I don't know why this is unstable, the ways of Zygote are mysterious sometimes.

The loss broadcasts this function, which contains odd things: abs_error .< δ is strange as these are scalars. And ignore_derivatives is strange as Zygote shouldn't go here... the broadcasting uses ForwardDiff, as you can confirm with @show. But commenting out that line doesn't fix anything.

julia> @eval Flux.Losses @inline function _huber_metric(abs_error, δ)
           #TODO: remove ignore_derivatives when Zygote can handle this function with CuArrays
           temp = false # Zygote.ignore_derivatives(abs_error .<  δ)
           x = ofeltype(abs_error, 0.5)
           @show δ
           ((abs_error * abs_error) * temp) * x + δ * (abs_error - x * δ) * (1 - temp)
       end
_huber_metric (generic function with 7 methods)

julia> wrapfunc(fc, fobs_ar, labels_ar, internfunc_nobroad_huberloss)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
δ = Dual{Nothing}(1.0,0.0,1.0)
((layers = ((weight = Float32[0.0 0.0  0.0 0.0; 0.0 0.0  0.0 0.0; 0.0 0.0  0.0 0.0], bias = Float32[0.0, 0.0, 0.0], σ = nothing), (weight = Float32[0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], bias = Float32[0.0, 0.0, 0.0], σ = nothing), (weight = Float32[0.0 0.0 0.0], bias = Float32[1.0000001], σ = nothing)),),)

# for free to join this conversation on GitHub. Already have an account? # to comment
Projects
None yet
Development

No branches or pull requests

2 participants