diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 559c558cf..6602a6c9a 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -176,6 +176,16 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, params.dsoftmax_sum = dsoftmax_sum_d; } +void run_fmha_fwd(Launch_params &launch_params) { + if (launch_params.params.d <= 32) { + run_fmha_fwd_hdim32(launch_params); + } else if (launch_params.params.d <= 64) { + run_fmha_fwd_hdim64(launch_params); + } else if (launch_params.params.d <= 128) { + run_fmha_fwd_hdim128(launch_params); + } +} + std::vector mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i @@ -307,13 +317,22 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q launch_params.params.philox_args = gen->philox_cuda_state(counter_offset); } - run_fmha_fp16_sm80(launch_params); + run_fmha_fwd(launch_params); std::vector result = {softmax_lse}; if (return_softmax) {result.push_back(s);} return result; } +void run_fmha_bwd(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { + if (params.d <= 32) { + run_fmha_bwd_hdim32(params, stream, configure); + } else if (params.d <= 64) { + run_fmha_bwd_hdim64(params, stream, configure); + } else if (params.d <= 128) { + run_fmha_bwd_hdim128(params, stream, configure); + } +} std::vector mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size @@ -341,7 +360,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size bool is_sm80 = dprops->major == 8 && dprops->minor == 0; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; TORCH_CHECK(is_sm8x || is_sm75); - auto launch = &run_fmha_dgrad_fp16_sm80; + auto launch = &run_fmha_bwd; bool is_dropout = p_dropout > 0.0; auto stream = at::cuda::getCurrentCUDAStream().stream(); diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index a389f659b..88788a9bb 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -195,9 +195,13 @@ struct Launch_params{ //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_fmha_fp16_sm80(Launch_params &launch_params); +void run_fmha_fwd_hdim32(Launch_params &launch_params); +void run_fmha_fwd_hdim64(Launch_params &launch_params); +void run_fmha_fwd_hdim128(Launch_params &launch_params); -void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure); +void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure); +void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure); +void run_fmha_bwd_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure); void run_fmha_block_fp16_sm80(Launch_params &launch_params, const bool configure); diff --git a/csrc/flash_attn/src/fmha_bwd_hdim128.cu b/csrc/flash_attn/src/fmha_bwd_hdim128.cu new file mode 100644 index 000000000..98c6d1051 --- /dev/null +++ b/csrc/flash_attn/src/fmha_bwd_hdim128.cu @@ -0,0 +1,13 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimentions to different files to speed up compilation. + +#include "fmha_bwd_launch_template.h" + +void run_fmha_bwd_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { + // work around for MSVC issue + FP16_SWITCH(params.is_bf16, [&] { + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>; + run_fmha_bwd_loop(params, stream, configure); + }); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_bwd_hdim32.cu b/csrc/flash_attn/src/fmha_bwd_hdim32.cu new file mode 100644 index 000000000..136fbcfd9 --- /dev/null +++ b/csrc/flash_attn/src/fmha_bwd_hdim32.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimentions to different files to speed up compilation. + +#include "fmha_bwd_launch_template.h" + +void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { + // work around for MSVC issue + FP16_SWITCH(params.is_bf16, [&] { + if (params.seqlen_k == 128) { + using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>; + run_fmha_bwd_loop(params, stream, configure); + } else if (params.seqlen_k >= 256) { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>; + run_fmha_bwd_loop(params, stream, configure); + } + }); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_bwd_hdim64.cu b/csrc/flash_attn/src/fmha_bwd_hdim64.cu new file mode 100644 index 000000000..fd1ce9f38 --- /dev/null +++ b/csrc/flash_attn/src/fmha_bwd_hdim64.cu @@ -0,0 +1,31 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimentions to different files to speed up compilation. + +#include "fmha_bwd_launch_template.h" + +void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { + // work around for MSVC issue + FP16_SWITCH(params.is_bf16, [&] { + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (params.seqlen_k == 128) { + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; + run_fmha_bwd_loop(params, stream, configure); + } else if (params.seqlen_k >= 256) { + if (dprops->major == 8 && dprops->minor == 0) { + // Don't share smem for K & V, and don't keep V in registers + // This speeds things up by 2-3% by avoiding register spills, but it + // uses more shared memory, which is fine on A100 but not other GPUs. + // For other GPUs, we keep V in registers. + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>; + run_fmha_bwd_loop(params, stream, configure); + } else if (dprops->major == 8 && dprops->minor > 0) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>; + run_fmha_bwd_loop(params, stream, configure); + } else if (dprops->major == 7 && dprops->minor == 5) { + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; + run_fmha_bwd_loop(params, stream, configure); + } + } + }); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_bwd_launch_template.h similarity index 56% rename from csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu rename to csrc/flash_attn/src/fmha_bwd_launch_template.h index 085cc3fd3..ffdde3606 100644 --- a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_bwd_launch_template.h @@ -1,5 +1,6 @@ -/* Copyright (c) 2022, Tri Dao. - */ +// Copyright (c) 2022, Tri Dao. + +#pragma once #include "static_switch.h" #include "fp16_switch.h" @@ -9,7 +10,7 @@ // Pick whether we should parallelize across seqlen_k (num_splits > 1) or not (num_splits=1). // Parallelizing will have better occupancy, but has some overhead due to having to zero out // dq_tmp and having to copy dq_tmp to dq. -int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int seqlen, +inline int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int seqlen, int blocksize, bool is_causal) { float n_waves_1 = float(batch_nheads) / (num_SMs * ctas_per_sm); float eff_1 = n_waves_1 / ceil(n_waves_1); @@ -29,22 +30,22 @@ int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int } template -__global__ void fmha_dgrad_dot_do_o_kernel(FMHA_dgrad_params params) { +__global__ void fmha_bwd_dot_do_o_kernel(FMHA_dgrad_params params) { fmha::compute_dot_do_o(params); } template -__global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) { +__global__ void fmha_bwd_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) { fmha::compute_dq_dk_dv_1xN(params); } template -__global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_seqparallel_kernel(FMHA_dgrad_params params) { +__global__ void fmha_bwd_q_dk_dv_loop_seqparallel_kernel(FMHA_dgrad_params params) { fmha::compute_dq_dk_dv_seqparallel(params); } template -void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { +void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; @@ -63,20 +64,20 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params ¶ms, cudaStream_t stre // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { auto kernel = params.is_causal - ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel - : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + ? &fmha_bwd_dq_dk_dv_loop_kernel + : &fmha_bwd_dq_dk_dv_loop_kernel; if (params.seqlen_k == blocksize_c) { kernel = params.is_causal - ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel - : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + ? &fmha_bwd_dq_dk_dv_loop_kernel + : &fmha_bwd_dq_dk_dv_loop_kernel; } else if (params.seqlen_k == blocksize_c * 2) { kernel = params.is_causal - ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel - : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel; + ? &fmha_bwd_dq_dk_dv_loop_kernel + : &fmha_bwd_dq_dk_dv_loop_kernel; } auto kernel_seqparallel = params.is_causal - ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_seqparallel_kernel - : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_seqparallel_kernel; + ? &fmha_bwd_q_dk_dv_loop_seqparallel_kernel + : &fmha_bwd_q_dk_dv_loop_seqparallel_kernel; if( smem_size_dq_dk_dv >= 48 * 1024 ) { FMHA_CHECK_CUDA(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); @@ -104,7 +105,7 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params ¶ms, cudaStream_t stre kernel<<>>(params); } else { dim3 grid_dot(params.b, params.h, (params.seqlen_q + 128 - 1) / 128); - fmha_dgrad_dot_do_o_kernel<<>>(params); + fmha_bwd_dot_do_o_kernel<<>>(params); int num_splits = params.seqlen_k / blocksize_c; // seqlen_k is divisible by blocksize_c dim3 grid(params.b, params.h, num_splits); kernel_seqparallel<<>>(params); @@ -112,42 +113,3 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params ¶ms, cudaStream_t stre FMHA_CHECK_CUDA(cudaPeekAtLastError()); }); } - -void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) { - // work around for MSVC issue - FP16_SWITCH(params.is_bf16, [&] { - auto dprops = at::cuda::getCurrentDeviceProperties(); - if (params.d <= 32) { - if (params.seqlen_k == 128) { - using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>; - run_fmha_dgrad_fp16_sm80_loop_(params, stream, configure); - } else if (params.seqlen_k >= 256) { - using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>; - run_fmha_dgrad_fp16_sm80_loop_(params, stream, configure); - } - } else if (params.d <= 64) { - if (params.seqlen_k == 128) { - using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; - run_fmha_dgrad_fp16_sm80_loop_(params, stream, configure); - } else if (params.seqlen_k >= 256) { - if (dprops->major == 8 && dprops->minor == 0) { - // Don't share smem for K & V, and don't keep V in registers - // This speeds things up by 2-3% by avoiding register spills, but it - // uses more shared memory, which is fine on A100 but not other GPUs. - // For other GPUs, we keep V in registers. - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>; - run_fmha_dgrad_fp16_sm80_loop_(params, stream, configure); - } else if (dprops->major == 8 && dprops->minor > 0) { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>; - run_fmha_dgrad_fp16_sm80_loop_(params, stream, configure); - } else if (dprops->major == 7 && dprops->minor == 5) { - using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; - run_fmha_dgrad_fp16_sm80_loop_(params, stream, configure); - } - } - } else if (params.d <= 128) { - using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>; - run_fmha_dgrad_fp16_sm80_loop_(params, stream, configure); - } - }); -} \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu deleted file mode 100644 index ede2a899a..000000000 --- a/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu +++ /dev/null @@ -1,153 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#include -#include - -#include "static_switch.h" -#include "fp16_switch.h" -#include "fmha.h" -#include "fmha_fprop_kernel_1xN.h" - -// Find the number of splits that maximizes the occupancy. For example, if we have -// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is -// better than having 3 splits (efficiency = 0.67). However, we also don't want too many -// splits as that would incur more HBM reads/writes. -// So we find the best efficiency, then find the smallest number of splits that gets 95% -// of the best efficiency. -int num_splits_heuristic_fwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits) { - float max_efficiency = 0.f; - std::vector efficiency; - efficiency.reserve(max_splits); - for (int num_splits = 1; num_splits <= max_splits; num_splits++) { - float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm); - float eff = n_waves / ceil(n_waves); - // printf("num_splits = %d, eff = %f\n", num_splits, eff); - if (eff > max_efficiency) { max_efficiency = eff; } - efficiency.push_back(eff); - } - for (int num_splits = 1; num_splits <= max_splits; num_splits++) { - if (efficiency[num_splits - 1] > 0.95 * max_efficiency) { - // printf("num_splits chosen = %d\n", num_splits); - return num_splits; - } - } - return 1; -} - -template -__global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) { - fmha::device_1xN_loop(params); -} - -template -void run_fmha_fp16_sm80_loop_(Launch_params &launch_params) { - constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; - const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c; - - constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; - // Don't need smem_size_softmax_lse if we're not looping - const int smem_size = fmha::get_dynamic_smem_size() - + (loop_steps > 1 ? smem_size_softmax_lse : 0); - - // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. - // https://github.com/kokkos/kokkos-kernels/issues/349 - // https://github.com/HazyResearch/flash-attention/issues/21 - BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { - auto kernel = launch_params.params.is_causal - ? (launch_params.return_softmax - ? &fmha_fprop_fp16_sm80_loop_kernel - : &fmha_fprop_fp16_sm80_loop_kernel) - : (launch_params.return_softmax - ? &fmha_fprop_fp16_sm80_loop_kernel - : &fmha_fprop_fp16_sm80_loop_kernel); - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // Automatically set num_splits to maximize occupancy - if (launch_params.params.num_splits <= 0) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size); - auto dprops = at::cuda::getCurrentDeviceProperties(); - // printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount); - constexpr int M = Kernel_traits::Cta_tile_p::M; - launch_params.params.num_splits = num_splits_heuristic_fwd( - launch_params.params.b * launch_params.params.h, dprops->multiProcessorCount, - ctas_per_sm, - /*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M)) - ); - } - // printf("smem_size = %d\n", smem_size); - dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits); - kernel<<>>( - launch_params.params); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - }); -} - -void run_fmha_fp16_sm80(Launch_params &launch_params) { - FP16_SWITCH(launch_params.params.is_bf16, [&] { - auto dprops = at::cuda::getCurrentDeviceProperties(); - if (launch_params.params.d <= 32) { - if (launch_params.params.seqlen_k == 128) { - using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params); - } else if (launch_params.params.seqlen_k >= 256) { - using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params); - } - } else if (launch_params.params.d <= 64) { - if (launch_params.params.seqlen_k == 128) { - using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params); - } else if (launch_params.params.seqlen_k >= 256) { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params); - } - } else if (launch_params.params.d <= 128) { - // TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory - // to reduce register spilling. However, that increases the smem usage from ~41KB to ~105KB, - // reducing occupancy (only 1 kernel can be scheduled per SM instead of 2). This strategy gives - // some speedup (6-10%) for large batch size, but slows things down for smal batch size. - // Now that we have better parallelism (over seqlen_q), block size 128 is faster for small - // batch size and only slightly slower (~3%) on large batch size. - // For causal=True, block size 128 seems always faster (for small & large batch size). - // So we're just gonna use block size 128 for simplicity. - using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; - run_fmha_fp16_sm80_loop_(launch_params); - } - // if (launch_params.params.d == 64) { - // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - // // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>; - // // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u, elem_type>; - // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; - // run_fmha_fp16_sm80_loop_(launch_params); - // } - }); -} \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_fwd_hdim128.cu b/csrc/flash_attn/src/fmha_fwd_hdim128.cu new file mode 100644 index 000000000..fd927a364 --- /dev/null +++ b/csrc/flash_attn/src/fmha_fwd_hdim128.cu @@ -0,0 +1,12 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimentions to different files to speed up compilation. + +#include "fmha_fwd_launch_template.h" + +void run_fmha_fwd_hdim128(Launch_params &launch_params) { + FP16_SWITCH(launch_params.params.is_bf16, [&] { + using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fwd_loop(launch_params); + }); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_fwd_hdim32.cu b/csrc/flash_attn/src/fmha_fwd_hdim32.cu new file mode 100644 index 000000000..7c159ac7a --- /dev/null +++ b/csrc/flash_attn/src/fmha_fwd_hdim32.cu @@ -0,0 +1,17 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimentions to different files to speed up compilation. + +#include "fmha_fwd_launch_template.h" + +void run_fmha_fwd_hdim32(Launch_params &launch_params) { + FP16_SWITCH(launch_params.params.is_bf16, [&] { + if (launch_params.params.seqlen_k == 128) { + using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fwd_loop(launch_params); + } else if (launch_params.params.seqlen_k >= 256) { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fwd_loop(launch_params); + } + }); +} \ No newline at end of file diff --git a/csrc/flash_attn/src/fmha_fwd_hdim64.cu b/csrc/flash_attn/src/fmha_fwd_hdim64.cu new file mode 100644 index 000000000..10e202e8b --- /dev/null +++ b/csrc/flash_attn/src/fmha_fwd_hdim64.cu @@ -0,0 +1,17 @@ +// Copyright (c) 2022, Tri Dao. + +// Splitting the different head dimentions to different files to speed up compilation. + +#include "fmha_fwd_launch_template.h" + +void run_fmha_fwd_hdim64(Launch_params &launch_params) { + FP16_SWITCH(launch_params.params.is_bf16, [&] { + if (launch_params.params.seqlen_k == 128) { + using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fwd_loop(launch_params); + } else if (launch_params.params.seqlen_k >= 256) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_fwd_loop(launch_params); + } + }); +} diff --git a/csrc/flash_attn/src/fmha_fwd_launch_template.h b/csrc/flash_attn/src/fmha_fwd_launch_template.h new file mode 100644 index 000000000..2876d3a38 --- /dev/null +++ b/csrc/flash_attn/src/fmha_fwd_launch_template.h @@ -0,0 +1,92 @@ +// Copyright (c) 2022, Tri Dao. + +#pragma once + +#include + +#include +#include + +#include "static_switch.h" +#include "fp16_switch.h" +#include "fmha.h" +#include "fmha_fprop_kernel_1xN.h" + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 95% +// of the best efficiency. +// [2022-11-25] TD: Mark this as "inline" otherwise we get "multiple definition" error. +inline int num_splits_heuristic_fwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits) { + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm); + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { max_efficiency = eff; } + efficiency.push_back(eff); + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (efficiency[num_splits - 1] > 0.95 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +template +__global__ void fmha_fwd_loop_kernel(FMHA_fprop_params params) { + fmha::device_1xN_loop(params); +} + +template +void run_fmha_fwd_loop(Launch_params &launch_params) { + constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; + const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c; + + constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; + // Don't need smem_size_softmax_lse if we're not looping + const int smem_size = fmha::get_dynamic_smem_size() + + (loop_steps > 1 ? smem_size_softmax_lse : 0); + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { + auto kernel = launch_params.params.is_causal + ? (launch_params.return_softmax + ? &fmha_fwd_loop_kernel + : &fmha_fwd_loop_kernel) + : (launch_params.return_softmax + ? &fmha_fwd_loop_kernel + : &fmha_fwd_loop_kernel); + if( smem_size >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // Automatically set num_splits to maximize occupancy + if (launch_params.params.num_splits <= 0) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size); + auto dprops = at::cuda::getCurrentDeviceProperties(); + // printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount); + constexpr int M = Kernel_traits::Cta_tile_p::M; + launch_params.params.num_splits = num_splits_heuristic_fwd( + launch_params.params.b * launch_params.params.h, dprops->multiProcessorCount, + ctas_per_sm, + /*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M)) + ); + } + // printf("smem_size = %d\n", smem_size); + dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits); + kernel<<>>( + launch_params.params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); + }); +} diff --git a/csrc/flash_attn/src/fmha_kernel.h b/csrc/flash_attn/src/fmha_kernel.h index 153d6d7e4..62879769a 100644 --- a/csrc/flash_attn/src/fmha_kernel.h +++ b/csrc/flash_attn/src/fmha_kernel.h @@ -75,107 +75,4 @@ struct BlockInfoPadded { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Noloop_traits{ - // Interpretation of Cta_tile dims, i.e. Cta_tile_p: - enum{ STEP = Cta_tile::M }; - enum{ SEQLEN = Cta_tile::N }; - - template - inline __device__ Noloop_traits(const int bidc, const Block_info& binfo) - : bidc_(bidc) { - const int seqlen = binfo.actual_seqlen; - const int steps = (seqlen + STEP - 1) / STEP; - const int steps_per_chunk = (steps + CHUNKS - 1) / CHUNKS; - - const int step_begin = bidc_ * steps_per_chunk; - const int step_end = min(steps, (bidc_ + 1) * steps_per_chunk); - const int actual_steps = max(0, step_end - step_begin); - loop_offset_ = step_begin; - num_steps_ = actual_steps; - - } - - template - inline __device__ void move_all(Tiles & ... tiles) const { - using expand_type = int[]; - for( int s = 0; s < loop_offset_; s++ ) { - expand_type{ (tiles.move(), 0)... }; - } - } - - inline __device__ int get_idx_dk() const { - //return bidc_; - return bidc_ * 2 + 0; - } - - inline __device__ int get_idx_dv() const { - //return CHUNKS + bidc_; - return bidc_ * 2 + 1; - } - - inline __device__ int offset_loop_count(const int l) { - // convert loop counter to position in the outer sequence - return (loop_offset_ + l) * STEP; - } - - const uint32_t bidc_; - int loop_offset_; - int num_steps_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -std::tuple work_dist(const int total_ctas, const int heads_total) { - - constexpr int STEPS_PER_HEAD = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; - - const int num_full_heads = heads_total / total_ctas; - const int heads_last_wave = heads_total % total_ctas; - - int num_main_groups = 0; - int main_steps = 0; - int rest_steps = 0; - if( heads_last_wave > 0 ) { - // Number of CTA groups that process within heads. - num_main_groups = total_ctas / heads_last_wave; - // Remaining CTAs that process between heads. - const int rest_ctas = total_ctas - (heads_last_wave * num_main_groups); - if(rest_ctas == 0) { - // We have exactly "num_main_groups" CTAs to process each of the remaining heads. - main_steps = (STEPS_PER_HEAD + num_main_groups - 1) / num_main_groups; - num_main_groups = STEPS_PER_HEAD / main_steps; // Here: main_step > 0 - rest_steps = STEPS_PER_HEAD % main_steps; - - } else { - // Ideal number of steps if we could load-balance as evenly as possible. - const int steps_ideal = (heads_last_wave * STEPS_PER_HEAD + total_ctas - 1) / total_ctas; - // Iterations that a "rest" CTA has to do at most. - const int max_rest_iters = (heads_last_wave + rest_ctas - 1) / rest_ctas; - // Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main CTAs. - main_steps = steps_ideal; - rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; - for( ; main_steps * num_main_groups < STEPS_PER_HEAD; main_steps++ ) { - rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; - const int max_rest_total_steps = rest_steps * max_rest_iters; - if( max_rest_total_steps < main_steps ) - break; - } - rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; - } - } - - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - using Mma_tile_p = fmha::Hmma_tile; - - const int max_steps = STEPS_PER_HEAD * num_full_heads + std::max(main_steps, rest_steps); - const int elts_per_thread_per_step = Mma_tile_p::MMAS_M * Mma_tile_p::MMAS_N * 8; - const int elts_per_thread = max_steps * elts_per_thread_per_step; - - return {num_full_heads, num_main_groups, heads_last_wave, main_steps, rest_steps, elts_per_thread}; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace fmha diff --git a/setup.py b/setup.py index c27c041e6..551680486 100644 --- a/setup.py +++ b/setup.py @@ -119,8 +119,12 @@ def append_nvcc_threads(nvcc_extra_args): name="flash_attn_cuda", sources=[ "csrc/flash_attn/fmha_api.cpp", - "csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu", - "csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu", + "csrc/flash_attn/src/fmha_fwd_hdim32.cu", + "csrc/flash_attn/src/fmha_fwd_hdim64.cu", + "csrc/flash_attn/src/fmha_fwd_hdim128.cu", + "csrc/flash_attn/src/fmha_bwd_hdim32.cu", + "csrc/flash_attn/src/fmha_bwd_hdim64.cu", + "csrc/flash_attn/src/fmha_bwd_hdim128.cu", "csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu", "csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu", ], @@ -152,7 +156,7 @@ def append_nvcc_threads(nvcc_extra_args): setup( name="flash_attn", - version="0.2.1", + version="0.2.2", packages=find_packages( exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) ),