Skip to content

๐Ÿ“š[WIP] FFPA: Yet antother Faster Flash Prefill Attention with O(1)๐ŸŽ‰GPU SRAM complexity for headdim > 256, 1.8x~3xโ†‘๐ŸŽ‰faster vs SDPA EA.

License

Notifications You must be signed in to change notification settings

DefTruth/ffpa-attn-mma

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

๐Ÿค–[WIP] FFPA: Yet antother Faster Flash Prefill Attention with O(1) SRAM complexity & O(d/4) or O(1) register complexity for large headdim (D > 256), almost 1.8x~3x ๐ŸŽ‰ faster than SDPA EA with or without MMA Acc F32 on many devices: ๐Ÿ“ˆL20 ~1.9xโ†‘๐ŸŽ‰, ๐Ÿ“ˆ A30 ~1.8xโ†‘๐ŸŽ‰, ๐Ÿ“ˆ3080 ~2.9xโ†‘๐ŸŽ‰, ๐Ÿ“ˆ4090 ~2.1xโ†‘๐ŸŽ‰.

๐Ÿ’กNOTE: This project is still in its early dev stages and now provides some kernels and benchmarks for reference. More features will be added in the future. (Welcome to ๐ŸŒŸ๐Ÿ‘†๐Ÿปstar this repo to support me ~)

ยฉ๏ธCitations๐ŸŽ‰๐ŸŽ‰

