From 3dd940516f60fb2ccfc01daa91c23ef2a3b0284e Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 10 Nov 2024 15:43:51 -0800 Subject: [PATCH] hotfix: fix rope tvm wrapper (#601) The TVM wrapper was broken in #599 because of API changes, this PR fixes the issue. --- src/tvm_wrapper.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index b61b01d8..d33ad9a1 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -692,7 +692,8 @@ void _FlashInferBatchQKApplyRotaryInPlace(DLTensor* q, DLTensor* k, DLTensor* in cudaError_t status = BatchQKApplyRotaryInPlace( static_cast(q->data), static_cast(k->data), static_cast(indptr->data), static_cast(offsets->data), batch_size, - num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, + num_qo_heads, num_kv_heads, /*rotary_dim=*/head_dim, head_dim, q_stride_n, q_stride_h, + k_stride_n, k_stride_h, /*interleave=*/false, rope_scale, rope_theta); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status);