Skip to content

Commit

Permalink
perf: use persistent kernel for merging attention states (#459)
Browse files Browse the repository at this point in the history
As observed by @MasterJH5574 , there are cases where our
`VariableLengthMergeStatesKernel` launches a lot of CTAs (>=10k) while
most of the CTAs only work on small number of merges, this PR fixes the
issue by using a persistent kernel.

There is still load imbalance issue, and I plan to resolve it inside
scheduler. I'll leave it for later PRs.
  • Loading branch information
yzh119 authored Aug 21, 2024
1 parent 048560d commit be6bf5b
Showing 1 changed file with 85 additions and 66 deletions.
151 changes: 85 additions & 66 deletions include/flashinfer/attention/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -301,85 +301,94 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
*/
template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stages, typename DTypeIn,
typename DTypeOut, typename IdType>
__global__ void VariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S,
IdType* indptr, DTypeOut* __restrict__ v_merged,
float* __restrict__ s_merged, uint32_t num_heads) {
__global__ void PersistentVariableLengthMergeStatesKernel(DTypeIn* __restrict__ V,
float* __restrict__ S, IdType* indptr,
DTypeOut* __restrict__ v_merged,
float* __restrict__ s_merged,
uint32_t seq_len, uint32_t num_heads) {
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t pos = blockIdx.x;
uint32_t head_idx = blockIdx.y;
state_t<vec_size> st;
uint32_t cta_id = blockIdx.x;
uint32_t num_ctas = gridDim.x;
uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas);
constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
constexpr uint32_t head_dim = vec_size * bdx;

extern __shared__ uint8_t smem[];
DTypeIn* v_smem = (DTypeIn*)smem;
float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn));
const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos];

if (num_index_sets == 0) {
vec_t<DTypeOut, vec_size> v;
v.fill(DTypeOut(0));
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = -5e4;
#pragma unroll 1
for (uint32_t i = cta_id; i < seq_len * num_heads; i += num_ctas) {
uint32_t pos = i / num_heads;
uint32_t head_idx = i % num_heads;
state_t<vec_size> st;
const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos];

if (num_index_sets == 0) {
vec_t<DTypeOut, vec_size> v;
v.fill(DTypeOut(0));
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = -5e4;
}
continue;
}
return;
}

if (num_index_sets == 1) {
vec_t<DTypeOut, vec_size> v;
v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size);
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx];
if (num_index_sets == 1) {
vec_t<DTypeOut, vec_size> v;
v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size);
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx];
}
continue;
}
}

#pragma unroll
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size,
(iter * bdy + ty) < num_index_sets);
cp_async::commit_group();
}
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size,
(iter * bdy + ty) < num_index_sets);
cp_async::commit_group();
}
#pragma unroll 4
for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) {
if (iter % bdx == 0) {
s_smem[ty * bdx + tx] =
iter * bdy + (ty * bdx + tx) < num_index_sets
? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
: 0.f;
for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) {
if (iter % bdx == 0) {
s_smem[ty * bdx + tx] =
iter * bdy + (ty * bdx + tx) < num_index_sets
? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
: 0.f;
__syncthreads();
}
cp_async::wait_group<num_smem_stages - 1>();
__syncthreads();
vec_t<float, vec_size> v;
v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
if (iter * bdy + ty < num_index_sets) {
float s = s_smem[(iter % bdx) * bdy + ty];
st.merge(v, s, 1);
}
__syncthreads();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
V +
((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) *
head_dim +
tx * vec_size,
(iter + num_smem_stages) * bdy + ty < num_index_sets);
cp_async::commit_group();
}
cp_async::wait_group<num_smem_stages - 1>();
__syncthreads();
vec_t<float, vec_size> v;
v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
if (iter * bdy + ty < num_index_sets) {
float s = s_smem[(iter % bdx) * bdy + ty];
st.merge(v, s, 1);
}
cp_async::wait_group<0>();
__syncthreads();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
V +
((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) *
head_dim +
tx * vec_size,
(iter + num_smem_stages) * bdy + ty < num_index_sets);
cp_async::commit_group();
}
cp_async::wait_group<0>();
__syncthreads();

st.normalize();
threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);
st.normalize();
st.normalize();
threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);
st.normalize();

st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = st.get_lse();
st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = st.get_lse();
}
}
}

Expand Down Expand Up @@ -502,19 +511,29 @@ template <typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeOut* v_merged,
float* s_merged, uint32_t seq_len, uint32_t num_heads,
uint32_t head_dim, cudaStream_t stream = nullptr) {
int dev_id = 0;
int num_sms = 0;
int num_blocks_per_sm = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));

DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U);
constexpr uint32_t bdx = HEAD_DIM / vec_size;
constexpr uint32_t num_threads = 128;
constexpr uint32_t bdy = num_threads / bdx;
dim3 nblks(seq_len, num_heads);
dim3 nthrs(bdx, bdy);
constexpr uint32_t num_smem_stages = 4;
auto kernel = VariableLengthMergeStatesKernel<vec_size, bdx, bdy, num_smem_stages, DTypeIn,
DTypeOut, IdType>;
void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &num_heads};
uint32_t smem_size =
num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + num_threads * sizeof(float);
auto kernel = PersistentVariableLengthMergeStatesKernel<vec_size, bdx, bdy, num_smem_stages,
DTypeIn, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(seq_len * num_heads, num_sms));

dim3 nblks(num_sms * num_blocks_per_sm);
dim3 nthrs(bdx, bdy);
void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &seq_len, &num_heads};
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
Expand Down

0 comments on commit be6bf5b

Please # to comment.