-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Fix SDPA dispatch & make SDPA CI compatible with torch<2.1.1 #27940
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, looks good, would like @ArthurZucker to take a quick look before merging. Will cherry-pick this for the release.
@@ -1244,6 +1244,7 @@ def _autoset_attn_implementation( | |||
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user. | |||
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). | |||
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) | |||
requested_attn_implementation = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be "default" instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the idea here is to check whether the user passed attn_implementation="eager"
, attn_implementation="sdpa"
or attn_implementation="sdpa"
explicitly when loading the model from from_pretrained
or from_config
.
In case attn_implementation
is explicitly set, we hard error if a dependency is missing (torch>=2.1.1, model does not support SDPA), otherwise we smoothly fall back on eager.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only) | ||
elif not hard_check_only: | ||
config = cls._check_and_enable_sdpa( | ||
config, hard_check_only=False if requested_attn_implementation is None else True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks better thanks
As per title.
On torch==2.0.1, these do pass
On torch==2.1.1, these do pass (#26572 (comment))
There was a bug where even though we manually request
attn_implementation="eager"
, we would still go into the SDPA controlflow and hard check that the requirements are fine. Which is not what we want.