-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
ONNX export failure for models invoking SDPA attention #28610
Comments
cc @fxmarty |
Thank you for the ping, thank you @BowenBao. cc @drisspg and linking relevant issues as well: pytorch/pytorch#110681 & pytorch/pytorch#108108 Solution 3. SDPA tracing without attention_mask is I think not possible due to the data-dependent controlflow here:
q_len > 1 . The reason for this controlflow is that SDPA attention_mask from is_causal is top-left aligned.
The same issue exists when tracing SDPA with symbolic_trace or with dynamo + fullgraph=True (https://pytorch.slack.com/archives/C033H6DJSJU/p1702029349053049?thread_ts=1694790001.945579&cid=C033H6DJSJU). Solution 1. is what the error suggests. I don't think it would be easy to implement (would need to have Solution 2. is probably the most doable (we would need to look at pad tokens). Currently we try as much as possible to pass a transformers/src/transformers/modeling_attn_mask_utils.py Lines 371 to 381 in e201864
|
Hi @fxmarty why Solution 1 is not easy to implement? I was thining something like if torch.jit.is_tracing():
attn_mask = old_attention()
else:
attn_mask = new_sdpa_attention() |
Thanks for your reply and context @fxmarty I have a local fix using solution 1, will put up a PR to unblock exporter in the short term, while waiting on pytorch/pytorch#108108 . |
@fxmarty I might be wrong but, installing from the latest source, I still have the same issue for BART based model export without attention_mask. Is is something planned to be supported? |
@LoicDagnas Yes, it is expected. If you want to trace the model without an transformers/src/transformers/modeling_attn_mask_utils.py Lines 391 to 393 in ece1b62
Note: this is due to the following controlflow
See for reference pytorch/pytorch#110681 & pytorch/pytorch#108108 |
There has been some discussion about its possible resolutions in the ONNX exporter team. I'd like to post an issue here as well to seek advice and preferences.
torch.jit.is_tracing()
and fallback to eager attn implementation if needed.attention_mask
before passing to SDPA if it is None.The text was updated successfully, but these errors were encountered: