๐ค[WIP] FFPA: Yet antother Faster Flash Prefill Attention with O(1) SRAM complexity & O(d/4) or O(1) register complexity for large headdim (D > 256), almost 1.8x~3x ๐ faster than SDPA EA with or without MMA Acc F32 on many devices: ๐L20 ~1.9xโ๐, ๐ A30 ~1.8xโ๐, ๐3080 ~2.9xโ๐, ๐4090 ~2.1xโ๐.
๐กNOTE: This project is still in its early dev stages and now provides some kernels and benchmarks for reference. More features will be added in the future. (Welcome to ๐๐๐ปstar this repo to support me ~)
@misc{ffpa-attn-mma@2025,
title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.},
url={https://github.com/DefTruth/ffpa-attn-mma.git},
note={Open-source software available at https://github.com/DefTruth/ffpa-attn-mma.git},
author={DefTruth etc},
year={2025}
}
- ๐ Installationโ๏ธ
- ๐ Python Testing๐
- ๐ FFPA L1~L3 Design๐ก
- ๐ FFPA L1: L20 ~1.9xโ๐
- ๐ FFPA L1: A30 ~1.8xโ๐
- ๐ FFPA L1: 3080 ~2.9xโ๐
- ๐ FFPA L1: 4090 ~2.1xโ๐
We have extended FlashAttention for large headdim (D > 256) by implementing Fine-grained Tiling at the MMA level (GEMM style) for the Q@K^T and P@V matmul. This approach results in a constant SRAM usage of Br * 16 or Bc * 16 (Br = Bc) for Q, K, and V, leading to an overall SRAM complexity of O(2 * Br * 16) โ O(1) and a register complexity of O(d/4) or O(1). Consequently, this method allows us to extend headdim beyond 256 and achieve faster performance compared to SDPA with or without MMA Accumulation F32 (1.8x~3x ๐ faster than SDPA EA).
We have named this new attention tiling technique FFPA: Faster Flash Prefill Attention. We have designed three (L1~L3)
levels of FFPA based on SRAM and register complexity considerations. All levels will not introduce any additional VRAM requirements, ensuring that the HBM memory complexity remains same as FlashAttention. ๐
- ๐L1: level 1, O(2xBrx16)โO(1) SRAM complexity, โO(d/4) register complexity.
- ๐L2: level 2, O(2xBrx16)โO(1) SRAM complexity, โO(1) register complexity + Q@K^T recomputation.
- ๐L3: level 3, O(2xBrx16)โO(1) SRAM complexity, โO(1) register complexity + scaling O via HBM offloading.
By leveraging this approach, we can achieve better performance for large headdim (D > 256) through a balanced utilization of FlashAttention (which is not designed to support D > 256) and SDPA EA. Approximate SRAM and register complexity analysis for L1~L3 is as follows: (d
=headdim, C,Br,Bc
=Constant, Br=Bc
) ๐
๐Complexity | ๐FFPA L1 | ๐FFPA L2 | ๐FFPA L3 | ๐FA-2 |
---|---|---|---|---|
SRAM | O(2xBrx16)โO(1) | O(2xBrx16)โO(1) | O(2xBrx16)โO(1) | โO(3xBrxd), dโ |
Register | โO(d/4), dโ | O((Bc/16)x4+2C)โO(1) | O((Bc/16)x4+2C)โO(1) | โO(d/2), dโ |
HBM | โFA2โO(Nd), O | โFA2โO(Nd), O | โFA2โO(Nd), O | โO(Nd), O |
Extra HBM | โFA2โO(N), m,l | โFA2โO(N), m,l | โFA2โO(N), m,l | โO(N), m,l |
๐๐Core Features๐๐: I have implemented FFPA L1~L3 using pure MMA PTX instructions, which supports many features such as Split-Q, SMEM Swizzle/Padding, QKV Multi-Stages(1~4), Tile MMAs/Warps, Mixed MMA F32/F16 Acc (Q@K^T MMA Acc F32 + P@V MMA Acc F16), Fully Shared QKV SMEM, Prefetch QKV g2s, Fully QKV Fine-grained Tiling(GEMM style), Collective Store, etc.
๐Feature | ๐Feature | ๐Feature | ๐Feature |
---|---|---|---|
โ๏ธTensor Cores | โ๏ธLoop over N/D | โ๏ธTile Block(Br, Bc) | โ๏ธMMA(m16n8k16) |
โ๏ธSplit Q(FA-2) | โ๏ธPack LDST(128 bits) | โ๏ธSMEM Swizzle/Pad | โ๏ธCopy Async |
โ๏ธTile MMA/Warp | โ๏ธQKV Multi-Stages(1~4) | โ๏ธCollective Store(Shfl) | โ๏ธPrefetch QKV g2s |
โ๏ธQKV Fine-grained Tiling | โ๏ธShared QKV SMEM | โ๏ธMixed MMA Acc | โ๏ธFFPA L1 Level |
- Python >= 3.10
- PyTorch >= 2.4.0, CUDA >= 12.4
- Recommended: PyTorch 2.5.1, CUDA 12.5
- Docker: nvcr.io/nvidia/pytorch:24.10-py3
The FFPA implemented in this repo can be install as a python library, namely, ffpa-attn
library (optional).
git clone https://github.com/DefTruth/ffpa-attn-mma.git
# clone, then, run bash .dev/install.sh directly or run commands:
python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl # pip uninstall ffpa-attn -y
L1: level 1, O(2xBrx16)โO(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, D=320-1024(FA2 not supported ๐). (Notes, *
=MMA Acc F32, ^
=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, ๐Benchmark)
- ๐ NVIDIA L20 (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.8xโ๐)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 56T | 63T | 58T | 58T | 55T | 56T | 54T | 55T | 54T | 55T | 54T | 56T |
FFPA L1* | 102T | 102T | 103T | 104T | 103T | 95T | 95T | 95T | 95T | 96T | 95T | 94T |
Speedup | 1.82x | 1.62x | 1.78x | 1.79x | 1.87x | 1.7x | 1.76x | 1.73x | 1.76x | 1.75x | 1.76x | 1.68x |
FFPA L1^ | 104T | 103T | 103T | 102T | 104T | 103T | 102T | 94T | 94T | 94T | 100T | 100T |
Speedup | 1.86x | 1.63x | 1.78x | 1.76x | 1.89x | 1.84x | 1.89x | 1.71x | 1.74x | 1.71x | 1.85x | 1.79x |
- ๐ NVIDIA L20 (
*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~1.9xโ๐)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 56T | 64T | 58T | 58T | 55T | 56T | 54T | 55T | 54T | 55T | 54T | 56T |
FFPA L1* | 105T | 102T | 104T | 103T | 105T | 95T | 95T | 94T | 94T | 94T | 102T | 101T |
Speedup | 1.88x | 1.59x | 1.79x | 1.78x | 1.91x | 1.7x | 1.76x | 1.71x | 1.74x | 1.71x | 1.89x | 1.8x |
FFPA L1^ | 104T | 103T | 103T | 102T | 103T | 103T | 102T | 94T | 94T | 94T | 100T | 100T |
Speedup | 1.86x | 1.61x | 1.78x | 1.76x | 1.87x | 1.84x | 1.89x | 1.71x | 1.74x | 1.71x | 1.85x | 1.79x |
- ๐ NVIDIA A30 (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.8xโ๐)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 25T | 25T | 24T | 24T | 24T | 24T | 23T | 22T | 22T | 22T | 22T | 18T |
FFPA L1* | 45T | 44T | 44T | 43T | 43T | 38T | 37T | 37T | 37T | 36T | 33T | 32T |
Speedup | 1.8x | 1.76x | 1.83x | 1.79x | 1.79x | 1.58x | 1.61x | 1.68x | 1.68x | 1.64x | 1.5x | 1.78x |
FFPA L1^ | 48T | 46T | 45T | 43T | 44T | 44T | 44T | 38T | 37T | 36T | 40T | 34T |
Speedup | 1.92x | 1.84x | 1.88x | 1.79x | 1.83x | 1.83x | 1.91x | 1.73x | 1.68x | 1.64x | 1.82x | 1.89x |
- ๐ NVIDIA A30 (
*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~1.9xโ๐)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 25T | 25T | 24T | 24T | 24T | 24T | 23T | 22T | 22T | 22T | 22T | 18T |
FFPA L1* | 48T | 46T | 46T | 43T | 44T | 38T | 38T | 38T | 37T | 36T | 40T | 34T |
Speedup | 1.92x | 1.84x | 1.92x | 1.79x | 1.83x | 1.58x | 1.65x | 1.73x | 1.68x | 1.64x | 1.82x | 1.89x |
FFPA L1^ | 48T | 46T | 45T | 43T | 44T | 44T | 44T | 38T | 37T | 36T | 39T | 34T |
Speedup | 1.92x | 1.84x | 1.88x | 1.79x | 1.83x | 1.83x | 1.91x | 1.73x | 1.68x | 1.64x | 1.77x | 1.89x |
- ๐ NVIDIA RTX 3080 Laptop (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~2.5xโ๐)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 13T | 16T | 11T | 16T | 15T | 15T | 15T | 15T | 14T | 14T | 14T | 14T |
FFPA L1* | 33T | 31T | 30T | 30T | 30T | 27T | 27T | 26T | 26T | 26T | 26T | 25T |
Speedup | 2.54x | 1.94x | 2.73x | 1.88x | 2.0x | 1.8x | 1.8x | 1.73x | 1.86x | 1.86x | 1.86x | 1.79x |
FFPA L1^ | 43T | 41T | 39T | 39T | 39T | 39T | 39T | 36T | 34T | 33T | 31T | 33T |
Speedup | 3.31x | 2.56x | 3.55x | 2.44x | 2.6x | 2.6x | 2.6x | 2.4x | 2.43x | 2.36x | 2.21x | 2.36x |
- ๐ NVIDIA RTX 3080 Laptop (
*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~2.9xโ๐)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 13T | 15T | 12T | 15T | 14T | 15T | 14T | 14T | 14T | 14T | 14T | 14T |
FFPA L1* | 38T | 36T | 34T | 35T | 34T | 31T | 32T | 31T | 30T | 28T | 27T | 27T |
Speedup | 2.92x | 2.4x | 2.83x | 2.33x | 2.43x | 2.07x | 2.29x | 2.21x | 2.14x | 2.0x | 1.93x | 1.93x |
FFPA L1^ | 44T | 41T | 39T | 39T | 38T | 39T | 39T | 36T | 34T | 32T | 31T | 33T |
Speedup | 3.38x | 2.73x | 3.25x | 2.6x | 2.71x | 2.6x | 2.79x | 2.57x | 2.43x | 2.29x | 2.21x | 2.36x |
- ๐ NVIDIA RTX 4090 (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.8xโ๐)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 81T | 94T | 85T | 85T | 79T | 81T | 79T | 80T | 79T | 80T | 78T | 78T |
FFPA L1* | 149T | 150T | 150T | 150T | 150T | 140T | 140T | 140T | 139T | 139T | 137T | 134T |
Speedup | 1.84x | 1.6x | 1.76x | 1.76x | 1.9x | 1.73x | 1.77x | 1.75x | 1.76x | 1.74x | 1.76x | 1.72x |
FFPA L1^ | 194T | 194T | 189T | 191T | 197T | 188T | 184T | 180T | 177T | 172T | 171T | 171T |
Speedup | 2.4x | 2.06x | 2.22x | 2.25x | 2.49x | 2.32x | 2.33x | 2.25x | 2.24x | 2.15x | 2.19x | 2.19x |
- ๐ NVIDIA RTX 4090 (
*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~2.1xโ๐)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 82T | 92T | 85T | 84T | 78T | 81T | 79T | 80T | 78T | 79T | 77T | 78T |
FFPA L1* | 176T | 170T | 171T | 171T | 171T | 161T | 160T | 161T | 160T | 158T | 165T | 164T |
Speedup | 2.15x | 1.85x | 2.01x | 2.04x | 2.19x | 1.99x | 2.03x | 2.01x | 2.05x | 2.0x | 2.14x | 2.1x |
FFPA L1^ | 200T | 191T | 189T | 191T | 188T | 188T | 186T | 179T | 175T | 173T | 172T | 170T |
Speedup | 2.44x | 2.08x | 2.22x | 2.27x | 2.41x | 2.32x | 2.35x | 2.24x | 2.24x | 2.19x | 2.23x | 2.18x |
๐You can test many custom FFPA kernels via Python and figure out the difference in their performance. The --gen-bench
and --plot
options help you generate a benchmark table in Markdown style and speedup bar plots on your device. Contributions of your benchmark tables and plots are welcome via a PR ๐๐.
- ๐ case: B=1, H=48, N=8192, D=320(
FA2 not supported
)
# You can test on many devices, such as Volta, Ampere, Ada, Hopper, ...
cd tests && python3 test.py --B 1 --H 48 --N 8192 --show-all --D 320
- ๐ case: Generate benchmark table and speedup bar plots on Your device.
cd tests && pip install matplotlib && python3 test.py --gen-bench --show-all --plot
๐กNOTE: Please check all configurable environment variables in env.py.
GNU General Public License v3.0
How to contribute? Wecome to starโญ๏ธ this repo to support me๐๐ป ~