Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[PIR][Pass] Add fused_flash_attn_pass #63220

Merged
merged 7 commits into from
Apr 8, 2024

Conversation

yuanlehome
Copy link
Contributor

@yuanlehome yuanlehome commented Apr 3, 2024

PR Category

Inference

PR Types

New features

Description

Pcard-71500

支持如下五种pattern的融合。

         Q          K           V
         |          |           |
     transpose  transpose   transpose
         |          |           |
       scale    transpose       |
         |          |           |
         -- matmul--            |
              |                 |
    mask --- add                |
              |                 |
            cast                |
              |                 |
           softmax              |
              |                 |
             cast               |
              |                 |
              ------matmul------
                      |
                     out
         Q   K   V   None   mask
         |   |   |     |      |
         ------flash_attn------
                   |
                  out
         Q          K           V
         |          |           |
     transpose  transpose   transpose
         |          |           |
       scale    transpose       |
         |          |           |
         -- matmul--            |
              |                 |
    mask --- add                |
              |                 |
              |                 |
           softmax              |
              |                 |
              |                 |
              ------matmul------
                      |
                     out
         Q   K   V   None   mask
         |   |   |     |      |
         ------flash_attn------
                   |
                  out
         Q          K           V
         |          |           |
     transpose  transpose   transpose
         |          |           |
         |      transpose       |
         |          |           |
         -- matmul--            |
              |                 |
            scale               |
              |                 |
    mask --- add                |
              |                 |
            cast                |
              |                 |
           softmax              |
              |                 |
             cast               |
              |                 |
              ------matmul------
                      |
                     out
         Q   K   V   None   mask
         |   |   |     |      |
         ------flash_attn------
                   |
                  out
         Q          K           V
         |          |           |
     transpose  transpose   transpose
         |          |           |
         |    transpose         |
         |          |           |
         -- matmul--            |
              |                 |
            scale               |
              |                 |
    mask --- add                |
              |                 |
              |                 |
           softmax              |
              |                 |
              |                 |
              ------matmul------
                      |
                     out
         Q   K   V   None   mask
         |   |   |     |      |
         ------flash_attn------
                   |
                  out
              transpose
                     |
          -----------+----------
          |          |           |
        slice       slice      slice
          |          |           |
          Q          K           V
          |          |           |
          |       transpose      |
          |          |           |
          -- matmul--            |
               |                 |
             scale               |
               |                 |
     mask --- add                |
               |                 |
            softmax              |
               |                 |
               ------matmul------
                       |
                   transpose
                       |
                      out
            transpose
                |
          ------+------
          |     |     |
        slice slice slice
          |     |     |
          Q     K     V              mask
          |     |     |               |
    tranpose tranpose tranpose        |
          |     |     |               |
          -------flash_attn------------
                    |
                   out

Copy link

paddle-bot bot commented Apr 3, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@yuanlehome yuanlehome merged commit f668c0d into PaddlePaddle:develop Apr 8, 2024
30 checks passed
co63oc pushed a commit to co63oc/Paddle that referenced this pull request Apr 9, 2024
* step 1

* step 2

* step 3

* fix

* step 4

* step 5

* fix codestyle
co63oc pushed a commit to co63oc/Paddle that referenced this pull request Apr 10, 2024
* step 1

* step 2

* step 3

* fix

* step 4

* step 5

* fix codestyle
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants