diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index b61b01d8..d33ad9a1 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -692,7 +692,8 @@ void _FlashInferBatchQKApplyRotaryInPlace(DLTensor* q, DLTensor* k, DLTensor* in cudaError_t status = BatchQKApplyRotaryInPlace( static_cast(q->data), static_cast(k->data), static_cast(indptr->data), static_cast(offsets->data), batch_size, - num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, + num_qo_heads, num_kv_heads, /*rotary_dim=*/head_dim, head_dim, q_stride_n, q_stride_h, + k_stride_n, k_stride_h, /*interleave=*/false, rope_scale, rope_theta); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status);