diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index ece23db9..8ca653a9 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -99,8 +99,8 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) { s[j] += math::shfl_xor_sync(s[j], offset); } - s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -5e4; s[j] = apply_logits_post_hook(s[j], logits_soft_cap); + s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -5e4; if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { s[j] += alibi_slope * float(int(kv_idx_base + tz * tile_size + j) - q_offset); }