-
-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Adds full gelu without approximation #629
Conversation
https://github.com/FluxML/NNlib.jl/actions/runs/13177183802/job/36779142351?pr=629#step:7:842 is a real test failure. I think Flux's |
Quick look at how different these functions are: julia> using SpecialFunctions, NNlib
julia> oftf(x, y) = oftype(float(x), y);
julia> new_gelu(x) = x/2*(1 + erf(x/sqrt(oftf(x,2))));
julia> rel(x) = (new_gelu(x) - gelu(x)) / new_gelu(x);
julia> rel.(-3:0.2f0:1)
21-element Vector{Float32}:
0.101809666
0.06506126
0.038386323
0.02019246
0.008700018
0.0021531228
-0.0010194147
-0.0021005166
-0.0020575877
-0.001547056
-0.0009626968
-0.00049560843
-0.0002001288
-5.4488235f-5
-6.02081f-6
NaN
4.4374747f-6
2.876006f-5
7.55584f-5
0.00013319723
0.00018150361
julia> rel_eps(x) = (new_gelu(x) - gelu(x)) / eps(new_gelu(x));
julia> Int.(rel_eps.(-3:0.2f0:1))
21-element Vector{Int64}:
-885402
-999594
-499514
-213282
-142868
-26298
8849
24719
31223
14336
10250
5637
2210
504
68
0
69
253
1104
1409
2562 |
I modified it such that |
@ToucheSir Can this be merged? If not, what is required to make it mergeable? |
It's a little sad that NNlib must depend on SpecialFunctions... maybe not so expensive?
Re name bike-shedding, some chance we should use neutral names like |
@mcabbott I modified the code as suggested:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks basically fine to me, thanks.
Maybe avoiding SpecialFunctions is a rabbit-hole, sorry.
One question is: How well do these variants work on the GPU? Presumably ccall((:erff, libopenlibm), Float32, (Float32,), x)
won't work... does SpecialFunctions have code to make erf.(cu(rand(10)))
work by another path?
Yes, CUDA.jl defines its own overloads at https://github.com/JuliaGPU/CUDA.jl/blob/master/ext/SpecialFunctionsExt.jl. If we want to talk load times, Flux has a direct dep on SpecialFunctions already. If import latency is a pressing concern, we could define a stub |
The implementation via OpenLibm_jll was naive, sorry... I added the missing rules for AD locally which made it compatible with ForwardDiff, Zygote and Enzyme. However, compatibility with other AD and the GPU packages would need further modifications. I would prefer an option that includes SpecialFunctions.jl as this seems much cleaner to me (either as direct dependency or extension (is this a problem for Lux.jl where SpecialFun is not a direct dependency?)). What would you prefer? @mcabbott @ToucheSir |
Can you try the stub function + extension approach I suggested above? If that turns out to be a dead end, I'm fine with going back to the original plan and having SpecialFunctions as a direct dep. There's already an outsized chance any user of NNlib will have it in their environment, so we wouldn't lose much by including it. |
@ToucheSir The extension approach worked out well. I tested |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, just a couple of touch-ups and let's get this merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this was great work :)
Adds the full gelu without approximation as
gelu(x)
and moves the tanh approximation used before togelu_fast
. See #628 for details.PR Checklist