From f5f7a2a23249fd0be5b30fd8fb3957ac3bb527ca Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 2 Jul 2024 20:32:19 -0700 Subject: [PATCH] hotfix: fix the decode kernel with logits cap (#350) logits soft cap should be applied before masking. Thanks @LiuXiaoxuanPKU for spotting this bug. --- include/flashinfer/attention/decode.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index ece23db9..8ca653a9 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -99,8 +99,8 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) { s[j] += math::shfl_xor_sync(s[j], offset); } - s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -5e4; s[j] = apply_logits_post_hook(s[j], logits_soft_cap); + s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -5e4; if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { s[j] += alibi_slope * float(int(kv_idx_base + tz * tile_size + j) - q_offset); }