diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 8910e097..8d78c938 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -1558,7 +1558,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg // normalize d normalize_d(o_frag, m, d); - const uint32_t num_kv_chunks = ceil_div(kv_len, kv_chunk_size); + const uint32_t num_kv_chunks = ceil_div(max(kv_len, 1), kv_chunk_size); // write back write_o_reg_gmem( @@ -1872,7 +1872,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage // normalize d normalize_d(o_frag, m, d); - const uint32_t num_kv_chunks = ceil_div(kv_len, kv_chunk_size); + const uint32_t num_kv_chunks = ceil_div(max(kv_len, 1), kv_chunk_size); // write_back write_o_reg_gmem(