-
Notifications
You must be signed in to change notification settings - Fork 108
feat: add mpnet model family #405
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
base: main
Are you sure you want to change the base?
Conversation
lib/bumblebee/text/mpnet.ex
Outdated
@@ -0,0 +1,458 @@ | |||
defmodule Bumblebee.Text.MPNet do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note this was copied from the Bert implementation. A few adjustments in the
options
but that was about it
test/bumblebee/text/mpnet_test.exs
Outdated
|
||
test ":for_masked_language_modeling" do | ||
assert {:ok, %{model: model, params: params, spec: spec}} = | ||
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MPNetForMaskedLM"}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test/bumblebee/text/mpnet_test.exs
Outdated
assert_all_close( | ||
outputs.hidden_state[[.., 1..3, 1..3]], | ||
Nx.tensor([ | ||
[[-0.2331, 1.7817, 1.1736], [-1.1001, 1.3922, -0.3391], [0.0408, 0.8677, -0.0779]] | ||
]) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to compare against the reference values from hf/transformers:
from transformers import MPNetModel
import torch
model = MPNetModel.from_pretrained("hf-internal-testing/tiny-random-MPNetModel")
inputs = {
"input_ids": torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
"attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
}
outputs = model(**inputs)
print(outputs.last_hidden_state.shape)
print(outputs.last_hidden_state[:, 1:4, 1:4])
#=> torch.Size([1, 10, 64])
#=> tensor([[[ 0.0033, -0.2547, 0.4954],
#=> [-1.5348, -1.5433, 0.4846],
#=> [ 0.7795, -0.3995, -0.9499]]], grad_fn=<SliceBackward0>)
I believe there are a few differences between MPNet and BERT, so we need to align the implementation accordingly. In particular, by a quick look some layer names differ, for example key->k
, value->v
, query->q
, so we need to update the layer mapping as well :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a ton! Regarding your last comment, are you looking here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! And the Bert implementation is https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py, which may be helpful for differences.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok made another round of improvements! thx for the direction.
lib/bumblebee/text/mpnet.ex
Outdated
@@ -0,0 +1,458 @@ | |||
defmodule Bumblebee.Text.MPNet do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: I think we want the name to be MpNet
to align with our naming conventions. Basically, in acronyms we capitalize only the first letter, as in BERT -> Bert
, RoBERTa
-> Roberta
. And we capitalize on each word, such as ResNet
, ConvNext
. We do this, because the reference names are often arbitrarily capitalized, and it's not ergonomic for library users to know the exact capitalization.
https://arxiv.org/pdf/2004.09297
Huggingface cards
https://huggingface.co/microsoft/mpnet-base
https://huggingface.co/sentence-transformers/multi-qa-mpnet-base-cos-v1