Skip to content

Commit

Permalink
hotfix: fix rope tvm wrapper (#601)
Browse files Browse the repository at this point in the history
The TVM wrapper was broken in #599 because of API changes, this PR fixes
the issue.
  • Loading branch information
yzh119 authored Nov 10, 2024
1 parent eb9bc71 commit 3dd9405
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/tvm_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,8 @@ void _FlashInferBatchQKApplyRotaryInPlace(DLTensor* q, DLTensor* k, DLTensor* in
cudaError_t status = BatchQKApplyRotaryInPlace(
static_cast<dtype*>(q->data), static_cast<dtype*>(k->data),
static_cast<idtype*>(indptr->data), static_cast<idtype*>(offsets->data), batch_size,
num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h,
num_qo_heads, num_kv_heads, /*rotary_dim=*/head_dim, head_dim, q_stride_n, q_stride_h,
k_stride_n, k_stride_h,
/*interleave=*/false, rope_scale, rope_theta);
if (status != cudaSuccess) {
LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status);
Expand Down

0 comments on commit 3dd9405

Please # to comment.