Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

ONNX export failure for models invoking SDPA attention #28610

Closed
BowenBao opened this issue Jan 19, 2024 · 6 comments · Fixed by #27931
Closed

ONNX export failure for models invoking SDPA attention #28610

BowenBao opened this issue Jan 19, 2024 · 6 comments · Fixed by #27931

Comments

@BowenBao
Copy link
Contributor

ValueError: Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument attn_implementation="eager" or pass an attention_mask input when tracing the model.

There has been some discussion about its possible resolutions in the ONNX exporter team. I'd like to post an issue here as well to seek advice and preferences.

  1. Check torch.jit.is_tracing() and fallback to eager attn implementation if needed.
  2. Create attention_mask before passing to SDPA if it is None.
  3. Support SDPA tracing w/o attention_mask (not sure how feasible this is).
@amyeroberts
Copy link
Collaborator

cc @fxmarty

@fxmarty
Copy link
Contributor

fxmarty commented Jan 22, 2024

Thank you for the ping, thank you @BowenBao. cc @drisspg and linking relevant issues as well: pytorch/pytorch#110681 & pytorch/pytorch#108108

Solution 3. SDPA tracing without attention_mask is I think not possible due to the data-dependent controlflow here:

is_causal=self.is_causal and attention_mask is None and q_len > 1,
q_len > 1. The reason for this controlflow is that SDPA attention_mask from is_causal is top-left aligned.

The same issue exists when tracing SDPA with symbolic_trace or with dynamo + fullgraph=True (https://pytorch.slack.com/archives/C033H6DJSJU/p1702029349053049?thread_ts=1694790001.945579&cid=C033H6DJSJU).

Solution 1. is what the error suggests. I don't think it would be easy to implement (would need to have torch.jit.is_tracing() controlflow that does magic on the model).

Solution 2. is probably the most doable (we would need to look at pad tokens). Currently we try as much as possible to pass a attn_mask=None since SDPA is able to dispatch to some mem-efficient attention & flash attention path only in the case. We already avoid setting the attention_mask to None in case we are tracing:

elif not is_tracing and torch.all(attention_mask == 1):
if query_length == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
attention_mask = None
elif key_value_length == query_length:
attention_mask = None
else:
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
pass

@thiagocrepaldi
Copy link

Thank you for the ping, thank you @BowenBao. cc @drisspg and linking relevant issues as well: pytorch/pytorch#110681 & pytorch/pytorch#108108

Solution 3. SDPA tracing without attention_mask is I think not possible due to the data-dependent controlflow here:

is_causal=self.is_causal and attention_mask is None and q_len > 1,

q_len > 1. The reason for this controlflow is that SDPA attention_mask from is_causal is top-left aligned.
The same issue exists when tracing SDPA with symbolic_trace or with dynamo + fullgraph=True (https://pytorch.slack.com/archives/C033H6DJSJU/p1702029349053049?thread_ts=1694790001.945579&cid=C033H6DJSJU).

Solution 1. is what the error suggests. I don't think it would be easy to implement (would need to have torch.jit.is_tracing() controlflow that does magic on the model).

Solution 2. is probably the most doable (we would need to look at pad tokens). Currently we try as much as possible to pass a attn_mask=None since SDPA is able to dispatch to some mem-efficient attention & flash attention path only in the case. We already avoid setting the attention_mask to None in case we are tracing:

elif not is_tracing and torch.all(attention_mask == 1):
if query_length == 1:
# For query_length == 1, causal attention and bi-directional attention are the same.
attention_mask = None
elif key_value_length == query_length:
attention_mask = None
else:
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
# Reference: https://github.com/pytorch/pytorch/issues/108108
pass

Hi @fxmarty why Solution 1 is not easy to implement? I was thining something like

if torch.jit.is_tracing():
    attn_mask = old_attention()
else:
    attn_mask = new_sdpa_attention()

@BowenBao
Copy link
Contributor Author

BowenBao commented Feb 1, 2024

Thanks for your reply and context @fxmarty

I have a local fix using solution 1, will put up a PR to unblock exporter in the short term, while waiting on pytorch/pytorch#108108 .

@LoicDagnas
Copy link

@fxmarty I might be wrong but, installing from the latest source, I still have the same issue for BART based model export without attention_mask. Is is something planned to be supported?

@fxmarty
Copy link
Contributor

fxmarty commented Feb 26, 2024

@LoicDagnas Yes, it is expected. If you want to trace the model without an attention_mask input, you should load your model with the argument attn_implementation="eager" passed to from_pretrained, as suggested in the error that should be raised:

raise ValueError(
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
)

Note: this is due to the following controlflow

is_causal=self.is_causal and attention_mask is None and tgt_len > 1,

See for reference pytorch/pytorch#110681 & pytorch/pytorch#108108

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
5 participants