diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 102d6d47..b8eff7d7 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -229,65 +229,61 @@ __global__ void BatchQKApplyRotaryPosIdsKernel( float smooth_b, float rope_rcp_scale, float rope_rcp_theta) { // NOTE: q and q_rope may be the same ptr, so do k and k_rope uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; - const uint32_t bdy = blockDim.y; - vec_t freq; + + const uint32_t idx = bx * blockDim.y + ty; + const uint32_t pos_idx = idx / (num_qo_heads + num_kv_heads); + if (pos_idx >= nnz) { + return; + } + + const IdType pos = pos_ids[pos_idx]; + + vec_t cos, sin; if (tx * vec_size < rotary_dim) { -#pragma unroll + #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { + float freq; if constexpr (interleave) { - freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(rotary_dim)); + freq = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(rotary_dim)); } else { - freq[i] = __powf(rope_rcp_theta, - float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / float(rotary_dim)); + freq = __powf(rope_rcp_theta, + float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / float(rotary_dim)); } - float smooth = freq[i] * smooth_a + smooth_b; + float smooth = freq * smooth_a + smooth_b; smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] - freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; - } - } - - vec_t cos, sin; + freq = (1 - smooth) * (freq * rope_rcp_scale) + smooth * freq; - if (bx * bdy + ty < nnz) { - const uint32_t idx = bx * bdy + ty; - const IdType pos = pos_ids[idx]; - - if (tx * vec_size < rotary_dim) { -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - float embed = float(pos) * freq[i]; - __sincosf(embed, &sin[i], &cos[i]); - } + const float embed = float(pos) * freq; + __sincosf(embed, &sin[i], &cos[i]); } + } -#pragma unroll 1 - for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { - DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h); - DType* q_rope_ptr = - q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h); - vec_t q_vec; - if constexpr (interleave) { - q_vec = vec_apply_llama_rope_cos_sin_interleave(q_ptr, cos, sin, rotary_dim); - } else { - q_vec = vec_apply_llama_rope_cos_sin(q_ptr, cos, sin, rotary_dim); - } - q_vec.cast_store(q_rope_ptr + tx * vec_size); + const uint32_t head_idx = idx % (num_qo_heads + num_kv_heads); + if (head_idx < num_qo_heads) { + const uint32_t qo_head_idx = head_idx; + DType* q_ptr = q + get_elem_offset_impl(pos_idx, qo_head_idx, 0, q_stride_n, q_stride_h); + DType* q_rope_ptr = + q_rope + get_elem_offset_impl(pos_idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h); + vec_t q_vec; + if constexpr (interleave) { + q_vec = vec_apply_llama_rope_cos_sin_interleave(q_ptr, cos, sin, rotary_dim); + } else { + q_vec = vec_apply_llama_rope_cos_sin(q_ptr, cos, sin, rotary_dim); } - -#pragma unroll 1 - for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) { - DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h); - DType* k_rope_ptr = - k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h); - vec_t k_vec; - if constexpr (interleave) { - k_vec = vec_apply_llama_rope_cos_sin_interleave(k_ptr, cos, sin, rotary_dim); - } else { - k_vec = vec_apply_llama_rope_cos_sin(k_ptr, cos, sin, rotary_dim); - } - k_vec.cast_store(k_rope_ptr + tx * vec_size); + q_vec.cast_store(q_rope_ptr + tx * vec_size); + } else { + const uint32_t kv_head_idx = head_idx - num_qo_heads; + DType* k_ptr = k + get_elem_offset_impl(pos_idx, kv_head_idx, 0, k_stride_n, k_stride_h); + DType* k_rope_ptr = + k_rope + get_elem_offset_impl(pos_idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h); + vec_t k_vec; + if constexpr (interleave) { + k_vec = vec_apply_llama_rope_cos_sin_interleave(k_ptr, cos, sin, rotary_dim); + } else { + k_vec = vec_apply_llama_rope_cos_sin(k_ptr, cos, sin, rotary_dim); } + k_vec.cast_store(k_rope_ptr + tx * vec_size); } } @@ -610,7 +606,7 @@ cudaError_t BatchQKApplyLlama31RotaryPosIds( constexpr uint32_t bdx = HEAD_DIM / vec_size; uint32_t num_threads = std::max(128U, bdx); uint32_t bdy = num_threads / bdx; - dim3 nblks((nnz + bdy - 1) / bdy); + dim3 nblks((nnz + bdy - 1) / bdy * (num_qo_heads + num_kv_heads)); dim3 nthrs(bdx, bdy); auto kernel = BatchQKApplyRotaryPosIdsKernel;