Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fixed unfused attn2d scale #2387

Merged
merged 1 commit into from
Jan 1, 2025
Merged

Conversation

laclouis5
Copy link
Contributor

@laclouis5 laclouis5 commented Jan 1, 2025

This PR fixes the scale parameter of the Attention2d unfused implementation (#2385).

Two aspects were corrected:

  • The scale was changed to q.size(-1) ** -0.5, i.e the rsqrt of the number of queries / sequence length. This matches the default value of torch.nn.functional.scaled_dot_product_attention.
  • Some transpositions were made in the wrong order, resulting in different results. The implementation now closely matches the PyTorch implementation as described in torch.nn.functional.scaled_dot_product_attention.

Tests were added to check the new implementation against the fused one.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@rwightman rwightman merged commit 2d734d9 into huggingface:main Jan 1, 2025
22 checks passed
@laclouis5 laclouis5 deleted the fix-attn2d-scale branch January 1, 2025 23:05
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants