Skip to content

Commit

Permalink
[CI] Change torch #include to make it work with torch 2.1 Philox
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Dec 7, 2024
1 parent 073afd5 commit e782d28
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
#include <torch/nn/functional.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack

#include <cutlass/numeric_types.h>

Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <cuda.h>
#include <vector>

#include <ATen/cuda/PhiloxUtils.cuh> // For at::cuda::philox::unpack
#include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState

constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
Expand Down
2 changes: 2 additions & 0 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#pragma once

#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack

#include <cute/tensor.hpp>

#include <cutlass/cutlass.h>
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.7.1.post2"
__version__ = "2.7.1.post3"

from flash_attn.flash_attn_interface import (
flash_attn_func,
Expand Down

0 comments on commit e782d28

Please # to comment.