Skip to content

Commit

Permalink
perf: fix prefill kernel performance degradation (step 1) (#602)
Browse files Browse the repository at this point in the history
The prefill attention kernel performance has degraded significantly in
recent releases (since v0.1.2), especially on A100 when `causal=True`,
this is mainly because we add new attention variants (which increases
register usage thus incurs register spilling) and move some parameters
from compile-time to runtime.

This PR alleviates the issue by caching some of the variables regarding
GQA group size.
In the next PR, we will support another mode `kv_head_major` in addition
to `qo_head_major`, to further accelerate GQA prefill with query size >=
64.

cc @AKKamath
  • Loading branch information
yzh119 authored Nov 11, 2024
1 parent 3dd9405 commit 595cf60
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -623,18 +623,25 @@ __device__ __forceinline__ void logits_transform(const typename AttentionVariant
const uint_fastdiv group_size,
DTypeQKAccum (*s_frag)[NUM_FRAGS_KV][8]) {
const uint32_t lane_idx = threadIdx.x, kv_head_idx = blockIdx.z;
uint32_t q[NUM_FRAGS_Q][2], r[NUM_FRAGS_Q][2];
#pragma unroll
for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
group_size.divmod(qo_packed_idx_base + fq * 16 + lane_idx / 4 + 8 * j, q[fq][j], r[fq][j]);
}
}

#pragma unroll
for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) {
#pragma unroll
for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) {
#pragma unroll
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
uint32_t q, r;
group_size.divmod(qo_packed_idx_base + fq * 16 + lane_idx / 4 + 8 * ((reg_id % 4) / 2), q,
r);
const uint32_t q_idx = q, kv_idx = kv_idx_base + fkv * 16 + 2 * (lane_idx % 4) +
8 * (reg_id / 4) + reg_id % 2;
const uint32_t qo_head_idx = kv_head_idx * group_size + r;
const uint32_t q_idx = q[fq][(reg_id % 4) / 2], kv_idx = kv_idx_base + fkv * 16 +
2 * (lane_idx % 4) +
8 * (reg_id / 4) + reg_id % 2;
const uint32_t qo_head_idx = kv_head_idx * group_size + r[fq][(reg_id % 4) / 2];
s_frag[fq][fkv][reg_id] = variant.LogitsTransform(
params, s_frag[fq][fkv][reg_id], batch_idx, q_idx, kv_idx, qo_head_idx, kv_head_idx);
}
Expand All @@ -652,18 +659,25 @@ __device__ __forceinline__ void logits_mask(const typename AttentionVariant::Par
const uint_fastdiv group_size,
DTypeQKAccum (*s_frag)[NUM_FRAGS_KV][8]) {
const uint32_t lane_idx = threadIdx.x, kv_head_idx = blockIdx.z;
uint32_t q[NUM_FRAGS_Q][2], r[NUM_FRAGS_Q][2];
#pragma unroll
for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
group_size.divmod(qo_packed_idx_base + fq * 16 + lane_idx / 4 + 8 * j, q[fq][j], r[fq][j]);
}
}

#pragma unroll
for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) {
#pragma unroll
for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) {
#pragma unroll
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
uint32_t q, r;
group_size.divmod(qo_packed_idx_base + fq * 16 + lane_idx / 4 + 8 * ((reg_id % 4) / 2), q,
r);
const uint32_t q_idx = q, kv_idx = kv_idx_base + fkv * 16 + 2 * (lane_idx % 4) +
8 * (reg_id / 4) + reg_id % 2;
const uint32_t qo_head_idx = kv_head_idx * group_size + r;
const uint32_t q_idx = q[fq][(reg_id % 4) / 2], kv_idx = kv_idx_base + fkv * 16 +
2 * (lane_idx % 4) +
8 * (reg_id / 4) + reg_id % 2;
const uint32_t qo_head_idx = kv_head_idx * group_size + r[fq][(reg_id % 4) / 2];
const bool mask =
(!(MASK_MODE == MaskMode::kCausal
? (kv_idx + qo_len > kv_len + q_idx || (kv_idx >= chunk_end))
Expand Down

0 comments on commit 595cf60

Please # to comment.