diff --git a/benchmarks/bench_rope.py b/benchmarks/bench_rope.py new file mode 100644 index 00000000..d87af5f1 --- /dev/null +++ b/benchmarks/bench_rope.py @@ -0,0 +1,93 @@ +import argparse +from typing import cast + +import torch +from triton.testing import do_bench + +import flashinfer + + +def generate_cos_sin_f32_cache(max_seq_len, head_dim, theta=1e4): + position = torch.arange(max_seq_len).float().unsqueeze(1) + freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) + freqs = torch.cat([freqs, freqs], dim=-1).contiguous() + args = position * freqs + sin_cache = torch.sin(args) + cos_cache = torch.cos(args) + return cos_cache, sin_cache + + +@torch.inference_mode() +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--batch-sizes", nargs="+", type=int, default=[1, 19, 99, 128]) + parser.add_argument("--append-len", nargs="+", type=int, default=[1, 128, 1024]) + parser.add_argument("--num-qo-heads", type=int, default=32) + parser.add_argument("--num-kv-heads", type=int, default=8) + parser.add_argument("--head-dim", type=int, default=128) + args = parser.parse_args() + + eps = 1e-6 + dtype = torch.float16 + num_qo_heads = args.num_qo_heads + num_kv_heads = args.num_kv_heads + head_dim = args.head_dim + + # Loop over each combination of batch_size, hidden_size, and dtype + for batch_size in args.batch_sizes: + for append_len in args.append_len: + for use_cos_sin_cache in [False, True]: + # Define tensors with the correct dtype + + q = torch.randn( + (batch_size * append_len, num_qo_heads, args.head_dim), + dtype=dtype, + device="cuda", + ) + k = torch.randn( + (batch_size * append_len, num_kv_heads, args.head_dim), + dtype=dtype, + device="cuda", + ) + pos_ids = torch.repeat_interleave( + torch.arange(append_len, dtype=torch.int32, device=q.device), + batch_size, + ) + cos_cache, sin_cache = generate_cos_sin_f32_cache(4096, head_dim) + cos_cache = cos_cache.to(q.device) + sin_cache = sin_cache.to(q.device) + + @torch.cuda.nvtx.range( + f"apply_rope batch_size={batch_size}, append_len={append_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}" + ) + def fn() -> None: + if use_cos_sin_cache: + flashinfer.apply_rope_with_cos_sin_cache( + q, k, cos_cache, sin_cache, pos_ids + ) + else: + flashinfer.apply_rope_pos_ids(q, k, pos_ids) + + # Run benchmarking + latency_ms = cast(float, do_bench(fn)) + throughput = ( + q.numel() * q.element_size() * 2 + k.numel() * k.element_size() * 2 + ) / (latency_ms * 1e-3) + print( + f"batch_size: {batch_size:3},", + f"append_len: {append_len:5},", + f"num_qo_heads: {num_qo_heads:5},", + f"num_kv_heads: {num_kv_heads:5},", + f"head_dim: {head_dim:5},", + f"use_cos_sin_cache: {use_cos_sin_cache},", + f"latency: {latency_ms*1e3:2.0f}us,", + f"throughput: {throughput*1e-9:7.3f}GB/s", + ) + + print("---") + + torch.cuda.profiler.stop() + + +if __name__ == "__main__": + main() diff --git a/include/flashinfer/cutlass_utils.cuh b/include/flashinfer/cutlass_utils.cuh index f7811f95..f6d3ef03 100644 --- a/include/flashinfer/cutlass_utils.cuh +++ b/include/flashinfer/cutlass_utils.cuh @@ -16,9 +16,6 @@ #ifndef FLASHINFER_CUTLASS_UTILS_CUH_ #define FLASHINFER_CUTLASS_UTILS_CUH_ -#include -#include - #include "cute/tensor.hpp" #include "cutlass/cutlass.h" #include "cutlass/epilogue/collective/collective_builder.hpp" diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 829065a4..102d6d47 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -77,40 +77,45 @@ __device__ __forceinline__ float get_alibi_slope(uint32_t head_idx, uint32_t num */ template __device__ __forceinline__ vec_t vec_apply_llama_rope( - const T* x, const vec_t& freq, int32_t offset) { - constexpr uint32_t head_dim = vec_size * bdx; + const T* x, const vec_t& freq, int32_t offset, + const uint32_t rotary_dim = vec_size * bdx) { vec_t permuted_vec, vec; vec.cast_load(x + threadIdx.x * vec_size); - permuted_vec.cast_load(x + ((threadIdx.x * vec_size < head_dim / 2) - ? threadIdx.x * vec_size + head_dim / 2 - : threadIdx.x * vec_size - head_dim / 2)); + if (threadIdx.x * vec_size < rotary_dim) { + permuted_vec.cast_load(x + ((threadIdx.x * vec_size < rotary_dim / 2) + ? threadIdx.x * vec_size + rotary_dim / 2 + : threadIdx.x * vec_size - rotary_dim / 2)); #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - float embed = float(offset) * freq[i]; - float cos, sin; - __sincosf(embed, &sin, &cos); - vec[i] = vec[i] * cos + - ((threadIdx.x * vec_size < head_dim / 2) ? -permuted_vec[i] : permuted_vec[i]) * sin; + for (uint32_t i = 0; i < vec_size; ++i) { + float embed = float(offset) * freq[i]; + float cos, sin; + __sincosf(embed, &sin, &cos); + vec[i] = + vec[i] * cos + + ((threadIdx.x * vec_size < rotary_dim / 2) ? -permuted_vec[i] : permuted_vec[i]) * sin; + } } return vec; } template __device__ __forceinline__ vec_t vec_apply_llama_rope_cos_sin( - const T* x, const vec_t& cos, const vec_t& sin) { - constexpr uint32_t head_dim = vec_size * bdx; + const T* x, const vec_t& cos, const vec_t& sin, + const uint32_t rotary_dim = vec_size * bdx) { vec_t permuted_vec, vec; vec.cast_load(x + threadIdx.x * vec_size); - permuted_vec.cast_load(x + ((threadIdx.x * vec_size < head_dim / 2) - ? threadIdx.x * vec_size + head_dim / 2 - : threadIdx.x * vec_size - head_dim / 2)); + if (threadIdx.x * vec_size < rotary_dim) { + permuted_vec.cast_load(x + ((threadIdx.x * vec_size < rotary_dim / 2) + ? threadIdx.x * vec_size + rotary_dim / 2 + : threadIdx.x * vec_size - rotary_dim / 2)); #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - vec[i] = - vec[i] * cos[i] + - ((threadIdx.x * vec_size < head_dim / 2) ? -permuted_vec[i] : permuted_vec[i]) * sin[i]; + for (uint32_t i = 0; i < vec_size; ++i) { + vec[i] = + vec[i] * cos[i] + + ((threadIdx.x * vec_size < rotary_dim / 2) ? -permuted_vec[i] : permuted_vec[i]) * sin[i]; + } } return vec; } @@ -128,31 +133,37 @@ __device__ __forceinline__ vec_t vec_apply_llama_rope_cos_sin( */ template __device__ __forceinline__ vec_t vec_apply_llama_rope_interleave( - const T* x, const vec_t& freq, int32_t offset) { + const T* x, const vec_t& freq, int32_t offset, + const uint32_t rotary_dim = vec_size * bdx) { vec_t vec, vec_before; vec.cast_load(x + threadIdx.x * vec_size); - vec_before = vec; + if (threadIdx.x * vec_size < rotary_dim) { + vec_before = vec; #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - float embed = float(offset) * freq[i]; - float cos, sin; - __sincosf(embed, &sin, &cos); - vec[i] = vec[i] * cos + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin; + for (uint32_t i = 0; i < vec_size; ++i) { + float embed = float(offset) * freq[i]; + float cos, sin; + __sincosf(embed, &sin, &cos); + vec[i] = vec[i] * cos + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin; + } } return vec; } template __device__ __forceinline__ vec_t vec_apply_llama_rope_cos_sin_interleave( - const T* x, const vec_t& cos, const vec_t& sin) { + const T* x, const vec_t& cos, const vec_t& sin, + const uint32_t rotary_dim = vec_size * bdx) { vec_t vec, vec_before; vec.cast_load(x + threadIdx.x * vec_size); - vec_before = vec; + if (threadIdx.x * vec_size < rotary_dim) { + vec_before = vec; #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - vec[i] = vec[i] * cos[i] + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin[i]; + for (uint32_t i = 0; i < vec_size; ++i) { + vec[i] = vec[i] * cos[i] + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin[i]; + } } return vec; } @@ -162,9 +173,9 @@ template q_vec; if constexpr (interleave) { - q_vec = vec_apply_llama_rope_cos_sin_interleave(q_ptr, cos, sin); + q_vec = vec_apply_llama_rope_cos_sin_interleave(q_ptr, cos, sin, rotary_dim); } else { - q_vec = vec_apply_llama_rope_cos_sin(q_ptr, cos, sin); + q_vec = vec_apply_llama_rope_cos_sin(q_ptr, cos, sin, rotary_dim); } q_vec.cast_store(q_rope_ptr + tx * vec_size); } @@ -197,9 +210,9 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheKernel( k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h); vec_t k_vec; if constexpr (interleave) { - k_vec = vec_apply_llama_rope_cos_sin_interleave(k_ptr, cos, sin); + k_vec = vec_apply_llama_rope_cos_sin_interleave(k_ptr, cos, sin, rotary_dim); } else { - k_vec = vec_apply_llama_rope_cos_sin(k_ptr, cos, sin); + k_vec = vec_apply_llama_rope_cos_sin(k_ptr, cos, sin, rotary_dim); } k_vec.cast_store(k_rope_ptr + tx * vec_size); } @@ -210,26 +223,28 @@ template __global__ void BatchQKApplyRotaryPosIdsKernel( DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ pos_ids, uint32_t nnz, - uint32_t num_qo_heads, uint32_t num_kv_heads, size_t q_stride_n, size_t q_stride_h, - size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h, - size_t k_rope_stride_n, size_t k_rope_stride_h, float smooth_a, float smooth_b, - float rope_rcp_scale, float rope_rcp_theta) { + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, size_t q_stride_n, + size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, + size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, float smooth_a, + float smooth_b, float rope_rcp_scale, float rope_rcp_theta) { // NOTE: q and q_rope may be the same ptr, so do k and k_rope uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; const uint32_t bdy = blockDim.y; vec_t freq; + if (tx * vec_size < rotary_dim) { #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - if constexpr (interleave) { - freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(head_dim)); - } else { - freq[i] = __powf(rope_rcp_theta, - float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); - } + for (uint32_t i = 0; i < vec_size; ++i) { + if constexpr (interleave) { + freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(rotary_dim)); + } else { + freq[i] = __powf(rope_rcp_theta, + float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / float(rotary_dim)); + } - float smooth = freq[i] * smooth_a + smooth_b; - smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] - freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; + float smooth = freq[i] * smooth_a + smooth_b; + smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] + freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; + } } vec_t cos, sin; @@ -238,10 +253,12 @@ __global__ void BatchQKApplyRotaryPosIdsKernel( const uint32_t idx = bx * bdy + ty; const IdType pos = pos_ids[idx]; + if (tx * vec_size < rotary_dim) { #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - float embed = float(pos) * freq[i]; - __sincosf(embed, &sin[i], &cos[i]); + for (uint32_t i = 0; i < vec_size; ++i) { + float embed = float(pos) * freq[i]; + __sincosf(embed, &sin[i], &cos[i]); + } } #pragma unroll 1 @@ -251,9 +268,9 @@ __global__ void BatchQKApplyRotaryPosIdsKernel( q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h); vec_t q_vec; if constexpr (interleave) { - q_vec = vec_apply_llama_rope_cos_sin_interleave(q_ptr, cos, sin); + q_vec = vec_apply_llama_rope_cos_sin_interleave(q_ptr, cos, sin, rotary_dim); } else { - q_vec = vec_apply_llama_rope_cos_sin(q_ptr, cos, sin); + q_vec = vec_apply_llama_rope_cos_sin(q_ptr, cos, sin, rotary_dim); } q_vec.cast_store(q_rope_ptr + tx * vec_size); } @@ -265,9 +282,9 @@ __global__ void BatchQKApplyRotaryPosIdsKernel( k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h); vec_t k_vec; if constexpr (interleave) { - k_vec = vec_apply_llama_rope_cos_sin_interleave(k_ptr, cos, sin); + k_vec = vec_apply_llama_rope_cos_sin_interleave(k_ptr, cos, sin, rotary_dim); } else { - k_vec = vec_apply_llama_rope_cos_sin(k_ptr, cos, sin); + k_vec = vec_apply_llama_rope_cos_sin(k_ptr, cos, sin, rotary_dim); } k_vec.cast_store(k_rope_ptr + tx * vec_size); } @@ -279,24 +296,26 @@ template freq; + if (tx * vec_size < rotary_dim) { #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - if constexpr (interleave) { - freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(head_dim)); - } else { - freq[i] = __powf(rope_rcp_theta, - float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); - } + for (uint32_t i = 0; i < vec_size; ++i) { + if constexpr (interleave) { + freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(rotary_dim)); + } else { + freq[i] = __powf(rope_rcp_theta, + float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / float(rotary_dim)); + } - float smooth = freq[i] * smooth_a + smooth_b; - smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] - freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; + float smooth = freq[i] * smooth_a + smooth_b; + smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] + freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; + } } if (bx < batch_size * num_qo_heads) { @@ -315,10 +334,11 @@ __global__ void BatchQKApplyRotaryKernel( q_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h); if constexpr (interleave) { - q_vec = - vec_apply_llama_rope_interleave(q_ptr, freq, offset + i * bdy + ty); + q_vec = vec_apply_llama_rope_interleave(q_ptr, freq, offset + i * bdy + ty, + rotary_dim); } else { - q_vec = vec_apply_llama_rope(q_ptr, freq, offset + i * bdy + ty); + q_vec = + vec_apply_llama_rope(q_ptr, freq, offset + i * bdy + ty, rotary_dim); } q_vec.cast_store(q_rope_ptr + tx * vec_size); } @@ -339,10 +359,11 @@ __global__ void BatchQKApplyRotaryKernel( k_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h); if constexpr (interleave) { - k_vec = - vec_apply_llama_rope_interleave(k_ptr, freq, offset + i * bdy + ty); + k_vec = vec_apply_llama_rope_interleave(k_ptr, freq, offset + i * bdy + ty, + rotary_dim); } else { - k_vec = vec_apply_llama_rope(k_ptr, freq, offset + i * bdy + ty); + k_vec = + vec_apply_llama_rope(k_ptr, freq, offset + i * bdy + ty, rotary_dim); } k_vec.cast_store(k_rope_ptr + tx * vec_size); } @@ -362,10 +383,10 @@ __global__ void BatchQKApplyRotaryKernel( template cudaError_t BatchQKApplyRotaryPosIdsCosSinCache( DType* q, DType* k, DType* q_rope, DType* k_rope, float* cos_cache, float* sin_cache, - IdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, - size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, - size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, - bool interleave, cudaStream_t stream = nullptr) { + IdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t rotary_dim, uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, + size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, + size_t k_rope_stride_h, bool interleave, cudaStream_t stream = nullptr) { DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); @@ -386,6 +407,7 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCache( (void*)&nnz, (void*)&num_qo_heads, (void*)&num_kv_heads, + (void*)&rotary_dim, (void*)&q_stride_n, (void*)&q_stride_h, (void*)&k_stride_n, @@ -402,14 +424,12 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCache( } template -cudaError_t BatchQKApplyRotaryPosIds(DType* q, DType* k, DType* q_rope, DType* k_rope, - IdType* __restrict__ pos_ids, uint32_t nnz, - uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, - size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, - size_t q_rope_stride_h, size_t k_rope_stride_n, - size_t k_rope_stride_h, bool interleave, float rope_scale, - float rope_theta, cudaStream_t stream = nullptr) { +cudaError_t BatchQKApplyRotaryPosIds( + DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ pos_ids, uint32_t nnz, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, uint32_t head_dim, + size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, + bool interleave, float rope_scale, float rope_theta, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; float smooth_a = 0.f; @@ -433,6 +453,7 @@ cudaError_t BatchQKApplyRotaryPosIds(DType* q, DType* k, DType* q_rope, DType* k (void*)&nnz, (void*)&num_qo_heads, (void*)&num_kv_heads, + (void*)&rotary_dim, (void*)&q_stride_n, (void*)&q_stride_h, (void*)&k_stride_n, @@ -456,11 +477,11 @@ template cudaError_t BatchQKApplyRotary(DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ indptr, IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, - size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, - size_t q_rope_stride_h, size_t k_rope_stride_n, - size_t k_rope_stride_h, bool interleave, float rope_scale, - float rope_theta, cudaStream_t stream = nullptr) { + uint32_t rotary_dim, uint32_t head_dim, size_t q_stride_n, + size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + size_t q_rope_stride_n, size_t q_rope_stride_h, + size_t k_rope_stride_n, size_t k_rope_stride_h, bool interleave, + float rope_scale, float rope_theta, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; float smooth_a = 0.f; @@ -484,6 +505,7 @@ cudaError_t BatchQKApplyRotary(DType* q, DType* k, DType* q_rope, DType* k_rope, (void*)&batch_size, (void*)&num_qo_heads, (void*)&num_kv_heads, + (void*)&rotary_dim, (void*)&q_stride_n, (void*)&q_stride_h, (void*)&k_stride_n, @@ -506,26 +528,26 @@ cudaError_t BatchQKApplyRotary(DType* q, DType* k, DType* q_rope, DType* k_rope, template cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, IdType* __restrict__ offsets, - uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, - size_t k_stride_n, size_t k_stride_h, - bool interleave, float rope_scale, float rope_theta, - cudaStream_t stream = nullptr) { - return BatchQKApplyRotary(q, k, q, k, indptr, offsets, batch_size, num_qo_heads, num_kv_heads, - head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, - q_stride_n, q_stride_h, k_stride_n, k_stride_h, - interleave, rope_scale, rope_theta, stream); - + uint32_t batch_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t rotary_dim, uint32_t head_dim, + size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, + size_t k_stride_h, bool interleave, float rope_scale, + float rope_theta, cudaStream_t stream = nullptr) { + return BatchQKApplyRotary( + q, k, q, k, indptr, offsets, batch_size, num_qo_heads, num_kv_heads, rotary_dim, head_dim, + q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_stride_n, q_stride_h, k_stride_n, + k_stride_h, interleave, rope_scale, rope_theta, stream); } template cudaError_t BatchQKApplyLlama31Rotary( DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ indptr, IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, - size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, - bool interleave, float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length, cudaStream_t stream = nullptr) { + uint32_t rotary_dim, uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, + size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, + size_t k_rope_stride_h, bool interleave, float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, float old_context_length, + cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); @@ -549,6 +571,7 @@ cudaError_t BatchQKApplyLlama31Rotary( (void*)&batch_size, (void*)&num_qo_heads, (void*)&num_kv_heads, + (void*)&rotary_dim, (void*)&q_stride_n, (void*)&q_stride_h, (void*)&k_stride_n, @@ -571,11 +594,11 @@ cudaError_t BatchQKApplyLlama31Rotary( template cudaError_t BatchQKApplyLlama31RotaryPosIds( DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* pos_ids, uint32_t nnz, - uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n, - size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, - size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, bool interleave, - float rope_scale, float rope_theta, float low_freq_factor, float high_freq_factor, - float old_context_length, cudaStream_t stream = nullptr) { + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, uint32_t head_dim, + size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); @@ -599,6 +622,7 @@ cudaError_t BatchQKApplyLlama31RotaryPosIds( (void*)&nnz, (void*)&num_qo_heads, (void*)&num_kv_heads, + (void*)&rotary_dim, (void*)&q_stride_n, (void*)&q_stride_h, (void*)&k_stride_n, diff --git a/python/csrc/flashinfer_rope_ops.cu b/python/csrc/flashinfer_rope_ops.cu index c07be244..369997c4 100644 --- a/python/csrc/flashinfer_rope_ops.cu +++ b/python/csrc/flashinfer_rope_ops.cu @@ -18,22 +18,24 @@ #include void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta); + torch::Tensor indptr, torch::Tensor offsets, unsigned int rotary_dim, + bool interleave, float rope_scale, float rope_theta); void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, - bool interleave, float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length); + unsigned int rotary_dim, bool interleave, float rope_scale, + float rope_theta, float low_freq_factor, float high_freq_factor, + float old_context_length); void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta); + torch::Tensor k_rope, torch::Tensor pos_ids, unsigned int rotary_dim, + bool interleave, float rope_scale, float rope_theta); void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length); + torch::Tensor k_rope, torch::Tensor pos_ids, + unsigned int rotary_dim, bool interleave, float rope_scale, + float rope_theta, float low_freq_factor, float high_freq_factor, + float old_context_length); void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, torch::Tensor cos_cache, diff --git a/python/csrc/rope.cu b/python/csrc/rope.cu index 9f773284..a2239af3 100644 --- a/python/csrc/rope.cu +++ b/python/csrc/rope.cu @@ -20,8 +20,8 @@ using namespace flashinfer; void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta) { + torch::Tensor indptr, torch::Tensor offsets, unsigned int rotary_dim, + bool interleave, float rope_scale, float rope_theta) { CHECK_LAST_DIM_CONTIGUOUS(q); CHECK_LAST_DIM_CONTIGUOUS(k); CHECK_INPUT(indptr); @@ -57,9 +57,9 @@ void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::T static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), - batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, interleave, - rope_scale, rope_theta, torch_current_stream); + batch_size, num_qo_heads, num_kv_heads, rotary_dim, head_dim, q_stride_n, q_stride_h, + k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, + interleave, rope_scale, rope_theta, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotary failed with error code " + std::string(cudaGetErrorString(status))); return true; @@ -67,8 +67,8 @@ void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::T } void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta) { + torch::Tensor k_rope, torch::Tensor pos_ids, unsigned int rotary_dim, + bool interleave, float rope_scale, float rope_theta) { CHECK_LAST_DIM_CONTIGUOUS(q); CHECK_LAST_DIM_CONTIGUOUS(k); CHECK_INPUT(pos_ids); @@ -98,8 +98,8 @@ void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, cudaError_t status = BatchQKApplyRotaryPosIds( static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), - static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, head_dim, - q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rotary_dim, + head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, interleave, rope_scale, rope_theta, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotaryPosIds failed with error code " + std::string(cudaGetErrorString(status))); @@ -127,8 +127,8 @@ void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::T CHECK_DIM(2, sin_cache); // sin_cache: (max_seq_len, D) CHECK_EQ(q.size(0), k.size(0)); CHECK_EQ(q.size(2), k.size(2)); - CHECK_EQ(cos_cache.size(1), q.size(2)); - CHECK_EQ(sin_cache.size(1), q.size(2)); + unsigned int rotary_dim = cos_cache.size(1); + CHECK_EQ(sin_cache.size(1), rotary_dim); CHECK_EQ(cos_cache.dtype(), torch::kFloat32); CHECK_EQ(sin_cache.dtype(), torch::kFloat32); unsigned int num_qo_heads = q.size(1); @@ -151,8 +151,8 @@ void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::T static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), static_cast(cos_cache.data_ptr()), static_cast(sin_cache.data_ptr()), - static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, head_dim, - q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rotary_dim, + head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, interleave, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + @@ -163,8 +163,9 @@ void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::T void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, - bool interleave, float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length) { + unsigned int rotary_dim, bool interleave, float rope_scale, + float rope_theta, float low_freq_factor, float high_freq_factor, + float old_context_length) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(indptr); @@ -200,9 +201,9 @@ void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), - batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, interleave, - rope_scale, rope_theta, low_freq_factor, high_freq_factor, old_context_length, + batch_size, num_qo_heads, num_kv_heads, rotary_dim, head_dim, q_stride_n, q_stride_h, + k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, + interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, old_context_length, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31Rotary failed with error code " + std::string(cudaGetErrorString(status))); @@ -211,9 +212,10 @@ void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, } void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length) { + torch::Tensor k_rope, torch::Tensor pos_ids, + unsigned int rotary_dim, bool interleave, float rope_scale, + float rope_theta, float low_freq_factor, float high_freq_factor, + float old_context_length) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(pos_ids); @@ -243,8 +245,8 @@ void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor cudaError_t status = BatchQKApplyLlama31RotaryPosIds( static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), - static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, head_dim, - q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rotary_dim, + head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, old_context_length, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31RotaryPosIds failed with error code " + diff --git a/python/csrc_aot/flashinfer_ops.cu b/python/csrc_aot/flashinfer_ops.cu index cc9f3935..ec519c78 100644 --- a/python/csrc_aot/flashinfer_ops.cu +++ b/python/csrc_aot/flashinfer_ops.cu @@ -129,22 +129,29 @@ torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, //========== rope ========== void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, - torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta); + torch::Tensor indptr, torch::Tensor offsets, unsigned int rotary_dim, + bool interleave, float rope_scale, float rope_theta); void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, - bool interleave, float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length); + unsigned int rotary_dim, bool interleave, float rope_scale, + float rope_theta, float low_freq_factor, float high_freq_factor, + float old_context_length); void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta); + torch::Tensor k_rope, torch::Tensor pos_ids, unsigned int rotary_dim, + bool interleave, float rope_scale, float rope_theta); void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, - torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, - float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length); + torch::Tensor k_rope, torch::Tensor pos_ids, + unsigned int rotary_dim, bool interleave, float rope_scale, + float rope_theta, float low_freq_factor, float high_freq_factor, + float old_context_length); + +void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor cos_cache, + torch::Tensor sin_cache, torch::Tensor pos_ids, + bool interleave); //========== sampling ========== diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py index e28a69d4..ef2f20b2 100644 --- a/python/flashinfer/rope.py +++ b/python/flashinfer/rope.py @@ -14,7 +14,7 @@ limitations under the License. """ -from typing import Tuple +from typing import Optional, Tuple import torch @@ -50,12 +50,22 @@ def _apply_rope( k_rope: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, + rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, ) -> None: get_rope_module().apply_rope( - q, k, q_rope, k_rope, indptr, offsets, interleave, rope_scale, rope_theta + q, + k, + q_rope, + k_rope, + indptr, + offsets, + rotary_dim, + interleave, + rope_scale, + rope_theta, ) @@ -67,6 +77,7 @@ def _fake_apply_rope( k_rope: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, + rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, @@ -82,6 +93,7 @@ def _apply_llama31_rope( k_rope: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, + rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, @@ -96,6 +108,7 @@ def _apply_llama31_rope( k_rope, indptr, offsets, + rotary_dim, interleave, rope_scale, rope_theta, @@ -113,6 +126,7 @@ def _fake_apply_llama31_rope( k_rope: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, + rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, @@ -130,12 +144,13 @@ def _apply_rope_pos_ids( q_rope: torch.Tensor, k_rope: torch.Tensor, pos_ids: torch.Tensor, + rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, ) -> None: get_rope_module().apply_rope_pos_ids( - q, k, q_rope, k_rope, pos_ids, interleave, rope_scale, rope_theta + q, k, q_rope, k_rope, pos_ids, rotary_dim, interleave, rope_scale, rope_theta ) @@ -146,6 +161,7 @@ def _fake_apply_rope_pos_ids( q_rope: torch.Tensor, k_rope: torch.Tensor, pos_ids: torch.Tensor, + rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, @@ -201,6 +217,7 @@ def _apply_llama31_rope_pos_ids( q_rope: torch.Tensor, k_rope: torch.Tensor, pos_ids: torch.Tensor, + rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, @@ -214,6 +231,7 @@ def _apply_llama31_rope_pos_ids( q_rope, k_rope, pos_ids, + rotary_dim, interleave, rope_scale, rope_theta, @@ -230,6 +248,7 @@ def _fake_apply_llama31_rope_pos_ids( q_rope: torch.Tensor, k_rope: torch.Tensor, pos_ids: torch.Tensor, + rotary_dim: int, interleave: bool, rope_scale: float, rope_theta: float, @@ -245,6 +264,7 @@ def apply_rope_inplace( k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, + rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 1, rope_theta: float = 1e4, @@ -271,6 +291,9 @@ def apply_rope_inplace( Indptr tensor, shape: ``(batch_size + 1)``. offsets : torch.Tensor The relative position offsets of each query in the batch, shape: ``(batch_size)``. + rotary_dim : Optional[int] + The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, + otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. @@ -316,13 +339,18 @@ def apply_rope_inplace( -------- apply_rope """ - _apply_rope(q, k, q, k, indptr, offsets, interleave, rope_scale, rope_theta) + if rotary_dim is None: + rotary_dim = q.size(-1) + _apply_rope( + q, k, q, k, indptr, offsets, rotary_dim, interleave, rope_scale, rope_theta + ) def apply_rope_pos_ids_inplace( q: torch.Tensor, k: torch.Tensor, pos_ids: torch.Tensor, + rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 1, rope_theta: float = 1e4, @@ -347,6 +375,9 @@ def apply_rope_pos_ids_inplace( element of ``indptr``. pos_ids : torch.Tensor Position indices, shape: ``(nnz)``. + rotary_dim : Optional[int] + The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, + otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. @@ -366,7 +397,11 @@ def apply_rope_pos_ids_inplace( -------- apply_rope_pos_ids """ - _apply_rope_pos_ids(q, k, q, k, pos_ids, interleave, rope_scale, rope_theta) + if rotary_dim is None: + rotary_dim = q.size(-1) + _apply_rope_pos_ids( + q, k, q, k, pos_ids, rotary_dim, interleave, rope_scale, rope_theta + ) def apply_llama31_rope_inplace( @@ -374,6 +409,7 @@ def apply_llama31_rope_inplace( k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, + rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 8, rope_theta: float = 5e5, @@ -403,6 +439,9 @@ def apply_llama31_rope_inplace( Indptr tensor, shape: ``(batch_size + 1)``. offsets : torch.Tensor The relative position offsets of each query in the batch, shape: ``(batch_size)``. + rotary_dim : Optional[int] + The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, + otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. @@ -454,6 +493,8 @@ def apply_llama31_rope_inplace( -------- apply_llama31_rope """ + if rotary_dim is None: + rotary_dim = q.size(-1) _apply_llama31_rope( q, k, @@ -461,6 +502,7 @@ def apply_llama31_rope_inplace( k, indptr, offsets, + rotary_dim, interleave, rope_scale, rope_theta, @@ -474,6 +516,7 @@ def apply_llama31_rope_pos_ids_inplace( q: torch.Tensor, k: torch.Tensor, pos_ids: torch.Tensor, + rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 8, rope_theta: float = 5e5, @@ -501,6 +544,9 @@ def apply_llama31_rope_pos_ids_inplace( element of ``indptr``. pos_ids : torch.Tensor Position indices, shape: ``(nnz)``. + rotary_dim : Optional[int] + The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, + otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. @@ -526,12 +572,15 @@ def apply_llama31_rope_pos_ids_inplace( -------- apply_llama31_rope_pos_ids """ + if rotary_dim is None: + rotary_dim = q.size(-1) _apply_llama31_rope_pos_ids( q, k, q, k, pos_ids, + rotary_dim, interleave, rope_scale, rope_theta, @@ -546,6 +595,7 @@ def apply_rope( k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, + rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 1, rope_theta: float = 1e4, @@ -572,6 +622,9 @@ def apply_rope( Indptr tensor, shape: ``(batch_size + 1)``. offsets : torch.Tensor The relative position offsets of each query in the batch, shape: ``(batch_size)``. + rotary_dim : Optional[int] + The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, + otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. @@ -630,8 +683,19 @@ def apply_rope( """ q_rope = torch.empty_like(q) k_rope = torch.empty_like(k) + if rotary_dim is None: + rotary_dim = q.size(-1) _apply_rope( - q, k, q_rope, k_rope, indptr, offsets, interleave, rope_scale, rope_theta + q, + k, + q_rope, + k_rope, + indptr, + offsets, + rotary_dim, + interleave, + rope_scale, + rope_theta, ) return q_rope, k_rope @@ -640,6 +704,7 @@ def apply_rope_pos_ids( q: torch.Tensor, k: torch.Tensor, pos_ids: torch.Tensor, + rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 1, rope_theta: float = 1e4, @@ -664,6 +729,9 @@ def apply_rope_pos_ids( element of ``indptr``. pos_ids : torch.Tensor Position indices, shape: ``(batch_size + 1)``. + rotary_dim : Optional[int] + The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, + otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. @@ -692,8 +760,10 @@ def apply_rope_pos_ids( """ q_rope = torch.empty_like(q) k_rope = torch.empty_like(k) + if rotary_dim is None: + rotary_dim = q.size(-1) _apply_rope_pos_ids( - q, k, q_rope, k_rope, pos_ids, interleave, rope_scale, rope_theta + q, k, q_rope, k_rope, pos_ids, rotary_dim, interleave, rope_scale, rope_theta ) return q_rope, k_rope @@ -703,6 +773,7 @@ def apply_llama31_rope( k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, + rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 8, rope_theta: float = 5e5, @@ -732,6 +803,9 @@ def apply_llama31_rope( Indptr tensor, shape: ``(batch_size + 1)``. offsets : torch.Tensor The relative position offsets of each query in the batch, shape: ``(batch_size)``. + rotary_dim : Optional[int] + The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, + otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. @@ -796,6 +870,8 @@ def apply_llama31_rope( """ q_rope = torch.empty_like(q) k_rope = torch.empty_like(k) + if rotary_dim is None: + rotary_dim = q.size(-1) _apply_llama31_rope( q, k, @@ -803,6 +879,7 @@ def apply_llama31_rope( k_rope, indptr, offsets, + rotary_dim, interleave, rope_scale, rope_theta, @@ -817,6 +894,7 @@ def apply_llama31_rope_pos_ids( q: torch.Tensor, k: torch.Tensor, pos_ids: torch.Tensor, + rotary_dim: Optional[int] = None, interleave: bool = False, rope_scale: float = 8, rope_theta: float = 5e5, @@ -844,6 +922,9 @@ def apply_llama31_rope_pos_ids( element of ``indptr``. pos_ids : torch.Tensor Position indices, shape: ``(nnz)``. + rotary_dim : Optional[int] + The dimensions to apply RoPE, if ``None``, we apply RoPE to the entire head dimension, + otherwise, we apply RoPE to the first ``rotary_dim`` dimensions, default: ``None``. interleave : bool Whether to use interleaved layout in the last dimension, default: ``False``. @@ -877,12 +958,15 @@ def apply_llama31_rope_pos_ids( """ q_rope = torch.empty_like(q) k_rope = torch.empty_like(k) + if rotary_dim is None: + rotary_dim = q.size(-1) _apply_llama31_rope_pos_ids( q, k, q_rope, k_rope, pos_ids, + rotary_dim, interleave, rope_scale, rope_theta, @@ -910,9 +994,9 @@ def apply_rope_with_cos_sin_cache( k : torch.Tensor Key tensor, shape: ``(nnz, num_k_heads, head_dim)``. cos_cache : torch.Tensor - Cosine cache tensor, shape: ``(max_seq_len, head_dim)``. + Cosine cache tensor, shape: ``(max_seq_len, rotary_dim)``. sin_cache : torch.Tensor - Sine cache tensor, shape: ``(max_seq_len, head_dim)``. + Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``. pos_ids : torch.Tensor Position indices, shape: ``(nnz)``. interleave : bool @@ -931,6 +1015,10 @@ def apply_rope_with_cos_sin_cache( The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``. k_rope : torch.Tensor The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``. + + Note + ---- + The rotary dimension is determined by the cosine cache and sine cache. """ if cos_cache.dtype != torch.float32 or sin_cache.dtype != torch.float32: raise ValueError("cos_cache and sin_cache should be float32") @@ -960,10 +1048,10 @@ def apply_rope_with_cos_sin_cache_inplace( k : torch.Tensor Key tensor, shape: ``(nnz, num_k_heads, head_dim)``. cos_cache : torch.Tensor - Cosine cache tensor, shape: ``(max_seq_len, head_dim)``. + Cosine cache tensor, shape: ``(max_seq_len, rotary_dim)``. Expect float32 data type. sin_cache : torch.Tensor - Sine cache tensor, shape: ``(max_seq_len, head_dim)``. + Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``. Expect float32 data type. pos_ids : torch.Tensor Position indices, shape: ``(nnz)``. @@ -976,6 +1064,10 @@ def apply_rope_with_cos_sin_cache_inplace( * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half dimensions ``([..., head_dim//2:])``. + + Note + ---- + The rotary dimension is determined by the cosine cache and sine cache. """ if cos_cache.dtype != torch.float32 or sin_cache.dtype != torch.float32: raise ValueError("cos_cache and sin_cache should be float32") diff --git a/tests/test_rope.py b/tests/test_rope.py index 6af2a549..c270d1e9 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -28,6 +28,7 @@ @pytest.mark.parametrize("offset", [0, 15, 99]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) @pytest.mark.parametrize("llama_version", ["llama", "llama31"]) +@pytest.mark.parametrize("partial_rotary_factor", [0.25, 0.5, 0.75, 1.0]) @pytest.mark.parametrize("inplace", [False, True]) def test_rope( batch_size, @@ -37,8 +38,10 @@ def test_rope( offset, head_dim, llama_version, + partial_rotary_factor, inplace, ): + rotary_dim = int(head_dim * partial_rotary_factor) nnz = batch_size * qkv_len qkv_packed = torch.randn( nnz, @@ -58,40 +61,74 @@ def test_rope( # reference implementation if llama_version == "llama": freqs_cis = precompute_freqs_cis( - head_dim, qkv_len + offset, 10000.0, use_scaled=False + rotary_dim, qkv_len + offset, 10000.0, use_scaled=False ).to("cuda:0") else: freqs_cis = precompute_freqs_cis( - head_dim, qkv_len + offset, 5e5, use_scaled=True + rotary_dim, qkv_len + offset, 5e5, use_scaled=True ).to("cuda:0") - q_rope_ref, k_rope_ref = apply_rotary_emb( - q.reshape(batch_size, qkv_len, num_qo_heads, head_dim), - k.reshape(batch_size, qkv_len, num_kv_heads, head_dim), + q_rot_ref, k_rot_ref = apply_rotary_emb( + q.reshape(batch_size, qkv_len, num_qo_heads, head_dim)[..., :rotary_dim], + k.reshape(batch_size, qkv_len, num_kv_heads, head_dim)[..., :rotary_dim], freqs_cis[offset : offset + qkv_len], ) - q_rope_ref = q_rope_ref.reshape(nnz, num_qo_heads, head_dim) - k_rope_ref = k_rope_ref.reshape(nnz, num_kv_heads, head_dim) + q_pass_ref = q.reshape(batch_size, qkv_len, num_qo_heads, head_dim)[ + ..., rotary_dim: + ] + k_pass_ref = k.reshape(batch_size, qkv_len, num_kv_heads, head_dim)[ + ..., rotary_dim: + ] + q_rope_ref = torch.cat([q_rot_ref, q_pass_ref], dim=-1).reshape( + nnz, num_qo_heads, head_dim + ) + k_rope_ref = torch.cat([k_rot_ref, k_pass_ref], dim=-1).reshape( + nnz, num_kv_heads, head_dim + ) # flashinfer implementation if llama_version == "llama": if inplace: flashinfer.apply_rope_inplace( - q, k, indptr, offsets, interleave=True, rope_theta=1e4 + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=True, + rope_theta=1e4, ) q_rope, k_rope = q, k else: q_rope, k_rope = flashinfer.apply_rope( - q, k, indptr, offsets, interleave=True, rope_theta=1e4 + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=True, + rope_theta=1e4, ) else: if inplace: flashinfer.apply_llama31_rope_inplace( - q, k, indptr, offsets, interleave=True, rope_theta=5e5 + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=True, + rope_theta=5e5, ) q_rope, k_rope = q, k else: q_rope, k_rope = flashinfer.apply_llama31_rope( - q, k, indptr, offsets, interleave=True, rope_theta=5e5 + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=True, + rope_theta=5e5, ) # compare @@ -106,6 +143,7 @@ def test_rope( @pytest.mark.parametrize("offset", [0, 15, 99]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) @pytest.mark.parametrize("llama_version", ["llama", "llama31"]) +@pytest.mark.parametrize("partial_rotary_factor", [0.25, 0.5, 0.75, 1.0]) @pytest.mark.parametrize("inplace", [False, True]) def test_rope_pos_ids( batch_size, @@ -115,8 +153,10 @@ def test_rope_pos_ids( offset, head_dim, llama_version, + partial_rotary_factor, inplace, ): + rotary_dim = int(head_dim * partial_rotary_factor) nnz = batch_size * qkv_len qkv_packed = torch.randn( nnz, @@ -144,39 +184,73 @@ def test_rope_pos_ids( if inplace: q_clone, k_clone = q.clone(), k.clone() flashinfer.apply_rope_inplace( - q, k, indptr, offsets, interleave=True, rope_theta=1e4 + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=True, + rope_theta=1e4, ) q_rope, k_rope = q, k flashinfer.apply_rope_pos_ids_inplace( - q_clone, k_clone, pos_ids, interleave=True, rope_theta=1e4 + q_clone, + k_clone, + pos_ids, + rotary_dim=rotary_dim, + interleave=True, + rope_theta=1e4, ) q_rope_pos_ids, k_rope_pos_ids = q_clone, k_clone else: q_rope, k_rope = flashinfer.apply_rope( - q, k, indptr, offsets, interleave=True, rope_theta=1e4 + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=True, + rope_theta=1e4, ) q_rope_pos_ids, k_rope_pos_ids = flashinfer.apply_rope_pos_ids( - q, k, pos_ids, interleave=True, rope_theta=1e4 + q, k, pos_ids, rotary_dim=rotary_dim, interleave=True, rope_theta=1e4 ) else: if inplace: q_clone, k_clone = q.clone(), k.clone() flashinfer.apply_llama31_rope_inplace( - q, k, indptr, offsets, interleave=True, rope_theta=5e5 + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=True, + rope_theta=5e5, ) q_rope, k_rope = q, k flashinfer.apply_llama31_rope_pos_ids_inplace( - q_clone, k_clone, pos_ids, interleave=True, rope_theta=5e5 + q_clone, + k_clone, + pos_ids, + rotary_dim=rotary_dim, + interleave=True, + rope_theta=5e5, ) q_rope_pos_ids, k_rope_pos_ids = q_clone, k_clone else: q_rope, k_rope = flashinfer.apply_llama31_rope( - q, k, indptr, offsets, interleave=True, rope_theta=5e5 + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=True, + rope_theta=5e5, ) q_rope_pos_ids, k_rope_pos_ids = flashinfer.apply_llama31_rope_pos_ids( - q, k, pos_ids, interleave=True, rope_theta=5e5 + q, k, pos_ids, rotary_dim=rotary_dim, interleave=True, rope_theta=5e5 ) # compare @@ -191,6 +265,7 @@ def test_rope_pos_ids( @pytest.mark.parametrize("offset", [0, 15, 99]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) @pytest.mark.parametrize("llama_version", ["llama", "llama31"]) +@pytest.mark.parametrize("partial_rotary_factor", [0.25, 0.5, 0.75, 1.0]) @pytest.mark.parametrize("inplace", [False, True]) def test_rope_cos_sin_cache( batch_size, @@ -200,8 +275,10 @@ def test_rope_cos_sin_cache( offset, head_dim, llama_version, + partial_rotary_factor, inplace, ): + rotary_dim = int(head_dim * partial_rotary_factor) nnz = batch_size * qkv_len qkv_packed = torch.randn( nnz, @@ -221,10 +298,10 @@ def test_rope_cos_sin_cache( ).to("cuda:0") if llama_version == "llama": - cos_cache, sin_cache = generate_cos_sin_f32_cache(4096, head_dim, theta=1e4) + cos_cache, sin_cache = generate_cos_sin_f32_cache(4096, rotary_dim, theta=1e4) else: cos_cache, sin_cache = generate_cos_sin_f32_cache( - 4096, head_dim, theta=5e5, use_scaled=True + 4096, rotary_dim, theta=5e5, use_scaled=True ) cos_cache = cos_cache.to("cuda:0") sin_cache = sin_cache.to("cuda:0") @@ -234,21 +311,23 @@ def test_rope_cos_sin_cache( if llama_version == "llama": if inplace: - flashinfer.apply_rope_pos_ids_inplace(q, k, pos_ids, interleave=False) + flashinfer.apply_rope_pos_ids_inplace( + q, k, pos_ids, rotary_dim=rotary_dim, interleave=False + ) q_rope, k_rope = q, k else: q_rope, k_rope = flashinfer.apply_rope_pos_ids( - q, k, pos_ids, interleave=False + q, k, pos_ids, rotary_dim=rotary_dim, interleave=False ) else: if inplace: flashinfer.apply_llama31_rope_pos_ids_inplace( - q, k, pos_ids, interleave=False + q, k, pos_ids, rotary_dim=rotary_dim, interleave=False ) q_rope, k_rope = q, k else: q_rope, k_rope = flashinfer.apply_llama31_rope_pos_ids( - q, k, pos_ids, interleave=False + q, k, pos_ids, rotary_dim=rotary_dim, interleave=False ) if inplace: @@ -269,6 +348,6 @@ def test_rope_cos_sin_cache( if __name__ == "__main__": - test_rope(2, 1, 8, 8, 1, 128, "llama31", False) - test_rope_pos_ids(2, 1, 8, 8, 1, 128, "llama31", False) - test_rope_cos_sin_cache(99, 19, 16, 8, 99, 256, "llama31", False) + test_rope(2, 1, 8, 8, 1, 128, "llama31", 1.0, False) + test_rope_pos_ids(2, 1, 8, 8, 1, 128, "llama31", 1.0, False) + test_rope_cos_sin_cache(99, 19, 16, 8, 99, 256, "llama31", 0.5, False)