Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix SDPA dispatch & make SDPA CI compatible with torch<2.1.1 #27940

Merged
merged 1 commit into from
Dec 11, 2023

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Dec 11, 2023

As per title.

On torch==2.0.1, these do pass

RUN_SLOW=1 pytest tests/models/bart -s -vvvvv -k "torchscript"
RUN_SLOW=1 pytest tests/models/llama -s -vvvvv -k "torchscript"
RUN_SLOW=1 pytest tests/models/whisper -s -vvvvv -k "torchscript"
RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/models/bert -s -vvvvv
RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/models/llama -s -vvvvv

On torch==2.1.1, these do pass (#26572 (comment))

RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/ -s -vvvvv -k "flash or sdpa"
RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/whisper -s -vvvvv -k "llama"
RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/models/llama -s -vvvvv
RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/models/bart -s -vvvvv
RUN_SLOW=1 CUDA_VISIBLE_DEVICES=0 pytest tests/models/bert -s -vvvvv

There was a bug where even though we manually request attn_implementation="eager", we would still go into the SDPA controlflow and hard check that the requirements are fine. Which is not what we want.

@fxmarty fxmarty requested a review from LysandreJik December 11, 2023 09:28
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, looks good, would like @ArthurZucker to take a quick look before merging. Will cherry-pick this for the release.

@@ -1244,6 +1244,7 @@ def _autoset_attn_implementation(
# Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user.
# The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager").
# The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model)
requested_attn_implementation = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be "default" instead?

Copy link
Contributor Author

@fxmarty fxmarty Dec 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the idea here is to check whether the user passed attn_implementation="eager", attn_implementation="sdpa" or attn_implementation="sdpa" explicitly when loading the model from from_pretrained or from_config.

In case attn_implementation is explicitly set, we hard error if a dependency is missing (torch>=2.1.1, model does not support SDPA), otherwise we smoothly fall back on eager.

@fxmarty fxmarty requested a review from ArthurZucker December 11, 2023 09:33
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

config = cls._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
elif not hard_check_only:
config = cls._check_and_enable_sdpa(
config, hard_check_only=False if requested_attn_implementation is None else True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks better thanks

@fxmarty fxmarty merged commit 9f18cc6 into huggingface:main Dec 11, 2023
3 checks passed
iantbutler01 pushed a commit to BismuthCloud/transformers that referenced this pull request Dec 16, 2023
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants