Skip to content

Commit

Permalink
hotfix: fix the decode kernel with logits cap (#350)
Browse files Browse the repository at this point in the history
logits soft cap should be applied before masking.

Thanks @LiuXiaoxuanPKU for spotting this bug.
  • Loading branch information
yzh119 authored Jul 3, 2024
1 parent dc2c76f commit f5f7a2a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
s[j] += math::shfl_xor_sync(s[j], offset);
}
s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -5e4;
s[j] = apply_logits_post_hook<logits_post_hook>(s[j], logits_soft_cap);
s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -5e4;
if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) {
s[j] += alibi_slope * float(int(kv_idx_base + tz * tile_size + j) - q_offset);
}
Expand Down

0 comments on commit f5f7a2a

Please # to comment.