From c984208ddbe958ff5d8fe83378adef8f099ab898 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 17 Sep 2023 16:14:58 -0700 Subject: [PATCH] Set block size to 64 x 64 for kvcache to avoid nvcc segfaults --- csrc/flash_attn/src/flash_fwd_launch_template.h | 16 +++++----------- flash_attn/__init__.py | 2 +- training/Dockerfile | 4 ++-- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 17124538a..ef713a871 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -115,18 +115,12 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm8x = dprops->major == 8 && dprops->minor > 0; constexpr int kBlockM = 64; // Fixed for all head dimensions - if (!is_sm8x) { // A100, H100 - // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, - // and for headdim 192 with block size 64 x 128. - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); - } else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); - } + // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, + // and for headdim 192 with block size 64 x 128. + // Also for headdim 160 with block size 64 x 128 after the rotary addition. + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); } template diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index dfe392748..fa82f2f15 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.2.3" +__version__ = "2.2.3.post1" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/training/Dockerfile b/training/Dockerfile index 6b7e24779..b6fa06390 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.2.3 +RUN pip install flash-attn==2.2.3.post1 # Install CUDA extensions for cross-entropy, fused dense, layer norm RUN git clone https://github.com/HazyResearch/flash-attention \ - && cd flash-attention && git checkout v2.2.3 \ + && cd flash-attention && git checkout v2.2.3.post1 \ && cd csrc/fused_softmax && pip install . && cd ../../ \ && cd csrc/rotary && pip install . && cd ../../ \ && cd csrc/layer_norm && pip install . && cd ../../ \