From b78151383d4a75094195cba29aba45d694d5fdb7 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sun, 4 Aug 2024 15:48:05 +0800 Subject: [PATCH] feat: support fused add rmsnorm (#419) ref https://github.com/sgl-project/sglang/pull/907 cc @yzh119 --- include/flashinfer/norm.cuh | 132 +++++++++++++++++++++++++++++----- python/csrc/flashinfer_ops.cu | 1 + python/csrc/flashinfer_ops.h | 5 +- python/csrc/norm.cu | 62 +++++++++++----- python/flashinfer/__init__.py | 55 +++++--------- python/flashinfer/norm.py | 33 +++++++-- python/tests/test_norm.py | 37 +++++++++- 7 files changed, 246 insertions(+), 79 deletions(-) diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index 41da4218..02c24a7e 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -28,7 +28,7 @@ namespace flashinfer { namespace norm { template -__global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restrict__ y, +__global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* __restrict__ output, const uint32_t d, float eps) { const uint32_t bx = blockIdx.x; const uint32_t tx = threadIdx.x, ty = threadIdx.y; @@ -43,14 +43,14 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric float sum_sq = 0.f; for (uint32_t i = 0; i < rounds; i++) { - vec_t x_vec; - x_vec.fill(0); + vec_t input_vec; + input_vec.fill(0); if ((i * num_threads + thread_id) * VEC_SIZE < d) { - x_vec.load(x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; j++) { - sum_sq += float(x_vec[j]) * float(x_vec[j]); + sum_sq += float(input_vec[j]) * float(input_vec[j]); } } @@ -76,28 +76,28 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric float rms_rcp = math::rsqrt(smem[0] / float(d) + eps); for (uint32_t i = 0; i < rounds; i++) { - vec_t x_vec; - vec_t w_vec; - vec_t y_vec; - x_vec.fill(0); - w_vec.fill(0); + vec_t input_vec; + vec_t weight_vec; + vec_t output_vec; + input_vec.fill(0); + weight_vec.fill(0); if ((i * num_threads + thread_id) * VEC_SIZE < d) { - x_vec.load(x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - w_vec.load(w + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; j++) { - y_vec[j] = float(x_vec[j]) * rms_rcp * float(w_vec[j]); + output_vec[j] = float(input_vec[j]) * rms_rcp * float(weight_vec[j]); } if ((i * num_threads + thread_id) * VEC_SIZE < d) { - y_vec.store(y + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + output_vec.store(output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); } } } template -cudaError_t RMSNorm(T* x, T* w, T* y, uint32_t batch_size, uint32_t d, float eps = 1e-5, - cudaStream_t stream = 0) { +cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d, + float eps = 1e-5, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t block_size = std::min(1024, d / vec_size); @@ -105,7 +105,7 @@ cudaError_t RMSNorm(T* x, T* w, T* y, uint32_t batch_size, uint32_t d, float eps dim3 nblks(batch_size); dim3 nthrs(32, num_warps); const uint32_t smem_size = num_warps * sizeof(float); - void* args[] = {&x, &w, &y, &d, &eps}; + void* args[] = {&input, &weight, &output, &d, &eps}; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = RMSNormKernel; @@ -114,6 +114,104 @@ cudaError_t RMSNorm(T* x, T* w, T* y, uint32_t batch_size, uint32_t d, float eps return cudaSuccess; } +template +__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual, + T* __restrict__ weight, const uint32_t d, float eps) { + const uint32_t bx = blockIdx.x; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t warp_size = 32; + const uint32_t num_warps = blockDim.y; + const uint32_t thread_id = tx + ty * warp_size; + const uint32_t num_threads = num_warps * warp_size; + const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); + extern __shared__ float smem[]; + + float sum_sq = 0.f; + + for (uint32_t i = 0; i < rounds; i++) { + vec_t input_vec; + input_vec.fill(0); + vec_t residual_vec; + residual_vec.fill(0); + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + float x = float(input_vec[j]); + x += float(residual_vec[j]); + sum_sq += x * x; + residual_vec[j] = (T)x; + } + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } + } + + // first, warp reduce sum +#pragma unroll + for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { + sum_sq += math::shfl_xor_sync(sum_sq, offset); + } + + smem[ty] = sum_sq; + __syncthreads(); + // then, cross warp reduce sum using only the first warp + if (ty == 0) { + sum_sq = (tx < num_warps) ? smem[tx] : 0.f; +#pragma unroll + for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { + sum_sq += math::shfl_xor_sync(sum_sq, offset); + } + smem[0] = sum_sq; + } + __syncthreads(); + + float rms_rcp = math::rsqrt(smem[0] / float(d) + eps); + + for (uint32_t i = 0; i < rounds; i++) { + vec_t input_vec; + vec_t weight_vec; + vec_t residual_vec; + input_vec.fill(0); + weight_vec.fill(0); + residual_vec.fill(0); + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]); + } + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } + } +} + +template +cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, + float eps = 1e-5, cudaStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t block_size = std::min(1024, d / vec_size); + const uint32_t num_warps = ceil_div(block_size, 32); + dim3 nblks(batch_size); + dim3 nthrs(32, num_warps); + const uint32_t smem_size = num_warps * sizeof(float); + void* args[] = {&input, &residual, &weight, &d, &eps}; + + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = FusedAddRMSNormKernel; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + + return cudaSuccess; +} + } // namespace norm } // namespace flashinfer diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 9215209b..49c0f518 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -42,6 +42,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("chain_speculative_sampling", &chain_speculative_sampling, "Speculative sampling from sequence of probabilities"); m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); + m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, "Apply Llama 3.1 style RoPE in-place"); diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 60441f3d..02d6a127 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -78,7 +78,10 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso torch::Tensor uniform_samples, torch::Tensor target_probs, bool deterministic); -torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps); +torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps); + +void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, + double eps); void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); diff --git a/python/csrc/norm.cu b/python/csrc/norm.cu index 7a3f9694..64be041a 100644 --- a/python/csrc/norm.cu +++ b/python/csrc/norm.cu @@ -20,26 +20,56 @@ using namespace flashinfer; -torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps) { - CHECK_INPUT(x); - CHECK_INPUT(w); - auto device = x.device(); - CHECK_EQ(w.device(), device); - CHECK_DIM(2, x); // x: (batch_size, hidden_size) - CHECK_DIM(1, w); // w: (hidden_size) - CHECK_EQ(x.size(1), w.size(0)); - unsigned int batch_size = x.size(0); - unsigned int hidden_size = x.size(1); +torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps) { + CHECK_INPUT(input); + CHECK_INPUT(weight); + auto device = input.device(); + CHECK_EQ(weight.device(), device); + CHECK_DIM(2, input); // input: (batch_size, hidden_size) + CHECK_DIM(1, weight); // weight: (hidden_size) + CHECK_EQ(input.size(1), weight.size(0)); + unsigned int batch_size = input.size(0); + unsigned int hidden_size = input.size(1); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - auto y = torch::empty_like(x); - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(x.scalar_type(), c_type, [&] { - cudaError_t status = norm::RMSNorm( - static_cast(x.data_ptr()), static_cast(w.data_ptr()), - static_cast(y.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); + auto output = torch::empty_like(input); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { + cudaError_t status = norm::RMSNorm(static_cast(input.data_ptr()), + static_cast(weight.data_ptr()), + static_cast(output.data_ptr()), batch_size, + hidden_size, eps, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "RMSNorm failed with error code " + std::string(cudaGetErrorString(status))); return true; }); - return y; + return output; +} + +void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, + double eps) { + CHECK_INPUT(input); + CHECK_INPUT(residual); + CHECK_INPUT(weight); + auto device = input.device(); + CHECK_EQ(residual.device(), device); + CHECK_EQ(weight.device(), device); + CHECK_DIM(2, input); // input: (batch_size, hidden_size) + CHECK_DIM(2, residual); // residual: (batch_size, hidden_size) + CHECK_DIM(1, weight); // weight: (hidden_size) + CHECK_EQ(input.size(0), residual.size(0)); + CHECK_EQ(input.size(1), residual.size(1)); + CHECK_EQ(input.size(1), weight.size(0)); + unsigned int batch_size = input.size(0); + unsigned int hidden_size = input.size(1); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { + cudaError_t status = norm::FusedAddRMSNorm(static_cast(input.data_ptr()), + static_cast(residual.data_ptr()), + static_cast(weight.data_ptr()), batch_size, + hidden_size, eps, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "FusedAddRMSNorm failed with error code " + + std::string(cudaGetErrorString(status))); + return true; + }); } diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index ffe9a545..f1429e76 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -14,44 +14,27 @@ limitations under the License. """ -from .decode import ( - single_decode_with_kv_cache, - BatchDecodeWithPagedKVCacheWrapper, - CUDAGraphBatchDecodeWithPagedKVCacheWrapper, -) -from .prefill import ( - single_prefill_with_kv_cache, - single_prefill_with_kv_cache_return_lse, - BatchPrefillWithRaggedKVCacheWrapper, - BatchPrefillWithPagedKVCacheWrapper, -) -from .sparse import BlockSparseAttentionWrapper -from .cascade import ( - merge_state, - merge_state_in_place, - merge_states, - BatchDecodeWithSharedPrefixPagedKVCacheWrapper, - BatchPrefillWithSharedPrefixPagedKVCacheWrapper, -) -from .page import append_paged_kv_cache -from .sampling import ( - sampling_from_probs, - top_p_sampling_from_probs, - top_k_sampling_from_probs, - top_k_top_p_sampling_from_probs, - top_p_renorm_prob, - top_k_renorm_prob, - chain_speculative_sampling, -) -from .norm import rmsnorm -from .rope import ( - apply_rope_inplace, - apply_llama31_rope_inplace, - apply_rope, - apply_llama31_rope, -) +from .cascade import (BatchDecodeWithSharedPrefixPagedKVCacheWrapper, + BatchPrefillWithSharedPrefixPagedKVCacheWrapper, + merge_state, merge_state_in_place, merge_states) +from .decode import (BatchDecodeWithPagedKVCacheWrapper, + CUDAGraphBatchDecodeWithPagedKVCacheWrapper, + single_decode_with_kv_cache) from .group_gemm import SegmentGEMMWrapper +from .norm import fused_add_rmsnorm, rmsnorm +from .page import append_paged_kv_cache +from .prefill import (BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, + single_prefill_with_kv_cache, + single_prefill_with_kv_cache_return_lse) from .quantization import packbits, segment_packbits +from .rope import (apply_llama31_rope, apply_llama31_rope_inplace, apply_rope, + apply_rope_inplace) +from .sampling import (chain_speculative_sampling, sampling_from_probs, + top_k_renorm_prob, top_k_sampling_from_probs, + top_k_top_p_sampling_from_probs, top_p_renorm_prob, + top_p_sampling_from_probs) +from .sparse import BlockSparseAttentionWrapper try: from ._build_meta import __version__ diff --git a/python/flashinfer/norm.py b/python/flashinfer/norm.py index f9f44b6c..63a078ff 100644 --- a/python/flashinfer/norm.py +++ b/python/flashinfer/norm.py @@ -20,8 +20,8 @@ try: from . import _kernels except ImportError as e: - import os import logging + import os if os.environ.get("BUILD_DOC", "0") == "1": _kernels = None @@ -30,21 +30,42 @@ raise e -def rmsnorm(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: +def rmsnorm( + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: r"""Root mean square normalization. Parameters ---------- - x: torch.Tensor + input: torch.Tensor Input tensor, shape (batch_size, hidden_size). - w: torch.Tensor + weight: torch.Tensor Weight tensor, shape (hidden_size,). eps: float Epsilon for numerical stability. Returns ------- - y: torch.Tensor + output: torch.Tensor Normalized tensor, shape (batch_size, hidden_size). """ - return _kernels.rmsnorm(x, w, eps) + return _kernels.rmsnorm(input, weight, eps) + + +def fused_add_rmsnorm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +): + r"""Fused add root mean square normalization. + + Parameters + ---------- + input: torch.Tensor + Input tensor, shape (batch_size, hidden_size). + residual: torch.Tensor + Residual tensor, shape (batch_size, hidden_size). + weight: torch.Tensor + Weight tensor, shape (hidden_size,). + eps: float + Epsilon for numerical stability. + """ + _kernels.fused_add_rmsnorm(input, residual, weight, eps) diff --git a/python/tests/test_norm.py b/python/tests/test_norm.py index cb5dc477..79877cb0 100644 --- a/python/tests/test_norm.py +++ b/python/tests/test_norm.py @@ -15,8 +15,9 @@ """ import numpy -import torch import pytest +import torch + import flashinfer @@ -28,6 +29,18 @@ def _norm(x): return output * w +def fused_add_rms_norm(x, residual, weight, eps): + orig_dtype = x.dtype + x = x.to(torch.float32) + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x.to(orig_dtype) * weight + return x, residual + + @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192]) @pytest.mark.parametrize("dtype", [torch.float16]) @@ -43,5 +56,23 @@ def test_norm(batch_size, hidden_size, dtype): ) -if __name__ == "__main__": - test_norm(1, 111, torch.float16) +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_fused_add_rmsnorm(batch_size, hidden_size, dtype): + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) + weight = torch.randn(hidden_size, dtype=dtype, device="cuda") + + x_native, residual_native = fused_add_rms_norm( + x.clone(), residual.clone(), weight, eps + ) + + x_fused = x.clone() + residual_fused = residual.clone() + flashinfer.fused_add_rmsnorm(x_fused, residual_fused, weight, eps) + + torch.testing.assert_close(x_fused, x_native, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(residual_fused, residual_native, rtol=1e-2, atol=1e-2)