Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Adds full gelu without approximation #629

Merged
merged 8 commits into from
Feb 28, 2025
Merged

Conversation

se-schmitt
Copy link
Contributor

Adds the full gelu without approximation as gelu(x) and moves the tanh approximation used before to gelu_fast. See #628 for details.

PR Checklist

  • Tests are added
  • Documentation, if applicable

@ToucheSir ToucheSir linked an issue Feb 6, 2025 that may be closed by this pull request
@ToucheSir
Copy link
Member

ToucheSir commented Feb 6, 2025

https://github.com/FluxML/NNlib.jl/actions/runs/13177183802/job/36779142351?pr=629#step:7:842 is a real test failure. I think Flux's Nil + outputsize machinery needs to be adjusted to understand SpecialFunctions.erf. The question is how, so I've opened FluxML/Flux.jl#2588 to track this.

@mcabbott
Copy link
Member

mcabbott commented Feb 7, 2025

Quick look at how different these functions are:

julia> using SpecialFunctions, NNlib

julia> oftf(x, y) = oftype(float(x), y);

julia> new_gelu(x) = x/2*(1 + erf(x/sqrt(oftf(x,2))));

julia> rel(x) = (new_gelu(x) - gelu(x)) / new_gelu(x);

julia> rel.(-3:0.2f0:1)
21-element Vector{Float32}:
   0.101809666
   0.06506126
   0.038386323
   0.02019246
   0.008700018
   0.0021531228
  -0.0010194147
  -0.0021005166
  -0.0020575877
  -0.001547056
  -0.0009626968
  -0.00049560843
  -0.0002001288
  -5.4488235f-5
  -6.02081f-6
 NaN
   4.4374747f-6
   2.876006f-5
   7.55584f-5
   0.00013319723
   0.00018150361

julia> rel_eps(x) = (new_gelu(x) - gelu(x)) / eps(new_gelu(x));

julia> Int.(rel_eps.(-3:0.2f0:1))
21-element Vector{Int64}:
 -885402
 -999594
 -499514
 -213282
 -142868
  -26298
    8849
   24719
   31223
   14336
   10250
    5637
    2210
     504
      68
       0
      69
     253
    1104
    1409
    2562

@se-schmitt
Copy link
Contributor Author

I modified it such that gelu remains the same and added the full gelu as gelu_full as discussed in #628 . This avoids breaking changes and the test failure from above, however, gelu_full is still not compatible with Flux' outputsize function.

@se-schmitt
Copy link
Contributor Author

@ToucheSir Can this be merged? If not, what is required to make it mergeable?

@mcabbott
Copy link
Member

It's a little sad that NNlib must depend on SpecialFunctions... maybe not so expensive?

julia> @time_imports using SpecialFunctions
      8.7 ms  IrrationalConstants
               ┌ 0.0 ms DocStringExtensions.__init__() 
     46.6 ms  DocStringExtensions 97.36% compilation time
      0.6 ms  LogExpFunctions
               ┌ 2.5 ms OpenLibm_jll.__init__() 
      4.2 ms  OpenLibm_jll
      0.4 ms  JLLWrappers
               ┌ 9.2 ms CompilerSupportLibraries_jll.__init__() 
     11.1 ms  CompilerSupportLibraries_jll
               ┌ 6.0 ms OpenSpecFun_jll.__init__() 93.49% compilation time
      6.5 ms  OpenSpecFun_jll 86.17% compilation time
      3.2 ms  SpecialFunctions

Re name bike-shedding, some chance we should use neutral names like gelu_erf and gelu_tanh, with both names available immediately but const gelu = gelu_tanh for now to be non-breaking. (I do not think either should be called gelu_fast, as the point of tanh_fast is that we sometimes automatically replace tanh with that, but there is no plan to automatically replace one of these with the other.)

@se-schmitt
Copy link
Contributor Author

@mcabbott I modified the code as suggested:

  • Instead of the SpecialFuncitons.jl package, only OpenLibm_jll.jl is used now and the erf function is defined via ccall (as in SpecialFunctions.jl).
  • I also renamed the functions to gelu_tanh and gelu_erf with const gelu = gelu_tanh. I made this also transparent in the documentation.

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks basically fine to me, thanks.

Maybe avoiding SpecialFunctions is a rabbit-hole, sorry.

One question is: How well do these variants work on the GPU? Presumably ccall((:erff, libopenlibm), Float32, (Float32,), x) won't work... does SpecialFunctions have code to make erf.(cu(rand(10))) work by another path?

@ToucheSir
Copy link
Member

ToucheSir commented Feb 14, 2025

does SpecialFunctions have code to make erf.(cu(rand(10))) work by another path?

Yes, CUDA.jl defines its own overloads at https://github.com/JuliaGPU/CUDA.jl/blob/master/ext/SpecialFunctionsExt.jl.

If we want to talk load times, Flux has a direct dep on SpecialFunctions already. If import latency is a pressing concern, we could define a stub function gelu_erf end in NNlib and the method for that function in a SpecialFunctionsExt.

@se-schmitt
Copy link
Contributor Author

The implementation via OpenLibm_jll was naive, sorry... I added the missing rules for AD locally which made it compatible with ForwardDiff, Zygote and Enzyme. However, compatibility with other AD and the GPU packages would need further modifications.

I would prefer an option that includes SpecialFunctions.jl as this seems much cleaner to me (either as direct dependency or extension (is this a problem for Lux.jl where SpecialFun is not a direct dependency?)).

What would you prefer? @mcabbott @ToucheSir

@ToucheSir
Copy link
Member

Can you try the stub function + extension approach I suggested above? If that turns out to be a dead end, I'm fine with going back to the original plan and having SpecialFunctions as a direct dep. There's already an outsized chance any user of NNlib will have it in their environment, so we wouldn't lose much by including it.

@se-schmitt
Copy link
Contributor Author

@ToucheSir The extension approach worked out well. I tested gelu_erf with Flux and Lux and it worked seamlessly.

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, just a couple of touch-ups and let's get this merged.

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this was great work :)

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

"Full" gelu without approximation
3 participants