-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
muP for Mamba and Mamba-2 #50
Conversation
Here are the so-called "coord checks", showing that the network activations behave the same no matter the width : (these show the scale of the activations for various widths (d_model), starting from t=1 (initialization) to t=5 (5 steps of training)) Mamba-2 (standard) : Mamba-2 (muP) : We can see that muP achieves the goal of making the activations behave roughly the same regardless of the width of the model. (n_layers=8, batch_size=64, batch_len=256, vocab_size=256, and for Mamba2 d_head=16) These plots were made with a LR of 1-3, which is quite low considering the optimal LR of 2**(-7) (see next comment). We can see some oscillations that start to appear. However, these don't seem to prevent muTransfer from happening (see next comment). These oscillations at a high LR were also observed on a Transformer (see this for example) |
And here are the LR sweeps that show that muTransfer indeed works : In both case, we can clearly see that the optimal LR for the SP shifts (becomes smaller and smaller, as observed in practice), wherease the optimal LR for the muP case stays roughly constant and the shape of the loss/LR curve looks the same no matter the width. In terms of number of parameters, width=64 is a 172k model while width=2048 is a 105M model (so the LR is stable across a 1000x increase in width). Each run consists of 5000 steps on Two things to note :
|
All the scripts used to create these experiments are available in the |
Concerning the muP implementation, it consists of modifying : Also, one need to be careful when using muP with weight decay. I'm now going to enter into the details of the first three points. From there, when training a model, you compute a ratio called Now we need to look at all the weights of the network and classify them into 3 categories : input, hidden and output :
muP tells us to : What weights do we have in our Mamba-2 model ?
So not great. So I just leave it as is (and that's great because it allows us to keep this concatenated weight and not waste time decompositing it).
This is a convolution weight, whose input and output channel is The shape of the weight is Concerning its bias, setting it to 0 greatly hurts Mamba-2 performance : So I just kept PyTorch's default init and it works fine.
And finally, the pre_logits are scaled down by Similar considerations are made for Mamba-1. Concerning the norms, as said earlier, I removed the weights from them (following https://arxiv.org/abs/2404.05728). Note that I didn't remove the weights from the RMSNorm just before the output projection (Mamba-2 only). works like that ¯_(ツ)_/¯ Note that many considerations and choices for this muP implementations were made empirically, and maybe there exists better a one! But it works quite well from what I've seen, and it's the only one for now. |
This PR adds the muP implementation with Mamba and Mamba-2.
muP is a special parametrization of the model that ensures that the network activations behave the same no matter the width (for "width", you can simply read "d_model" ie model dimension).
This makes it so that hyperparameters like learning rate and init standard deviation are the same, no matter if
d_model=64
ord_model=2048
. This is thus incredibly useful to find the optimal HPs to train a (target) large model :This is called muTransfer. So muP allows muTransfer.
See belows for checks that ensure that these muP implementations are correct.
(this is the second PR, made in order to see in one place all the changes requiered to implement muP with Mamba. The older PR was reverted)