Skip to content

Commit

Permalink
[CI] Use torch 2.6.0.dev20241001, reduce torch #include
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Dec 6, 2024
1 parent cf0f4c3 commit 073afd5
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 84 deletions.
34 changes: 19 additions & 15 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ jobs:
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241010']
cuda-version: ['11.8.0', '12.4.1']
torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001']
cuda-version: ['11.8.0', '12.3.2']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
Expand All @@ -68,10 +68,10 @@ jobs:

steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -80,6 +80,7 @@ jobs:
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
Expand All @@ -98,26 +99,24 @@ jobs:

- name: Install CUDA ${{ matrix.cuda-version }}
if: ${{ matrix.cuda-version != 'cpu' }}
uses: Jimver/cuda-toolkit@v0.2.18
uses: Jimver/cuda-toolkit@v0.2.19
id: cuda-toolkit
with:
cuda: ${{ matrix.cuda-version }}
linux-local-args: '["--toolkit"]'
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
method: 'network'
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
# not just nvcc
# sub-packages: '["nvcc"]'
sub-packages: '["nvcc"]'

- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
run: |
pip install --upgrade pip
# If we don't install before installing Pytorch, we get error for torch 2.0.1
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
pip install lit
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
pip install setuptools
pip install setuptools==68.0.0
# With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error
# AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable
pip install typing-extensions==4.12.2
# We want to figure out the CUDA version to download pytorch
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
Expand All @@ -128,7 +127,12 @@ jobs:
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
)
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
# pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
# Can't use --no-deps because we need cudnn etc.
# Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001
pip install jinja2
pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
else
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
Expand Down Expand Up @@ -191,9 +195,9 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- uses: actions/setup-python@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'

Expand Down
97 changes: 50 additions & 47 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>

#include <cutlass/numeric_types.h>

#include "hardware_info.h"
#include "flash.h"
#include "static_switch.h"

Expand Down Expand Up @@ -294,7 +296,7 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params, const int batch_size,
const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
const int head_size_rounded, const float p_dropout,
const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
const int num_splits, const int num_sm, struct c10::TensorOptions opts) {

// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
Expand All @@ -309,7 +311,7 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params,
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
if (num_splits < 1) {
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128);
}
if (params.num_splits > 1) {
softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
Expand Down Expand Up @@ -357,10 +359,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult
const bool return_softmax,
c10::optional<at::Generator> gen_) {

auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};

auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
// bool is_sm75 = cc_major == 7 && cc_minor == 5;
bool is_sm8x = cc_major == 8 && cc_minor >= 0;
bool is_sm90 = cc_major == 9 && cc_minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
Expand Down Expand Up @@ -435,9 +440,6 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);

// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();

auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
Expand Down Expand Up @@ -475,7 +477,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult
at::Tensor softmax_lse_accum, out_accum;
std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts);
head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts);

// number of times random will be generated per thread, to offset philox counter in thc random
// state
Expand Down Expand Up @@ -536,10 +538,13 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const bool return_softmax,
c10::optional<at::Generator> gen_) {

auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};

auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
// bool is_sm75 = cc_major == 7 && cc_minor == 5;
bool is_sm8x = cc_major == 8 && cc_minor >= 0;
bool is_sm90 = cc_major == 9 && cc_minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
Expand Down Expand Up @@ -654,9 +659,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);

// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
at::Tensor p;
Expand Down Expand Up @@ -711,7 +713,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
std::tie(softmax_lse_accum, out_accum) =
set_params_splitkv(params, batch_size, num_heads, head_size,
max_seqlen_k, max_seqlen_q, head_size_rounded,
p_dropout, /*num_splits*/ 0, dprops, opts);
p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts);
}

if (leftpad_k_.has_value()) {
Expand Down Expand Up @@ -798,11 +800,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl
TORCH_CHECK(false, "This flash attention build does not support backward.");
#endif
if (is_causal) { window_size_right = 0; }
auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;

// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};

auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
// bool is_sm75 = cc_major == 7 && cc_minor == 5;
bool is_sm8x = cc_major == 8 && cc_minor >= 0;
bool is_sm80 = cc_major == 8 && cc_minor == 0;
bool is_sm90 = cc_major == 9 && cc_minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
Expand Down Expand Up @@ -895,9 +901,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl
// TODO: change later, for now set to true for simplicity
bool loop = true;

// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
at::Tensor dq_accum;
Expand All @@ -906,7 +909,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl
if (!deterministic) {
dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
} else {
const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads);
dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
}
// dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
Expand Down Expand Up @@ -1018,13 +1021,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
#endif

if (is_causal) { window_size_right = 0; }
auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;

// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};

auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
// bool is_sm75 = cc_major == 7 && cc_minor == 5;
bool is_sm8x = cc_major == 8 && cc_minor >= 0;
bool is_sm80 = cc_major == 8 && cc_minor == 0;
bool is_sm90 = cc_major == 9 && cc_minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
Expand Down Expand Up @@ -1122,9 +1128,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
// TODO: change later, for now set to true for simplicity
bool loop = true;

// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
at::Tensor dq_accum;
Expand All @@ -1141,7 +1144,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
if (!deterministic) {
dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
} else {
const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads);
dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
}
}
Expand Down Expand Up @@ -1251,10 +1254,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
int num_splits
) {

auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};

auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
// bool is_sm75 = cc_major == 7 && cc_minor == 5;
bool is_sm8x = cc_major == 8 && cc_minor >= 0;
bool is_sm90 = cc_major == 9 && cc_minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
Expand Down Expand Up @@ -1358,9 +1364,6 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);

// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();

auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
Expand Down Expand Up @@ -1471,12 +1474,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
}

// Keep references to these tensors to extend their lifetime
at::Tensor softmax_lse_accum, out_accum;
std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
head_size_rounded, /*dropout*/ 0.f, num_splits, dprops, opts);
head_size_rounded, /*dropout*/ 0.f, num_splits, get_num_sm(get_current_device()), opts);

if (paged_KV) {
params.block_table = block_table.data_ptr<int>();
Expand Down
8 changes: 1 addition & 7 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,7 @@
#include <cuda.h>
#include <vector>

#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif

#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
#include <ATen/cuda/PhiloxUtils.cuh> // For at::cuda::philox::unpack

constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
Expand Down
7 changes: 4 additions & 3 deletions csrc/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

#pragma once

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK

#include "static_switch.h"
#include "hardware_info.h"
#include "flash.h"
#include "flash_bwd_preprocess_kernel.h"
#include "flash_bwd_kernel.h"
Expand Down Expand Up @@ -72,8 +73,8 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
int gridDimx = num_n_block;
if (params.deterministic) {
auto dprops = at::cuda::getCurrentDeviceProperties();
gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
int num_sm = get_num_sm(get_current_device());
gridDimx = (num_sm + params.b * params.h - 1) / (params.b * params.h);
}
dim3 grid_n(gridDimx, params.b, params.h);

Expand Down
Loading

0 comments on commit 073afd5

Please # to comment.