From 74ffba1d1b946fcd3536b7637a4e1a999e5a5d3e Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 28 Jul 2024 21:39:24 -0700 Subject: [PATCH] feat: non-inplace rope operators (#405) As requested in #403, this PR implements non-inplace rope operators. --- docs/api/python/rope.rst | 2 + include/flashinfer/attention/prefill.cuh | 4 +- include/flashinfer/pos_enc.cuh | 174 +++++++++++++++++++++++ python/csrc/flashinfer_ops.cu | 2 + python/csrc/flashinfer_ops.h | 10 ++ python/csrc/rope.cu | 100 ++++++++++++- python/csrc/single_prefill.cu | 2 +- python/flashinfer/__init__.py | 7 +- python/flashinfer/rope.py | 133 +++++++++++++++++ python/tests/test_rope.py | 138 ++++++++++++++++-- 10 files changed, 555 insertions(+), 17 deletions(-) diff --git a/docs/api/python/rope.rst b/docs/api/python/rope.rst index b27ac7e9..23ea1172 100644 --- a/docs/api/python/rope.rst +++ b/docs/api/python/rope.rst @@ -12,3 +12,5 @@ Kernels for applying rotary embeddings. apply_rope_inplace apply_llama31_rope_inplace + apply_rope + apply_llama31_rope diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 34ae1180..a4f6cbc6 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -1440,7 +1440,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg write_o_reg_gmem( o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, /*o_stride_n=*/ - partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim, + partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim, /*o_stride_h=*/head_dim, group_size); // write lse @@ -1732,7 +1732,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage write_o_reg_gmem( o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, /*o_stride_n=*/ - partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim, + partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim, /*o_stride_h=*/head_dim, group_size); // write lse diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index efa5c8bf..15b4a8d9 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -191,6 +191,86 @@ __global__ void BatchQKApplyRotaryInPlaceKernel( } } +template +__global__ void BatchQKApplyRotaryKernel(DType* __restrict__ q, DType* __restrict__ k, + DType* __restrict__ q_rope, DType* __restrict__ k_rope, + IdType* __restrict__ indptr, IdType* __restrict__ offsets, + uint32_t batch_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, size_t q_stride_n, + size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + float smooth_a, float smooth_b, float rope_rcp_scale, + float rope_rcp_theta) { + uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; + const uint32_t bdy = blockDim.y; + vec_t freq; +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + if constexpr (interleave) { + freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(head_dim)); + } else { + freq[i] = __powf(rope_rcp_theta, + float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); + } + + float smooth = freq[i] * smooth_a + smooth_b; + smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] + freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; + } + + if (bx < batch_size * num_qo_heads) { + // apply rotary to q + const uint32_t batch_idx = bx / num_qo_heads; + const uint32_t qo_head_idx = bx % num_qo_heads; + const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx]; + const uint32_t offset = offsets[batch_idx]; +#pragma unroll 2 + for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) { + vec_t q_vec; + if (i * bdy + ty < seq_len) { + DType* q_ptr = q + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0, + q_stride_n, q_stride_h); + DType* q_rope_ptr = + q_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0, + /*q_stride_n=*/num_qo_heads * head_dim, + /*q_stride_h=*/head_dim); + if constexpr (interleave) { + q_vec = + vec_apply_llama_rope_interleave(q_ptr, freq, offset + i * bdy + ty); + } else { + q_vec = vec_apply_llama_rope(q_ptr, freq, offset + i * bdy + ty); + } + q_vec.cast_store(q_rope_ptr + tx * vec_size); + } + } + } else { + // apply rotary to k + uint32_t batch_idx = (bx - batch_size * num_qo_heads) / num_kv_heads; + uint32_t kv_head_idx = (bx - batch_size * num_qo_heads) % num_kv_heads; + const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx]; + const uint32_t offset = offsets[batch_idx]; +#pragma unroll 2 + for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) { + vec_t k_vec; + if (i * bdy + ty < seq_len) { + DType* k_ptr = k + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0, + k_stride_n, k_stride_h); + DType* k_rope_ptr = + k_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0, + /*kv_stride_n=*/num_kv_heads * head_dim, + /*kv_stride_h=*/head_dim); + if constexpr (interleave) { + k_vec = + vec_apply_llama_rope_interleave(k_ptr, freq, offset + i * bdy + ty); + } else { + k_vec = vec_apply_llama_rope(k_ptr, freq, offset + i * bdy + ty); + } + k_vec.cast_store(k_rope_ptr + +tx * vec_size); + } + } + } +} + #define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ if (interleave) { \ const bool INTERLEAVE = true; \ @@ -289,6 +369,100 @@ cudaError_t BatchQKApplyLlama31RotaryInPlace( return cudaSuccess; } +template +cudaError_t BatchQKApplyRotary(DType* __restrict__ q, DType* __restrict__ k, + DType* __restrict__ q_rope, DType* __restrict__ k_rope, + IdType* __restrict__ indptr, IdType* __restrict__ offsets, + uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, + size_t k_stride_n, size_t k_stride_h, bool interleave, + float rope_scale, float rope_theta, cudaStream_t stream = nullptr) { + float rope_rcp_scale = 1.0f / rope_scale; + float rope_rcp_theta = 1.0f / rope_theta; + float smooth_a = 0.f; + float smooth_b = 0.f; + + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + uint32_t num_threads = std::max(128U, bdx); + uint32_t bdy = num_threads / bdx; + dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); + dim3 nthrs(bdx, bdy); + auto kernel = BatchQKApplyRotaryKernel; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&q_rope, + (void*)&k_rope, + (void*)&indptr, + (void*)&offsets, + (void*)&batch_size, + (void*)&num_qo_heads, + (void*)&num_kv_heads, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&k_stride_n, + (void*)&k_stride_h, + (void*)&smooth_a, + (void*)&smooth_b, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + }); + }); + + return cudaSuccess; +} + +template +cudaError_t BatchQKApplyLlama31Rotary(DType* __restrict__ q, DType* __restrict__ k, + DType* __restrict__ q_rope, DType* __restrict__ k_rope, + IdType* __restrict__ indptr, IdType* __restrict__ offsets, + uint32_t batch_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n, + size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + bool interleave, float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, + float old_context_length, cudaStream_t stream = nullptr) { + float rope_rcp_scale = 1.0f / rope_scale; + float rope_rcp_theta = 1.0f / rope_theta; + float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); + float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f); + + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + uint32_t num_threads = std::max(128U, bdx); + uint32_t bdy = num_threads / bdx; + dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); + dim3 nthrs(bdx, bdy); + auto kernel = BatchQKApplyRotaryKernel; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&q_rope, + (void*)&k_rope, + (void*)&indptr, + (void*)&offsets, + (void*)&batch_size, + (void*)&num_qo_heads, + (void*)&num_kv_heads, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&k_stride_n, + (void*)&k_stride_h, + (void*)&smooth_a, + (void*)&smooth_b, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + }); + }); + + return cudaSuccess; +} + } // namespace flashinfer #endif // FLASHINFER_POS_ENC_CUH_ diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 79c34b21..9215209b 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -45,6 +45,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, "Apply Llama 3.1 style RoPE in-place"); + m.def("apply_rope", &apply_rope, "Apply RoPE"); + m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); m.def("packbits", &packbits, "GPU packbits operator"); m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); py::class_(m, diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 32617c69..0f526c26 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -83,6 +83,16 @@ void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor float rope_theta, float low_freq_factor, float high_freq_factor, float old_context_length); +std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta); + +std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, + torch::Tensor indptr, torch::Tensor offsets, + bool interleave, float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, + float old_context_length); + torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, diff --git a/python/csrc/rope.cu b/python/csrc/rope.cu index 7fb9f483..572d7e9c 100644 --- a/python/csrc/rope.cu +++ b/python/csrc/rope.cu @@ -102,4 +102,102 @@ void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor std::string(cudaGetErrorString(status))); return true; }); -} \ No newline at end of file +} + +std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta) { + CHECK_CUDA(q); // not necessarily contiguous + CHECK_CUDA(k); // not necessarily contiguous + CHECK_INPUT(indptr); + CHECK_INPUT(offsets); + + auto device = q.device(); + CHECK_EQ(k.device(), device); + CHECK_DIM(3, q); // q: (nnz, H_Q, D) + CHECK_DIM(3, k); // k: (nnz, H_K, D) + CHECK_DIM(1, indptr); // indptr: (B + 1) + CHECK_DIM(1, offsets); // offsets: (B) + CHECK_EQ(q.size(0), k.size(0)); + CHECK_EQ(q.size(2), k.size(2)); + unsigned int num_qo_heads = q.size(1); + unsigned int num_kv_heads = k.size(1); + unsigned int head_dim = q.size(2); + unsigned int batch_size = offsets.size(0); + CHECK_EQ(indptr.size(0), batch_size + 1); + size_t q_stride_n = q.stride(0); + size_t q_stride_h = q.stride(1); + size_t k_stride_n = k.stride(0); + size_t k_stride_h = k.stride(1); + indptr = indptr.to(torch::kInt32); + offsets = offsets.to(torch::kInt32); + // NOTE(Zihao): empty_like do not copy strides so it's okay to use it here. + auto q_rope = torch::empty_like(q); + auto k_rope = torch::empty_like(k); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { + cudaError_t status = BatchQKApplyRotary( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), + static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), + batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, + k_stride_h, interleave, rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotary failed with error code " + + std::string(cudaGetErrorString(status))); + return true; + }); + + return {q_rope, k_rope}; +} + +std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, + torch::Tensor indptr, torch::Tensor offsets, + bool interleave, float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, + float old_context_length) { + CHECK_CUDA(q); // not necessarily contiguous + CHECK_CUDA(k); // not necessarily contiguous + CHECK_INPUT(indptr); + CHECK_INPUT(offsets); + + auto device = q.device(); + CHECK_EQ(k.device(), device); + CHECK_DIM(3, q); // q: (nnz, H_Q, D) + CHECK_DIM(3, k); // k: (nnz, H_K, D) + CHECK_DIM(1, indptr); // indptr: (B + 1) + CHECK_DIM(1, offsets); // offsets: (B) + CHECK_EQ(q.size(0), k.size(0)); + CHECK_EQ(q.size(2), k.size(2)); + unsigned int num_qo_heads = q.size(1); + unsigned int num_kv_heads = k.size(1); + unsigned int head_dim = q.size(2); + unsigned int batch_size = offsets.size(0); + CHECK_EQ(indptr.size(0), batch_size + 1); + size_t q_stride_n = q.stride(0); + size_t q_stride_h = q.stride(1); + size_t k_stride_n = k.stride(0); + size_t k_stride_h = k.stride(1); + indptr = indptr.to(torch::kInt32); + offsets = offsets.to(torch::kInt32); + + // NOTE(Zihao): empty_like do not copy strides so it's okay to use it here. + auto q_rope = torch::empty_like(q); + auto k_rope = torch::empty_like(k); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { + cudaError_t status = BatchQKApplyLlama31Rotary( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), + static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), + batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, + k_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, + old_context_length, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31Rotary failed with error code " + + std::string(cudaGetErrorString(status))); + return true; + }); + + return {q_rope, k_rope}; +} diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index 37d1a838..9dfc740a 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -71,7 +71,7 @@ std::vector single_prefill_with_kv_cache( TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index db818d98..ffe9a545 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -44,7 +44,12 @@ chain_speculative_sampling, ) from .norm import rmsnorm -from .rope import apply_rope_inplace, apply_llama31_rope_inplace +from .rope import ( + apply_rope_inplace, + apply_llama31_rope_inplace, + apply_rope, + apply_llama31_rope, +) from .group_gemm import SegmentGEMMWrapper from .quantization import packbits, segment_packbits diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py index 6bb67eb8..3aad81a2 100644 --- a/python/flashinfer/rope.py +++ b/python/flashinfer/rope.py @@ -147,3 +147,136 @@ def apply_llama31_rope_inplace( high_freq_factor, float(old_context_len), ) + + +def apply_rope( + q: torch.Tensor, + k: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool = False, + rope_scale: float = 1, + rope_theta: float = 1e4, +) -> None: + r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). + + We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th + segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the + i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always + 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. + Please see :ref:`Ragged Tensor tutorial ` for more details about the + ragged tensor. + + Parameters + ---------- + q : torch.Tensor + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last + element of ``indptr``. + k : torch.Tensor + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + indptr : torch.Tensor + Indptr tensor, shape: ``(batch_size + 1)``. + offsets : torch.Tensor + The relative position offsets of each query in the batch, shape: ``(batch_size)``. + interleave : bool + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + + rope_scale : float + The scaling factor used in the rope embedding, default: ``1``. + rope_theta : float + The theta value used in the rope embedding, default: ``1e4``. + + Returns + ------- + q_rope : torch.Tensor + The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``. + k_rope : torch.Tensor + The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``. + """ + return _kernels.apply_rope( + q, k, indptr, offsets, interleave, rope_scale, rope_theta + ) + + +def apply_llama31_rope( + q: torch.Tensor, + k: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + interleave: bool = True, + rope_scale: float = 8, + rope_theta: float = 5e5, + low_freq_factor: float = 1, + high_freq_factor: float = 4, + old_context_len: int = 8192, +) -> None: + r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as + RaggedTensor). + + We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th + segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the + i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always + 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. + Please see :ref:`Ragged Tensor tutorial ` for more details about the + ragged tensor. + + Parameters + ---------- + q : torch.Tensor + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + k : torch.Tensor + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. + indptr : torch.Tensor + Indptr tensor, shape: ``(batch_size + 1)``. + offsets : torch.Tensor + The relative position offsets of each query in the batch, shape: ``(batch_size)``. + interleave : bool + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + + rope_scale : float + The scaling factor used in the rope embedding, default: ``8``. + rope_theta : float + The theta value used in the rope embedding, default: ``5e5``. + low_freq_factor : float + The low frequency factor used in Llama 3.1 RoPE, default: ``1``. + high_freq_factor : float + The high frequency factor used in Llama 3.1 RoPE, default: ``4``. + old_context_len : int + The old context length used in Llama 3.1 RoPE, default: ``8192``. + + Returns + ------- + q_rope : torch.Tensor + The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``. + k_rope : torch.Tensor + The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``. + """ + return _kernels.apply_llama31_rope( + q, + k, + indptr, + offsets, + interleave, + rope_scale, + rope_theta, + low_freq_factor, + high_freq_factor, + float(old_context_len), + ) diff --git a/python/tests/test_rope.py b/python/tests/test_rope.py index e0676126..e49826e4 100644 --- a/python/tests/test_rope.py +++ b/python/tests/test_rope.py @@ -27,7 +27,7 @@ @pytest.mark.parametrize("num_kv_heads", [8]) @pytest.mark.parametrize("offset", [0, 15, 99]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) -def test_llama_rope( +def test_llama_rope_inplace( batch_size, qkv_len, num_qo_heads, @@ -55,13 +55,13 @@ def test_llama_rope( freqs_cis = precompute_freqs_cis( head_dim, qkv_len + offset, 10000.0, use_scaled=False ).to("cuda:0") - q_rope, k_rope = apply_rotary_emb( + q_rope_ref, k_rope_ref = apply_rotary_emb( q.reshape(batch_size, qkv_len, num_qo_heads, head_dim), k.reshape(batch_size, qkv_len, num_kv_heads, head_dim), freqs_cis[offset : offset + qkv_len], ) - q_rope = q_rope.reshape(nnz, num_qo_heads, head_dim) - k_rope = k_rope.reshape(nnz, num_kv_heads, head_dim) + q_rope_ref = q_rope_ref.reshape(nnz, num_qo_heads, head_dim) + k_rope_ref = k_rope_ref.reshape(nnz, num_kv_heads, head_dim) # flashinfer implementation flashinfer.apply_rope_inplace( @@ -70,10 +70,10 @@ def test_llama_rope( # compare np.testing.assert_allclose( - q_rope.cpu().numpy(), q.cpu().numpy(), rtol=1e-3, atol=1e-3 + q_rope_ref.cpu().numpy(), q.cpu().numpy(), rtol=1e-3, atol=1e-3 ) np.testing.assert_allclose( - k_rope.cpu().numpy(), k.cpu().numpy(), rtol=1e-3, atol=1e-3 + k_rope_ref.cpu().numpy(), k.cpu().numpy(), rtol=1e-3, atol=1e-3 ) @@ -83,7 +83,63 @@ def test_llama_rope( @pytest.mark.parametrize("num_kv_heads", [8]) @pytest.mark.parametrize("offset", [0, 15, 99]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) -def test_llama31_rope( +def test_llama_rope( + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + offset, + head_dim, +): + nnz = batch_size * qkv_len + qkv_packed = torch.randn( + nnz, + (num_qo_heads + 2 * num_kv_heads) * head_dim, + dtype=torch.float16, + device="cuda:0", + ) + q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) + k = qkv_packed[ + :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim + ].reshape(nnz, num_kv_heads, head_dim) + indptr = torch.tensor( + [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" + ) + offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") + + # reference implementation + freqs_cis = precompute_freqs_cis( + head_dim, qkv_len + offset, 10000.0, use_scaled=False + ).to("cuda:0") + q_rope_ref, k_rope_ref = apply_rotary_emb( + q.reshape(batch_size, qkv_len, num_qo_heads, head_dim), + k.reshape(batch_size, qkv_len, num_kv_heads, head_dim), + freqs_cis[offset : offset + qkv_len], + ) + q_rope_ref = q_rope_ref.reshape(nnz, num_qo_heads, head_dim) + k_rope_ref = k_rope_ref.reshape(nnz, num_kv_heads, head_dim) + + # flashinfer implementation + q_rope, k_rope = flashinfer.apply_rope( + q, k, indptr, offsets, interleave=True, rope_theta=1e4 + ) + + # compare + np.testing.assert_allclose( + q_rope_ref.cpu().numpy(), q_rope.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + np.testing.assert_allclose( + k_rope_ref.cpu().numpy(), k_rope.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("qkv_len", [1, 4, 19, 204]) +@pytest.mark.parametrize("num_qo_heads", [8, 16]) +@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("offset", [0, 15, 99]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_llama31_rope_inplace( batch_size, qkv_len, num_qo_heads, @@ -111,13 +167,13 @@ def test_llama31_rope( freqs_cis = precompute_freqs_cis( head_dim, qkv_len + offset, 5e5, use_scaled=True ).to("cuda:0") - q_rope, k_rope = apply_rotary_emb( + q_rope_ref, k_rope_ref = apply_rotary_emb( q.reshape(batch_size, qkv_len, num_qo_heads, head_dim), k.reshape(batch_size, qkv_len, num_kv_heads, head_dim), freqs_cis[offset : offset + qkv_len], ) - q_rope = q_rope.reshape(nnz, num_qo_heads, head_dim) - k_rope = k_rope.reshape(nnz, num_kv_heads, head_dim) + q_rope_ref = q_rope_ref.reshape(nnz, num_qo_heads, head_dim) + k_rope_ref = k_rope_ref.reshape(nnz, num_kv_heads, head_dim) # flashinfer implementation flashinfer.apply_llama31_rope_inplace( @@ -126,13 +182,71 @@ def test_llama31_rope( # compare np.testing.assert_allclose( - q_rope.cpu().numpy(), q.cpu().numpy(), rtol=1e-3, atol=1e-3 + q_rope_ref.cpu().numpy(), q.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + np.testing.assert_allclose( + k_rope_ref.cpu().numpy(), k.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("qkv_len", [1, 4, 19, 204]) +@pytest.mark.parametrize("num_qo_heads", [8, 16]) +@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("offset", [0, 15, 99]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_llama31_rope( + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + offset, + head_dim, +): + nnz = batch_size * qkv_len + qkv_packed = torch.randn( + nnz, + (num_qo_heads + 2 * num_kv_heads) * head_dim, + dtype=torch.float16, + device="cuda:0", + ) + q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) + k = qkv_packed[ + :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim + ].reshape(nnz, num_kv_heads, head_dim) + indptr = torch.tensor( + [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" + ) + offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") + + # reference implementation + freqs_cis = precompute_freqs_cis( + head_dim, qkv_len + offset, 5e5, use_scaled=True + ).to("cuda:0") + q_rope_ref, k_rope_ref = apply_rotary_emb( + q.reshape(batch_size, qkv_len, num_qo_heads, head_dim), + k.reshape(batch_size, qkv_len, num_kv_heads, head_dim), + freqs_cis[offset : offset + qkv_len], + ) + q_rope_ref = q_rope_ref.reshape(nnz, num_qo_heads, head_dim) + k_rope_ref = k_rope_ref.reshape(nnz, num_kv_heads, head_dim) + + # flashinfer implementation + q_rope, k_rope = flashinfer.apply_llama31_rope( + q, k, indptr, offsets, interleave=True, rope_theta=5e5 + ) + + # compare + np.testing.assert_allclose( + q_rope_ref.cpu().numpy(), q_rope.cpu().numpy(), rtol=1e-3, atol=1e-3 ) np.testing.assert_allclose( - k_rope.cpu().numpy(), k.cpu().numpy(), rtol=1e-3, atol=1e-3 + k_rope_ref.cpu().numpy(), k_rope.cpu().numpy(), rtol=1e-3, atol=1e-3 ) if __name__ == "__main__": + test_llama_rope_inplace(2, 1, 8, 8, 1, 128) + test_llama31_rope_inplace(1, 1, 8, 8, 0, 128) test_llama_rope(2, 1, 8, 8, 1, 128) test_llama31_rope(1, 1, 8, 8, 0, 128)