From 9ee26e73921c664a38ccf81d824d266d6a8501a6 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Tue, 27 Aug 2024 20:11:14 +1000 Subject: [PATCH] feat: add gelu_and_mul (#474) for gemma --- include/flashinfer/activation.cuh | 6 ++++++ python/csrc/activation.cu | 18 ++++++++++++++++++ python/csrc/flashinfer_ops.cu | 1 + python/csrc/flashinfer_ops.h | 2 ++ python/flashinfer/__init__.py | 2 +- python/flashinfer/activation.py | 30 ++++++++++++++++++++++++++++++ python/tests/test_activation.py | 13 +++++++++++-- 7 files changed, 69 insertions(+), 3 deletions(-) diff --git a/include/flashinfer/activation.cuh b/include/flashinfer/activation.cuh index 80fb5516..67bee024 100644 --- a/include/flashinfer/activation.cuh +++ b/include/flashinfer/activation.cuh @@ -31,6 +31,12 @@ __device__ __forceinline__ float silu_kernel(const float& val) { return val / (1.0f + __expf(-val)); } +// https://github.com/pytorch/pytorch/blob/f48038527792814b06dafa6d471acb04c837b972/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 +__device__ __forceinline__ float gelu_kernel(const float& val) { + constexpr float kAlpha = M_SQRT1_2; + return val * 0.5f * (1.0f + ::erf(val * kAlpha)); +} + template __device__ __forceinline__ T gelu_tanh_kernel(const T& val) { const float cdf = diff --git a/python/csrc/activation.cu b/python/csrc/activation.cu index 4830334b..ef3a781a 100644 --- a/python/csrc/activation.cu +++ b/python/csrc/activation.cu @@ -58,3 +58,21 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) { return true; }); } + +void gelu_and_mul(torch::Tensor& out, torch::Tensor& input) { + int d = input.size(-1) / 2; + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(std::min(d / vec_size, 1024U)); + flashinfer::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), + static_cast(input.data_ptr()), d); + + return true; + }); +} diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index cb4343a1..aa28b2b0 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -41,6 +41,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); + m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, "Apply Llama 3.1 style RoPE in-place"); diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index ff5ec5e7..86c921f8 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -81,6 +81,8 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); +void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); + void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index a89f985d..1871cfb0 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -14,7 +14,7 @@ limitations under the License. """ -from .activation import gelu_tanh_and_mul, silu_and_mul +from .activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from .cascade import ( BatchDecodeWithSharedPrefixPagedKVCacheWrapper, BatchPrefillWithSharedPrefixPagedKVCacheWrapper, diff --git a/python/flashinfer/activation.py b/python/flashinfer/activation.py index ce889813..eb9732b2 100644 --- a/python/flashinfer/activation.py +++ b/python/flashinfer/activation.py @@ -100,3 +100,33 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te ) _kernels.gelu_tanh_and_mul(out, input) return out + + +def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + r"""Fused GeLU and Mul operation. + + Parameters + ---------- + input: torch.Tensor + Input tensor, shape (..., 2 * hidden_size). + + out: Optional[torch.Tensor] + The the output tensor, if specified, the kernel will update this tensor inplace. + + Returns + ------- + output: torch.Tensor + Output tensor, shape (..., hidden_size). + """ + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + _kernels.gelu_and_mul(out, input) + return out diff --git a/python/tests/test_activation.py b/python/tests/test_activation.py index 611c4aac..c370f384 100644 --- a/python/tests/test_activation.py +++ b/python/tests/test_activation.py @@ -20,7 +20,6 @@ import flashinfer - @pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) @pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) @@ -32,7 +31,6 @@ def test_fused_silu_mul(dim, batch_size, seq_len): y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3 ) - @pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) @pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) @@ -43,3 +41,14 @@ def test_fused_gelu_tanh_mul(dim, batch_size, seq_len): numpy.testing.assert_allclose( y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3 ) + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_gelu_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none") + y = flashinfer.activation.gelu_and_mul(x) + numpy.testing.assert_allclose( + y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3 + )