From 595cf602e73688d2f96f8cf1aad7cb2fce689d41 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 10 Nov 2024 16:24:18 -0800 Subject: [PATCH] perf: fix prefill kernel performance degradation (step 1) (#602) The prefill attention kernel performance has degraded significantly in recent releases (since v0.1.2), especially on A100 when `causal=True`, this is mainly because we add new attention variants (which increases register usage thus incurs register spilling) and move some parameters from compile-time to runtime. This PR alleviates the issue by caching some of the variables regarding GQA group size. In the next PR, we will support another mode `kv_head_major` in addition to `qo_head_major`, to further accelerate GQA prefill with query size >= 64. cc @AKKamath --- include/flashinfer/attention/prefill.cuh | 38 ++++++++++++++++-------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 1d720473..c9d760ae 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -623,18 +623,25 @@ __device__ __forceinline__ void logits_transform(const typename AttentionVariant const uint_fastdiv group_size, DTypeQKAccum (*s_frag)[NUM_FRAGS_KV][8]) { const uint32_t lane_idx = threadIdx.x, kv_head_idx = blockIdx.z; + uint32_t q[NUM_FRAGS_Q][2], r[NUM_FRAGS_Q][2]; +#pragma unroll + for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + group_size.divmod(qo_packed_idx_base + fq * 16 + lane_idx / 4 + 8 * j, q[fq][j], r[fq][j]); + } + } + #pragma unroll for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { #pragma unroll for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { - uint32_t q, r; - group_size.divmod(qo_packed_idx_base + fq * 16 + lane_idx / 4 + 8 * ((reg_id % 4) / 2), q, - r); - const uint32_t q_idx = q, kv_idx = kv_idx_base + fkv * 16 + 2 * (lane_idx % 4) + - 8 * (reg_id / 4) + reg_id % 2; - const uint32_t qo_head_idx = kv_head_idx * group_size + r; + const uint32_t q_idx = q[fq][(reg_id % 4) / 2], kv_idx = kv_idx_base + fkv * 16 + + 2 * (lane_idx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + const uint32_t qo_head_idx = kv_head_idx * group_size + r[fq][(reg_id % 4) / 2]; s_frag[fq][fkv][reg_id] = variant.LogitsTransform( params, s_frag[fq][fkv][reg_id], batch_idx, q_idx, kv_idx, qo_head_idx, kv_head_idx); } @@ -652,18 +659,25 @@ __device__ __forceinline__ void logits_mask(const typename AttentionVariant::Par const uint_fastdiv group_size, DTypeQKAccum (*s_frag)[NUM_FRAGS_KV][8]) { const uint32_t lane_idx = threadIdx.x, kv_head_idx = blockIdx.z; + uint32_t q[NUM_FRAGS_Q][2], r[NUM_FRAGS_Q][2]; +#pragma unroll + for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + group_size.divmod(qo_packed_idx_base + fq * 16 + lane_idx / 4 + 8 * j, q[fq][j], r[fq][j]); + } + } + #pragma unroll for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { #pragma unroll for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { - uint32_t q, r; - group_size.divmod(qo_packed_idx_base + fq * 16 + lane_idx / 4 + 8 * ((reg_id % 4) / 2), q, - r); - const uint32_t q_idx = q, kv_idx = kv_idx_base + fkv * 16 + 2 * (lane_idx % 4) + - 8 * (reg_id / 4) + reg_id % 2; - const uint32_t qo_head_idx = kv_head_idx * group_size + r; + const uint32_t q_idx = q[fq][(reg_id % 4) / 2], kv_idx = kv_idx_base + fkv * 16 + + 2 * (lane_idx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + const uint32_t qo_head_idx = kv_head_idx * group_size + r[fq][(reg_id % 4) / 2]; const bool mask = (!(MASK_MODE == MaskMode::kCausal ? (kv_idx + qo_len > kv_len + q_idx || (kv_idx >= chunk_end))