Skip to content

Commit

Permalink
Set block size to 64 x 64 for kvcache to avoid nvcc segfaults
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Sep 17, 2023
1 parent 8c8b4d3 commit c984208
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 14 deletions.
16 changes: 5 additions & 11 deletions csrc/flash_attn/src/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,12 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {

template<typename T, int Headdim>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
constexpr int kBlockM = 64; // Fixed for all head dimensions
if (!is_sm8x) { // A100, H100
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
} else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
}
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
}

template<typename T>
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.2.3"
__version__ = "2.2.3.post1"

from flash_attn.flash_attn_interface import (
flash_attn_func,
Expand Down
4 changes: 2 additions & 2 deletions training/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0

# Install FlashAttention
RUN pip install flash-attn==2.2.3
RUN pip install flash-attn==2.2.3.post1

# Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN git clone https://github.com/HazyResearch/flash-attention \
&& cd flash-attention && git checkout v2.2.3 \
&& cd flash-attention && git checkout v2.2.3.post1 \
&& cd csrc/fused_softmax && pip install . && cd ../../ \
&& cd csrc/rotary && pip install . && cd ../../ \
&& cd csrc/layer_norm && pip install . && cd ../../ \
Expand Down

0 comments on commit c984208

Please # to comment.