From 073afd5931d6672ff4899429f83a881ff8182fe2 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 6 Dec 2024 15:31:04 -0800 Subject: [PATCH] [CI] Use torch 2.6.0.dev20241001, reduce torch #include --- .github/workflows/publish.yml | 34 ++++--- csrc/flash_attn/flash_api.cpp | 97 ++++++++++--------- csrc/flash_attn/src/flash.h | 8 +- .../src/flash_bwd_launch_template.h | 7 +- .../src/flash_fwd_launch_template.h | 16 +-- csrc/flash_attn/src/hardware_info.h | 41 ++++++++ flash_attn/__init__.py | 2 +- setup.py | 6 +- 8 files changed, 127 insertions(+), 84 deletions(-) create mode 100644 csrc/flash_attn/src/hardware_info.h diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index dcd295a7a..e6ab79e86 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -44,8 +44,8 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241010'] - cuda-version: ['11.8.0', '12.4.1'] + torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001'] + cuda-version: ['11.8.0', '12.3.2'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) @@ -68,10 +68,10 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -80,6 +80,7 @@ jobs: echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV + echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - name: Free up disk space if: ${{ runner.os == 'Linux' }} @@ -98,7 +99,7 @@ jobs: - name: Install CUDA ${{ matrix.cuda-version }} if: ${{ matrix.cuda-version != 'cpu' }} - uses: Jimver/cuda-toolkit@v0.2.18 + uses: Jimver/cuda-toolkit@v0.2.19 id: cuda-toolkit with: cuda: ${{ matrix.cuda-version }} @@ -106,18 +107,16 @@ jobs: # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }} method: 'network' - # We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions, - # not just nvcc - # sub-packages: '["nvcc"]' + sub-packages: '["nvcc"]' - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} run: | pip install --upgrade pip - # If we don't install before installing Pytorch, we get error for torch 2.0.1 - # ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none) - pip install lit # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools - pip install setuptools + pip install setuptools==68.0.0 + # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error + # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable + pip install typing-extensions==4.12.2 # We want to figure out the CUDA version to download pytorch # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix @@ -128,7 +127,12 @@ jobs: print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ ) if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then - pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} + # pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} + # Can't use --no-deps because we need cudnn etc. + # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001 + pip install jinja2 + pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl else pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} fi @@ -191,9 +195,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: '3.10' diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 41d963cd8..ba674beb7 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -5,11 +5,13 @@ // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. #include #include -#include #include +#include +#include #include +#include "hardware_info.h" #include "flash.h" #include "static_switch.h" @@ -294,7 +296,7 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n std::tuple set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q, const int head_size_rounded, const float p_dropout, - const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) { + const int num_splits, const int num_sm, struct c10::TensorOptions opts) { // This needs to match with run_mha_fwd_splitkv_dispatch const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); @@ -309,7 +311,7 @@ std::tuple set_params_splitkv(Flash_fwd_params ¶ms, if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout if (num_splits < 1) { // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block. - params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128); + params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128); } if (params.num_splits > 1) { softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); @@ -357,10 +359,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult const bool return_softmax, c10::optional gen_) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + // bool is_sm75 = cc_major == 7 && cc_minor == 5; + bool is_sm8x = cc_major == 8 && cc_minor >= 0; + bool is_sm90 = cc_major == 9 && cc_minor == 0; TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); // We will support Turing in the near future // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); @@ -435,9 +440,6 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; - auto opts = q.options(); auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); @@ -475,7 +477,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult at::Tensor softmax_lse_accum, out_accum; std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, - head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts); + head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts); // number of times random will be generated per thread, to offset philox counter in thc random // state @@ -536,10 +538,13 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const bool return_softmax, c10::optional gen_) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + // bool is_sm75 = cc_major == 7 && cc_minor == 5; + bool is_sm8x = cc_major == 8 && cc_minor >= 0; + bool is_sm90 = cc_major == 9 && cc_minor == 0; TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); // We will support Turing in the near future // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); @@ -654,9 +659,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; - auto opts = q.options(); auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); at::Tensor p; @@ -711,7 +713,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads, head_size, max_seqlen_k, max_seqlen_q, head_size_rounded, - p_dropout, /*num_splits*/ 0, dprops, opts); + p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts); } if (leftpad_k_.has_value()) { @@ -798,11 +800,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl TORCH_CHECK(false, "This flash attention build does not support backward."); #endif if (is_causal) { window_size_right = 0; } - auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm80 = dprops->major == 8 && dprops->minor == 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + // bool is_sm75 = cc_major == 7 && cc_minor == 5; + bool is_sm8x = cc_major == 8 && cc_minor >= 0; + bool is_sm80 = cc_major == 8 && cc_minor == 0; + bool is_sm90 = cc_major == 9 && cc_minor == 0; TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); // We will support Turing in the near future // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); @@ -895,9 +901,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl // TODO: change later, for now set to true for simplicity bool loop = true; - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; - auto opts = q.options(); auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); at::Tensor dq_accum; @@ -906,7 +909,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl if (!deterministic) { dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); } else { - const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads); + const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads); dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); } // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); @@ -1018,13 +1021,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); #endif - if (is_causal) { window_size_right = 0; } - auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm80 = dprops->major == 8 && dprops->minor == 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + // bool is_sm75 = cc_major == 7 && cc_minor == 5; + bool is_sm8x = cc_major == 8 && cc_minor >= 0; + bool is_sm80 = cc_major == 8 && cc_minor == 0; + bool is_sm90 = cc_major == 9 && cc_minor == 0; TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); // We will support Turing in the near future // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); @@ -1122,9 +1128,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size // TODO: change later, for now set to true for simplicity bool loop = true; - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; - auto opts = q.options(); auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat)); at::Tensor dq_accum; @@ -1141,7 +1144,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size if (!deterministic) { dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); } else { - const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads); + const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads); dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); } } @@ -1251,10 +1254,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he int num_splits ) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + // bool is_sm75 = cc_major == 7 && cc_minor == 5; + bool is_sm8x = cc_major == 8 && cc_minor >= 0; + bool is_sm90 = cc_major == 9 && cc_minor == 0; TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); // We will support Turing in the near future // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); @@ -1358,9 +1364,6 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.device()}; - auto opts = q.options(); auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); @@ -1471,12 +1474,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32"); params.cache_batch_idx = reinterpret_cast(cache_batch_idx.data_ptr()); } - + // Keep references to these tensors to extend their lifetime at::Tensor softmax_lse_accum, out_accum; std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, - head_size_rounded, /*dropout*/ 0.f, num_splits, dprops, opts); + head_size_rounded, /*dropout*/ 0.f, num_splits, get_num_sm(get_current_device()), opts); if (paged_KV) { params.block_table = block_table.data_ptr(); diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 6f597fbee..8838f59b6 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -7,13 +7,7 @@ #include #include -#ifdef OLD_GENERATOR_PATH -#include -#else -#include -#endif - -#include // For at::cuda::philox::unpack +#include // For at::cuda::philox::unpack constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 727d87e93..3b79a01c5 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -4,9 +4,10 @@ #pragma once -#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #include "static_switch.h" +#include "hardware_info.h" #include "flash.h" #include "flash_bwd_preprocess_kernel.h" #include "flash_bwd_kernel.h" @@ -72,8 +73,8 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; int gridDimx = num_n_block; if (params.deterministic) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h); + int num_sm = get_num_sm(get_current_device()); + gridDimx = (num_sm + params.b * params.h - 1) / (params.b * params.h); } dim3 grid_n(gridDimx, params.b, params.h); diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 900cf4671..b04667c55 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -3,10 +3,10 @@ ******************************************************************************/ #pragma once - -#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #include "static_switch.h" +#include "hardware_info.h" #include "flash.h" #include "flash_fwd_kernel.h" @@ -198,8 +198,8 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm8x = dprops->major == 8 && dprops->minor > 0; + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x = cc_major == 8 && cc_minor > 0; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), if (is_sm8x) { @@ -222,8 +222,8 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm8x = dprops->major == 8 && dprops->minor > 0; + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x = cc_major == 8 && cc_minor > 0; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if constexpr(!Is_dropout) { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), @@ -257,8 +257,8 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 160; - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm8x = dprops->major == 8 && dprops->minor > 0; + auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); + bool is_sm8x = cc_major == 8 && cc_minor > 0; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // For A100, H100, 128 x 32 is the fastest. // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), diff --git a/csrc/flash_attn/src/hardware_info.h b/csrc/flash_attn/src/hardware_info.h new file mode 100644 index 000000000..b218a29b3 --- /dev/null +++ b/csrc/flash_attn/src/hardware_info.h @@ -0,0 +1,41 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#if !defined(__CUDACC_RTC__) +#include "cuda_runtime.h" +#endif + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while (0) + + +inline int get_current_device() { + int device; + CHECK_CUDA(cudaGetDevice(&device)); + return device; +} + +inline std::tuple get_compute_capability(int device) { + int capability_major, capability_minor; + CHECK_CUDA(cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device)); + CHECK_CUDA(cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device)); + return {capability_major, capability_minor}; +} + +inline int get_num_sm(int device) { + int multiprocessor_count; + CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device)); + return multiprocessor_count; +} diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 436120802..34fdfef70 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.7.1.post1" +__version__ = "2.7.1.post2" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/setup.py b/setup.py index c20636b48..c2126134a 100644 --- a/setup.py +++ b/setup.py @@ -441,9 +441,9 @@ def get_wheel_url(): # We're using the CUDA version used to build torch, not the one currently installed # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) torch_cuda_version = parse(torch.version.cuda) - # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.4 + # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3 # to save CI time. Minor versions should be compatible. - torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.4") + torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" cuda_version = f"{torch_cuda_version.major}" @@ -542,7 +542,7 @@ def __init__(self, *args, **kwargs) -> None: else { "bdist_wheel": CachedWheelsCommand, }, - python_requires=">=3.8", + python_requires=">=3.9", install_requires=[ "torch", "einops",