Skip to content

feat: Add support for flash attention converter #2560

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged

Conversation

gs-olive
Copy link
Collaborator

Description

  • Add new subgraph-matching variants to align with flash attention paradigm in SD + SDXL models
  • Add support for scale kwarg specification in both attention variants
  • Add testing for flash attention ATen operator

Fixes #2427

Type of change

  • New converter

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

- Add new subgraph-matching variants to align with flash attention
paradigm in SD + SDXL models
- Add support for `scale` kwarg specification in both attention variants
- Add testing for flash attention ATen operator
@gs-olive gs-olive requested review from zewenli98 and apbose December 27, 2023 18:35
@gs-olive gs-olive self-assigned this Dec 27, 2023
@github-actions github-actions bot added component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests labels Dec 27, 2023
@github-actions github-actions bot requested a review from narendasan December 27, 2023 18:36
Copy link
Collaborator

@zewenli98 zewenli98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

@gs-olive gs-olive merged commit de49d62 into pytorch:main Jan 9, 2024
@gs-olive gs-olive deleted the scaled_dot_product_attention_converter branch January 9, 2024 23:22
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support aten._scaled_dot_product_flash_attention
3 participants