@misc{ffpa-attn-mma@2025,
  title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.},
  url={https://github.com/DefTruth/ffpa-attn-mma.git},
  note={Open-source software available at https://github.com/DefTruth/ffpa-attn-mma.git},
  author={DefTruth etc},
  year={2025}
}

๐Ÿ“– Contents

๐Ÿ“– FFPA L1~L3: FlashAttention + QKV Fine-grained Tiling at MMA level๐Ÿ’ก

We have extended FlashAttention for large headdim (D > 256) by implementing Fine-grained Tiling at the MMA level (GEMM style) for the Q@K^T and P@V matmul. This approach results in a constant SRAM usage of Br * 16 or Bc * 16 (Br = Bc) for Q, K, and V, leading to an overall SRAM complexity of O(2 * Br * 16) โ‰ˆ O(1) and a register complexity of O(d/4) or O(1). Consequently, this method allows us to extend headdim beyond 256 and achieve faster performance compared to SDPA with or without MMA Accumulation F32 (1.8x~3x ๐ŸŽ‰ faster than SDPA EA).

We have named this new attention tiling technique FFPA: Faster Flash Prefill Attention. We have designed three (L1~L3) levels of FFPA based on SRAM and register complexity considerations. All levels will not introduce any additional VRAM requirements, ensuring that the HBM memory complexity remains same as FlashAttention. ๐Ÿ‘‡

  • ๐Ÿ“šL1: level 1, O(2xBrx16)โ‰ˆO(1) SRAM complexity, โ‰ˆO(d/4) register complexity.
  • ๐Ÿ“šL2: level 2, O(2xBrx16)โ‰ˆO(1) SRAM complexity, โ‰ˆO(1) register complexity + Q@K^T recomputation.
  • ๐Ÿ“šL3: level 3, O(2xBrx16)โ‰ˆO(1) SRAM complexity, โ‰ˆO(1) register complexity + scaling O via HBM offloading.

By leveraging this approach, we can achieve better performance for large headdim (D > 256) through a balanced utilization of FlashAttention (which is not designed to support D > 256) and SDPA EA. Approximate SRAM and register complexity analysis for L1~L3 is as follows: (d=headdim, C,Br,Bc=Constant, Br=Bc) ๐Ÿ‘‡

๐Ÿ“šComplexity ๐Ÿ“šFFPA L1 ๐Ÿ“šFFPA L2 ๐Ÿ“šFFPA L3 ๐Ÿ“šFA-2
SRAM O(2xBrx16)โ‰ˆO(1) O(2xBrx16)โ‰ˆO(1) O(2xBrx16)โ‰ˆO(1) โ‰ˆO(3xBrxd), dโ†‘
Register โ‰ˆO(d/4), dโ†‘ O((Bc/16)x4+2C)โ‰ˆO(1) O((Bc/16)x4+2C)โ‰ˆO(1) โ‰ˆO(d/2), dโ†‘
HBM โ‰ˆFA2โ‰ˆO(Nd), O โ‰ˆFA2โ‰ˆO(Nd), O โ‰ˆFA2โ‰ˆO(Nd), O โ‰ˆO(Nd), O
Extra HBM โ‰ˆFA2โ‰ˆO(N), m,l โ‰ˆFA2โ‰ˆO(N), m,l โ‰ˆFA2โ‰ˆO(N), m,l โ‰ˆO(N), m,l

๐Ÿ“š๐Ÿ‘‡Core Features๐ŸŽ‰๐ŸŽ‰: I have implemented FFPA L1~L3 using pure MMA PTX instructions, which supports many features such as Split-Q, SMEM Swizzle/Padding, QKV Multi-Stages(1~4), Tile MMAs/Warps, Mixed MMA F32/F16 Acc (Q@K^T MMA Acc F32 + P@V MMA Acc F16), Fully Shared QKV SMEM, Prefetch QKV g2s, Fully QKV Fine-grained Tiling(GEMM style), Collective Store, etc.

๐Ÿ“šFeature ๐Ÿ“šFeature ๐Ÿ“šFeature ๐Ÿ“šFeature
โœ”๏ธTensor Cores โœ”๏ธLoop over N/D โœ”๏ธTile Block(Br, Bc) โœ”๏ธMMA(m16n8k16)
โœ”๏ธSplit Q(FA-2) โœ”๏ธPack LDST(128 bits) โœ”๏ธSMEM Swizzle/Pad โœ”๏ธCopy Async
โœ”๏ธTile MMA/Warp โœ”๏ธQKV Multi-Stages(1~4) โœ”๏ธCollective Store(Shfl) โœ”๏ธPrefetch QKV g2s
โœ”๏ธQKV Fine-grained Tiling โœ”๏ธShared QKV SMEM โœ”๏ธMixed MMA Acc โœ”๏ธFFPA L1 Level

๐Ÿ“– Prerequisites

  • Python >= 3.10
  • PyTorch >= 2.4.0, CUDA >= 12.4
  • Recommended: PyTorch 2.5.1, CUDA 12.5
  • Docker: nvcr.io/nvidia/pytorch:24.10-py3

๐Ÿ“– Installation

The FFPA implemented in this repo can be install as a python library, namely, ffpa-attn library (optional).

git clone https://github.com/DefTruth/ffpa-attn-mma.git
# clone, then, run bash .dev/install.sh directly or run commands:
python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl # pip uninstall ffpa-attn -y

๐Ÿ“– FFPA L1 (Level 1): Benchmark ๐ŸŽ‰๐ŸŽ‰

L1: level 1, O(2xBrx16)โ‰ˆO(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, D=320-1024(FA2 not supported ๐Ÿ‘€). (Notes, *=MMA Acc F32, ^=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, ๐Ÿ‘‡Benchmark)

  • ๐Ÿ“š NVIDIA L20 (*=MMA Acc F32, ^=MMA Acc F16, T=TFLOPS, ~1.8xโ†‘๐ŸŽ‰)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 56T 63T 58T 58T 55T 56T 54T 55T 54T 55T 54T 56T
FFPA L1* 102T 102T 103T 104T 103T 95T 95T 95T 95T 96T 95T 94T
Speedup 1.82x 1.62x 1.78x 1.79x 1.87x 1.7x 1.76x 1.73x 1.76x 1.75x 1.76x 1.68x
FFPA L1^ 104T 103T 103T 102T 104T 103T 102T 94T 94T 94T 100T 100T
Speedup 1.86x 1.63x 1.78x 1.76x 1.89x 1.84x 1.89x 1.71x 1.74x 1.71x 1.85x 1.79x
  • ๐Ÿ“š NVIDIA L20 (*=MMA Acc: QK F32 + PV F16, ^=MMA Acc F16, T=TFLOPS, ~1.9xโ†‘๐ŸŽ‰)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 56T 64T 58T 58T 55T 56T 54T 55T 54T 55T 54T 56T
FFPA L1* 105T 102T 104T 103T 105T 95T 95T 94T 94T 94T 102T 101T
Speedup 1.88x 1.59x 1.79x 1.78x 1.91x 1.7x 1.76x 1.71x 1.74x 1.71x 1.89x 1.8x
FFPA L1^ 104T 103T 103T 102T 103T 103T 102T 94T 94T 94T 100T 100T
Speedup 1.86x 1.61x 1.78x 1.76x 1.87x 1.84x 1.89x 1.71x 1.74x 1.71x 1.85x 1.79x
  • ๐Ÿ“š NVIDIA A30 (*=MMA Acc F32, ^=MMA Acc F16, T=TFLOPS, ~1.8xโ†‘๐ŸŽ‰)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 25T 25T 24T 24T 24T 24T 23T 22T 22T 22T 22T 18T
FFPA L1* 45T 44T 44T 43T 43T 38T 37T 37T 37T 36T 33T 32T
Speedup 1.8x 1.76x 1.83x 1.79x 1.79x 1.58x 1.61x 1.68x 1.68x 1.64x 1.5x 1.78x
FFPA L1^ 48T 46T 45T 43T 44T 44T 44T 38T 37T 36T 40T 34T
Speedup 1.92x 1.84x 1.88x 1.79x 1.83x 1.83x 1.91x 1.73x 1.68x 1.64x 1.82x 1.89x
  • ๐Ÿ“š NVIDIA A30 (*=MMA Acc: QK F32 + PV F16, ^=MMA Acc F16, T=TFLOPS, ~1.9xโ†‘๐ŸŽ‰)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 25T 25T 24T 24T 24T 24T 23T 22T 22T 22T 22T 18T
FFPA L1* 48T 46T 46T 43T 44T 38T 38T 38T 37T 36T 40T 34T
Speedup 1.92x 1.84x 1.92x 1.79x 1.83x 1.58x 1.65x 1.73x 1.68x 1.64x 1.82x 1.89x
FFPA L1^ 48T 46T 45T 43T 44T 44T 44T 38T 37T 36T 39T 34T
Speedup 1.92x 1.84x 1.88x 1.79x 1.83x 1.83x 1.91x 1.73x 1.68x 1.64x 1.77x 1.89x
  • ๐Ÿ“š NVIDIA RTX 3080 Laptop (*=MMA Acc F32, ^=MMA Acc F16, T=TFLOPS, ~2.5xโ†‘๐ŸŽ‰)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 13T 16T 11T 16T 15T 15T 15T 15T 14T 14T 14T 14T
FFPA L1* 33T 31T 30T 30T 30T 27T 27T 26T 26T 26T 26T 25T
Speedup 2.54x 1.94x 2.73x 1.88x 2.0x 1.8x 1.8x 1.73x 1.86x 1.86x 1.86x 1.79x
FFPA L1^ 43T 41T 39T 39T 39T 39T 39T 36T 34T 33T 31T 33T
Speedup 3.31x 2.56x 3.55x 2.44x 2.6x 2.6x 2.6x 2.4x 2.43x 2.36x 2.21x 2.36x
  • ๐Ÿ“š NVIDIA RTX 3080 Laptop (*=MMA Acc: QK F32 + PV F16, ^=MMA Acc F16, T=TFLOPS, ~2.9xโ†‘๐ŸŽ‰)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 13T 15T 12T 15T 14T 15T 14T 14T 14T 14T 14T 14T
FFPA L1* 38T 36T 34T 35T 34T 31T 32T 31T 30T 28T 27T 27T
Speedup 2.92x 2.4x 2.83x 2.33x 2.43x 2.07x 2.29x 2.21x 2.14x 2.0x 1.93x 1.93x
FFPA L1^ 44T 41T 39T 39T 38T 39T 39T 36T 34T 32T 31T 33T
Speedup 3.38x 2.73x 3.25x 2.6x 2.71x 2.6x 2.79x 2.57x 2.43x 2.29x 2.21x 2.36x
  • ๐Ÿ“š NVIDIA RTX 4090 (*=MMA Acc F32, ^=MMA Acc F16, T=TFLOPS, ~1.8xโ†‘๐ŸŽ‰)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 81T 94T 85T 85T 79T 81T 79T 80T 79T 80T 78T 78T
FFPA L1* 149T 150T 150T 150T 150T 140T 140T 140T 139T 139T 137T 134T
Speedup 1.84x 1.6x 1.76x 1.76x 1.9x 1.73x 1.77x 1.75x 1.76x 1.74x 1.76x 1.72x
FFPA L1^ 194T 194T 189T 191T 197T 188T 184T 180T 177T 172T 171T 171T
Speedup 2.4x 2.06x 2.22x 2.25x 2.49x 2.32x 2.33x 2.25x 2.24x 2.15x 2.19x 2.19x
  • ๐Ÿ“š NVIDIA RTX 4090 (*=MMA Acc: QK F32 + PV F16, ^=MMA Acc F16, T=TFLOPS, ~2.1xโ†‘๐ŸŽ‰)
Algorithm 320 384 448 512 576 640 704 768 832 896 960 1024
SDPA EA 82T 92T 85T 84T 78T 81T 79T 80T 78T 79T 77T 78T
FFPA L1* 176T 170T 171T 171T 171T 161T 160T 161T 160T 158T 165T 164T
Speedup 2.15x 1.85x 2.01x 2.04x 2.19x 1.99x 2.03x 2.01x 2.05x 2.0x 2.14x 2.1x
FFPA L1^ 200T 191T 189T 191T 188T 188T 186T 179T 175T 173T 172T 170T
Speedup 2.44x 2.08x 2.22x 2.27x 2.41x 2.32x 2.35x 2.24x 2.24x 2.19x 2.23x 2.18x

๐Ÿ“– Python Testing

๐Ÿ‘‡You can test many custom FFPA kernels via Python and figure out the difference in their performance. The --gen-bench and --plot options help you generate a benchmark table in Markdown style and speedup bar plots on your device. Contributions of your benchmark tables and plots are welcome via a PR ๐ŸŽ‰๐ŸŽ‰.

  • ๐Ÿ“š case: B=1, H=48, N=8192, D=320(FA2 not supported)
# You can test on many devices, such as Volta, Ampere, Ada, Hopper, ...
cd tests && python3 test.py --B 1 --H 48 --N 8192 --show-all --D 320
  • ๐Ÿ“š case: Generate benchmark table and speedup bar plots on Your device.
cd tests && pip install matplotlib && python3 test.py --gen-bench --show-all --plot

๐Ÿ’กNOTE: Please check all configurable environment variables in env.py.

ยฉ๏ธLicense

GNU General Public License v3.0

๐ŸŽ‰Contribute

How to contribute? Wecome to starโญ๏ธ this repo to support me๐Ÿ‘†๐Ÿป ~

๐Ÿ“– References

About

๐Ÿ“š[WIP] FFPA: Yet antother Faster Flash Prefill Attention with O(1)๐ŸŽ‰GPU SRAM complexity for headdim > 256, 1.8x~3xโ†‘๐ŸŽ‰faster vs SDPA EA.

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published