Skip to content

Commit

Permalink
feat: support fused add rmsnorm (#419)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs authored Aug 4, 2024
1 parent 1c9ffb3 commit b781513
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 79 deletions.
132 changes: 115 additions & 17 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace flashinfer {
namespace norm {

template <uint32_t VEC_SIZE, typename T>
__global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restrict__ y,
__global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* __restrict__ output,
const uint32_t d, float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
Expand All @@ -43,14 +43,14 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric
float sum_sq = 0.f;

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> x_vec;
x_vec.fill(0);
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
x_vec.load(x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
sum_sq += float(x_vec[j]) * float(x_vec[j]);
sum_sq += float(input_vec[j]) * float(input_vec[j]);
}
}

Expand All @@ -76,36 +76,36 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric
float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> x_vec;
vec_t<T, VEC_SIZE> w_vec;
vec_t<T, VEC_SIZE> y_vec;
x_vec.fill(0);
w_vec.fill(0);
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> output_vec;
input_vec.fill(0);
weight_vec.fill(0);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
x_vec.load(x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
w_vec.load(w + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
y_vec[j] = float(x_vec[j]) * rms_rcp * float(w_vec[j]);
output_vec[j] = float(input_vec[j]) * rms_rcp * float(weight_vec[j]);
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
y_vec.store(y + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
output_vec.store(output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}
}

template <typename T>
cudaError_t RMSNorm(T* x, T* w, T* y, uint32_t batch_size, uint32_t d, float eps = 1e-5,
cudaStream_t stream = 0) {
cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
float eps = 1e-5, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
void* args[] = {&x, &w, &y, &d, &eps};
void* args[] = {&input, &weight, &output, &d, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = RMSNormKernel<VEC_SIZE, T>;
Expand All @@ -114,6 +114,104 @@ cudaError_t RMSNorm(T* x, T* w, T* y, uint32_t batch_size, uint32_t d, float eps
return cudaSuccess;
}

template <uint32_t VEC_SIZE, typename T>
__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual,
T* __restrict__ weight, const uint32_t d, float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t warp_size = 32;
const uint32_t num_warps = blockDim.y;
const uint32_t thread_id = tx + ty * warp_size;
const uint32_t num_threads = num_warps * warp_size;
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
extern __shared__ float smem[];

float sum_sq = 0.f;

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0);
vec_t<T, VEC_SIZE> residual_vec;
residual_vec.fill(0);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
float x = float(input_vec[j]);
x += float(residual_vec[j]);
sum_sq += x * x;
residual_vec[j] = (T)x;
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}

// first, warp reduce sum
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}

smem[ty] = sum_sq;
__syncthreads();
// then, cross warp reduce sum using only the first warp
if (ty == 0) {
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}
smem[0] = sum_sq;
}
__syncthreads();

float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> residual_vec;
input_vec.fill(0);
weight_vec.fill(0);
residual_vec.fill(0);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]);
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}
}

template <typename T>
cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
float eps = 1e-5, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
void* args[] = {&input, &residual, &weight, &d, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});

return cudaSuccess;
}

} // namespace norm

} // namespace flashinfer
Expand Down
1 change: 1 addition & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("chain_speculative_sampling", &chain_speculative_sampling,
"Speculative sampling from sequence of probabilities");
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization");
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
"Apply Llama 3.1 style RoPE in-place");
Expand Down
5 changes: 4 additions & 1 deletion python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso
torch::Tensor uniform_samples, torch::Tensor target_probs,
bool deterministic);

torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps);
torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps);

void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
double eps);

void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta);
Expand Down
62 changes: 46 additions & 16 deletions python/csrc/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,56 @@

using namespace flashinfer;

torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps) {
CHECK_INPUT(x);
CHECK_INPUT(w);
auto device = x.device();
CHECK_EQ(w.device(), device);
CHECK_DIM(2, x); // x: (batch_size, hidden_size)
CHECK_DIM(1, w); // w: (hidden_size)
CHECK_EQ(x.size(1), w.size(0));
unsigned int batch_size = x.size(0);
unsigned int hidden_size = x.size(1);
torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps) {
CHECK_INPUT(input);
CHECK_INPUT(weight);
auto device = input.device();
CHECK_EQ(weight.device(), device);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
CHECK_EQ(input.size(1), weight.size(0));
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto y = torch::empty_like(x);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(x.scalar_type(), c_type, [&] {
cudaError_t status = norm::RMSNorm(
static_cast<c_type*>(x.data_ptr()), static_cast<c_type*>(w.data_ptr()),
static_cast<c_type*>(y.data_ptr()), batch_size, hidden_size, eps, torch_current_stream);
auto output = torch::empty_like(input);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::RMSNorm(static_cast<c_type*>(input.data_ptr()),
static_cast<c_type*>(weight.data_ptr()),
static_cast<c_type*>(output.data_ptr()), batch_size,
hidden_size, eps, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"RMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
return true;
});
return y;
return output;
}

void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
double eps) {
CHECK_INPUT(input);
CHECK_INPUT(residual);
CHECK_INPUT(weight);
auto device = input.device();
CHECK_EQ(residual.device(), device);
CHECK_EQ(weight.device(), device);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
CHECK_EQ(input.size(0), residual.size(0));
CHECK_EQ(input.size(1), residual.size(1));
CHECK_EQ(input.size(1), weight.size(0));
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::FusedAddRMSNorm(static_cast<c_type*>(input.data_ptr()),
static_cast<c_type*>(residual.data_ptr()),
static_cast<c_type*>(weight.data_ptr()), batch_size,
hidden_size, eps, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "FusedAddRMSNorm failed with error code " +
std::string(cudaGetErrorString(status)));
return true;
});
}
55 changes: 19 additions & 36 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,27 @@
limitations under the License.
"""

from .decode import (
single_decode_with_kv_cache,
BatchDecodeWithPagedKVCacheWrapper,
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
)
from .prefill import (
single_prefill_with_kv_cache,
single_prefill_with_kv_cache_return_lse,
BatchPrefillWithRaggedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
from .sparse import BlockSparseAttentionWrapper
from .cascade import (
merge_state,
merge_state_in_place,
merge_states,
BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
)
from .page import append_paged_kv_cache
from .sampling import (
sampling_from_probs,
top_p_sampling_from_probs,
top_k_sampling_from_probs,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
top_k_renorm_prob,
chain_speculative_sampling,
)
from .norm import rmsnorm
from .rope import (
apply_rope_inplace,
apply_llama31_rope_inplace,
apply_rope,
apply_llama31_rope,
)
from .cascade import (BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
merge_state, merge_state_in_place, merge_states)
from .decode import (BatchDecodeWithPagedKVCacheWrapper,
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
single_decode_with_kv_cache)
from .group_gemm import SegmentGEMMWrapper
from .norm import fused_add_rmsnorm, rmsnorm
from .page import append_paged_kv_cache
from .prefill import (BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
single_prefill_with_kv_cache,
single_prefill_with_kv_cache_return_lse)
from .quantization import packbits, segment_packbits
from .rope import (apply_llama31_rope, apply_llama31_rope_inplace, apply_rope,
apply_rope_inplace)
from .sampling import (chain_speculative_sampling, sampling_from_probs,
top_k_renorm_prob, top_k_sampling_from_probs,
top_k_top_p_sampling_from_probs, top_p_renorm_prob,
top_p_sampling_from_probs)
from .sparse import BlockSparseAttentionWrapper

try:
from ._build_meta import __version__
Expand Down
33 changes: 27 additions & 6 deletions python/flashinfer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
try:
from . import _kernels
except ImportError as e:
import os
import logging
import os

if os.environ.get("BUILD_DOC", "0") == "1":
_kernels = None
Expand All @@ -30,21 +30,42 @@
raise e


def rmsnorm(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
def rmsnorm(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
r"""Root mean square normalization.
Parameters
----------
x: torch.Tensor
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
w: torch.Tensor
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
Returns
-------
y: torch.Tensor
output: torch.Tensor
Normalized tensor, shape (batch_size, hidden_size).
"""
return _kernels.rmsnorm(x, w, eps)
return _kernels.rmsnorm(input, weight, eps)


def fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
):
r"""Fused add root mean square normalization.
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
"""
_kernels.fused_add_rmsnorm(input, residual, weight, eps)
Loading

0 comments on commit b781513

Please # to comment.