Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Layer Normalization #4

Open
Aktsvigun opened this issue Jul 8, 2021 · 3 comments
Open

Layer Normalization #4

Aktsvigun opened this issue Jul 8, 2021 · 3 comments

Comments

@Aktsvigun
Copy link

Hi,
thanks for a great implementation!

I wanted to clarify one thing that mismatches with the code, proposed in the article itself. In your code, you pre-normalize inputs, so that they are passed through LayerNorm before FFT. In the code, presented in the article, they have:

class FNetEncoderBlock ( nn . Module ) :
30 f o u r i e r _ l a y e r : Fou rie rT ran sfo rmLa ye r
31 f f _ l a y e r : FeedForwardLayer
32
33 @nn. compact
34 def _ _ c a l l _ _ ( s e l f , x , d e t e r m i n i s t i c ) :
35 m i x i n g _ o ut p ut = s e l f . f o u r i e r _ l a y e r ( x )
36 x = nn . LayerNorm (1 e−12 , name=" mixing_laye r_no rm " ) ( x + &
m i x i n g _ o ut p ut )
37 fe ed _fo rw a rd _o utp ut = s e l f . f f _ l a y e r ( x , d e t e r m i n i s t i c )
38 r e t u r n nn . LayerNorm (
39 1e−12 , name=" output_la ye r_no rm " ) ( x + fee d_fo rwa rd _outp ut )

which in my view is done in the opposite order.
Am I mistaken or is it indeed a bug?

@Aktsvigun
Copy link
Author

I see this code is damaged. Here is the image (A.5 in the paper):
Снимок экрана 2021-07-08 в 12 32 13

@Aktsvigun
Copy link
Author

A similar question regards dropout in the FeedForward layer. You have it added twice, while in the paper they add it only in the end:
Снимок экрана 2021-07-08 в 12 36 30

@erksch
Copy link

erksch commented Jul 25, 2021

@Aktsvigun you can checkout our repo https://github.com/erksch/fnet-pytorch. We reimplemented the architecture precisely to such a degree that we can even use the official checkpoints (converted from Jax to PyTorch).

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants