-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Fix is_causal being a tensor #35791
Fix is_causal being a tensor #35791
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Find it a bit weird because I can't reproduce the tensor shape being eval as tensor
>>> torch.ones(*(1,2),1).shape[0] > 1
False
Yeah I was surprised at first and thought it might be a behavior change in torch but it's actually related to tracing: import torch
class model(torch.nn.Module):
def __init__(self):
super(model, self).__init__()
self.fc = torch.nn.Linear(1, 1)
def forward(self, x):
print("shape > 1: ", x.shape[0] > 1)
return self.fc(x)
model = model()
x = torch.randn(1, 1)
model(x) # shape > 1: False
torch.jit.trace(model, x, check_trace=False) # shape > 1: tensor(False)
torch.onnx.export(model, x, "model.onnx") # shape > 1: tensor(False) This should've impacted the gpt2 onnx tests but I guess it wasn't caught because it's only tested as a pure decoder. |
Got it, does the device synch not cause any issue for the exported model afterwards? |
@ArthurZucker didn't see any issues in optimum onnx exported models, but I added a guard for jit tracing just in case. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay thanks
* fix is_causal being a tensor * convert in sdpa attention only when jit tracing
What does this PR do?
We encountered this issue when gpt2 is part of a vision encoder decoder (add_cross_attention=True)
Usually
is_causal
is evaluated inis_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
by the fact thatattention_mask is None
is False, skipping the rest, but when attention mask is not None,query_states.shape[-2] > 1
is evaluated and returns a tensor, resulting in sdpa failing because it's expecting a bool.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.