Skip to content

Commit

Permalink
Speed up compilation by splitting into separate .cu files
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Nov 26, 2022
1 parent b784ed7 commit d95ee1a
Show file tree
Hide file tree
Showing 13 changed files with 251 additions and 318 deletions.
23 changes: 21 additions & 2 deletions csrc/flash_attn/fmha_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,16 @@ void set_params_dgrad(FMHA_dgrad_params &params,
params.dsoftmax_sum = dsoftmax_sum_d;
}

void run_fmha_fwd(Launch_params<FMHA_fprop_params> &launch_params) {
if (launch_params.params.d <= 32) {
run_fmha_fwd_hdim32(launch_params);
} else if (launch_params.params.d <= 64) {
run_fmha_fwd_hdim64(launch_params);
} else if (launch_params.params.d <= 128) {
run_fmha_fwd_hdim128(launch_params);
}
}

std::vector<at::Tensor>
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
Expand Down Expand Up @@ -307,13 +317,22 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
}

run_fmha_fp16_sm80(launch_params);
run_fmha_fwd(launch_params);

std::vector<at::Tensor> result = {softmax_lse};
if (return_softmax) {result.push_back(s);}
return result;
}

void run_fmha_bwd(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
if (params.d <= 32) {
run_fmha_bwd_hdim32(params, stream, configure);
} else if (params.d <= 64) {
run_fmha_bwd_hdim64(params, stream, configure);
} else if (params.d <= 128) {
run_fmha_bwd_hdim128(params, stream, configure);
}
}

std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
Expand Down Expand Up @@ -341,7 +360,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
TORCH_CHECK(is_sm8x || is_sm75);
auto launch = &run_fmha_dgrad_fp16_sm80;
auto launch = &run_fmha_bwd;

bool is_dropout = p_dropout > 0.0;
auto stream = at::cuda::getCurrentCUDAStream().stream();
Expand Down
8 changes: 6 additions & 2 deletions csrc/flash_attn/src/fmha.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,13 @@ struct Launch_params{

////////////////////////////////////////////////////////////////////////////////////////////////////

void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params);
void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params);
void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params);
void run_fmha_fwd_hdim128(Launch_params<FMHA_fprop_params> &launch_params);

void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
void run_fmha_bwd_hdim32(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
void run_fmha_bwd_hdim64(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
void run_fmha_bwd_hdim128(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);

void run_fmha_block_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);

Expand Down
13 changes: 13 additions & 0 deletions csrc/flash_attn/src/fmha_bwd_hdim128.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (c) 2022, Tri Dao.

// Splitting the different head dimentions to different files to speed up compilation.

#include "fmha_bwd_launch_template.h"

void run_fmha_bwd_hdim128(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
// work around for MSVC issue
FP16_SWITCH(params.is_bf16, [&] {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
});
}
18 changes: 18 additions & 0 deletions csrc/flash_attn/src/fmha_bwd_hdim32.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) 2022, Tri Dao.

// Splitting the different head dimentions to different files to speed up compilation.

#include "fmha_bwd_launch_template.h"

void run_fmha_bwd_hdim32(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
// work around for MSVC issue
FP16_SWITCH(params.is_bf16, [&] {
if (params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
} else if (params.seqlen_k >= 256) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
}
});
}
31 changes: 31 additions & 0 deletions csrc/flash_attn/src/fmha_bwd_hdim64.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2022, Tri Dao.

// Splitting the different head dimentions to different files to speed up compilation.

#include "fmha_bwd_launch_template.h"

void run_fmha_bwd_hdim64(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
// work around for MSVC issue
FP16_SWITCH(params.is_bf16, [&] {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
} else if (params.seqlen_k >= 256) {
if (dprops->major == 8 && dprops->minor == 0) {
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we keep V in registers.
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
} else if (dprops->major == 8 && dprops->minor > 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
} else if (dprops->major == 7 && dprops->minor == 5) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
}
}
});
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/* Copyright (c) 2022, Tri Dao.
*/
// Copyright (c) 2022, Tri Dao.

#pragma once

#include "static_switch.h"
#include "fp16_switch.h"
Expand All @@ -9,7 +10,7 @@
// Pick whether we should parallelize across seqlen_k (num_splits > 1) or not (num_splits=1).
// Parallelizing will have better occupancy, but has some overhead due to having to zero out
// dq_tmp and having to copy dq_tmp to dq.
int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int seqlen,
inline int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int seqlen,
int blocksize, bool is_causal) {
float n_waves_1 = float(batch_nheads) / (num_SMs * ctas_per_sm);
float eff_1 = n_waves_1 / ceil(n_waves_1);
Expand All @@ -29,22 +30,22 @@ int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int
}

template<typename Kernel_traits>
__global__ void fmha_dgrad_dot_do_o_kernel(FMHA_dgrad_params params) {
__global__ void fmha_bwd_dot_do_o_kernel(FMHA_dgrad_params params) {
fmha::compute_dot_do_o<Kernel_traits>(params);
}

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
__global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
__global__ void fmha_bwd_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
fmha::compute_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
}

template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
__global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_seqparallel_kernel(FMHA_dgrad_params params) {
__global__ void fmha_bwd_q_dk_dv_loop_seqparallel_kernel(FMHA_dgrad_params params) {
fmha::compute_dq_dk_dv_seqparallel<Kernel_traits, Is_dropout, Is_causal>(params);
}

template<typename Kernel_traits>
void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
void run_fmha_bwd_loop(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
Expand All @@ -63,20 +64,20 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stre
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
auto kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false>;
? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true>
: &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false>;
if (params.seqlen_k == blocksize_c) {
kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/1>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/1>;
? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/1>
: &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/1>;
} else if (params.seqlen_k == blocksize_c * 2) {
kernel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/2>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/2>;
? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/2>
: &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/2>;
}
auto kernel_seqparallel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, true>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, false>;
? &fmha_bwd_q_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, true>
: &fmha_bwd_q_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, false>;
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
Expand Down Expand Up @@ -104,50 +105,11 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stre
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
} else {
dim3 grid_dot(params.b, params.h, (params.seqlen_q + 128 - 1) / 128);
fmha_dgrad_dot_do_o_kernel<Kernel_traits><<<grid_dot, Kernel_traits::THREADS, 0, stream>>>(params);
fmha_bwd_dot_do_o_kernel<Kernel_traits><<<grid_dot, Kernel_traits::THREADS, 0, stream>>>(params);
int num_splits = params.seqlen_k / blocksize_c; // seqlen_k is divisible by blocksize_c
dim3 grid(params.b, params.h, num_splits);
kernel_seqparallel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
}
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
}

void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
// work around for MSVC issue
FP16_SWITCH(params.is_bf16, [&] {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (params.d <= 32) {
if (params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if (params.seqlen_k >= 256) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
}
} else if (params.d <= 64) {
if (params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if (params.seqlen_k >= 256) {
if (dprops->major == 8 && dprops->minor == 0) {
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we keep V in registers.
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if (dprops->major == 8 && dprops->minor > 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if (dprops->major == 7 && dprops->minor == 5) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
}
}
} else if (params.d <= 128) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
}
});
}
Loading

0 comments on commit d95ee1a

Please # to comment.