-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
add Bigbird ONNX config #16427
add Bigbird ONNX config #16427
Conversation
The documentation is not available anymore as the PR was closed or merged. |
5040f17eba15504bad66b14a645bddd9b015ebb7 #15622 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this very clean implementation @vumichien 🚀 ! I have a small question about the changes to the modeling file, but this looks good to me 😄
Could you please fix the merge conflicts with the serialization.mdx
file and check that the slow tests pass by running:
RUN_SLOW=1 pytest tests/onnx/test_onnx_v2.py -k "bigbird"
@@ -2994,7 +2994,7 @@ def forward( | |||
# setting lengths logits to `-inf` | |||
logits_mask = self.prepare_question_mask(question_lengths, seqlen) | |||
if token_type_ids is None: | |||
token_type_ids = (~logits_mask).long() | |||
token_type_ids = torch.ones(logits_mask.size(), dtype=int) - logits_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these changes needed to enable inference with ONNX Runtime? Just trying to understand the need to change the modeling code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lewtun Yes, you are right. I have changed the type of logits_mask here from boolean to integer to enable inference with ONNX Runtime. The error looks like this if we use the original modeling code.
RuntimeError: 0INTERNAL ASSERT FAILED at "../torch/csrc/jit/ir/alias_analysis.cpp":611, please report a bug to PyTorch. We don't have an op for aten::fill_ but it isn't a special case. Argument types: Tensor, bool,
Candidates:
aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> (Tensor(a!))
aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> (Tensor(a!))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification! This looks fine to me :)
… add-onnx-bigbird
@@ -3057,5 +3057,5 @@ def prepare_question_mask(q_lengths: torch.Tensor, maxlen: int): | |||
# q_lengths -> (bz, 1) | |||
mask = torch.arange(0, maxlen).to(q_lengths.device) | |||
mask.unsqueeze_(0) # -> (1, maxlen) | |||
mask = mask < q_lengths | |||
mask = torch.where(mask < q_lengths, 1, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also in here, I have to change the mask type
@lewtun Thank you for reviewing my code. I have fixed the merge conflicts and all tests by running this |
This all looks good to me @vumichien - thanks for iterating! Gently pinging @LysandreJik or @sgugger for final approval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your PR!
* add Bigbird ONNX config
What does this PR do?
Add Bigbird OnnxConfig to make this model available for conversion.
Who can review?
@lewtun @LysandreJik