diff --git a/include/flashinfer/activation.cuh b/include/flashinfer/activation.cuh index 56196ffb..80fb5516 100644 --- a/include/flashinfer/activation.cuh +++ b/include/flashinfer/activation.cuh @@ -17,6 +17,7 @@ #ifndef FLASHINFER_ACTIVATION_CUH_ #define FLASHINFER_ACTIVATION_CUH_ +#include "math.cuh" #include "utils.cuh" #include "vec_dtypes.cuh" @@ -30,6 +31,13 @@ __device__ __forceinline__ float silu_kernel(const float& val) { return val / (1.0f + __expf(-val)); } +template +__device__ __forceinline__ T gelu_tanh_kernel(const T& val) { + const float cdf = + 0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val)))); + return val * cdf; +} + template __global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { constexpr uint32_t vec_size = 16 / sizeof(T); diff --git a/python/csrc/activation.cu b/python/csrc/activation.cu index 7f780323..e4dcf7c0 100644 --- a/python/csrc/activation.cu +++ b/python/csrc/activation.cu @@ -40,3 +40,22 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { return true; }); } + +void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) { + int d = input.size(-1) / 2; + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(std::min(d / vec_size, 1024U)); + flashinfer::activation::act_and_mul_kernel + <<>>(static_cast(out.data_ptr()), + static_cast(input.data_ptr()), d); + + return true; + }); +} diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 42efb545..9bfa6503 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -40,6 +40,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); + m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, "Apply Llama 3.1 style RoPE in-place"); diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 43971b90..073b7219 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -78,6 +78,8 @@ void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tenso void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); + void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index 2b8f187e..b1b37fc6 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -26,7 +26,7 @@ CUDAGraphBatchDecodeWithPagedKVCacheWrapper, single_decode_with_kv_cache, ) -from .activation import silu_and_mul +from .activation import gelu_tanh_and_mul, silu_and_mul from .group_gemm import SegmentGEMMWrapper from .norm import fused_add_rmsnorm, rmsnorm from .page import append_paged_kv_cache diff --git a/python/flashinfer/activation.py b/python/flashinfer/activation.py index 945310ed..ce889813 100644 --- a/python/flashinfer/activation.py +++ b/python/flashinfer/activation.py @@ -14,9 +14,10 @@ limitations under the License. """ -import torch from typing import Optional +import torch + # mypy: disable-error-code="attr-defined" try: from . import _kernels @@ -69,3 +70,33 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: ) _kernels.silu_and_mul(out, input) return out + + +def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + r"""Fused GeLU Tanh and Mul operation. + + Parameters + ---------- + input: torch.Tensor + Input tensor, shape (..., 2 * hidden_size). + + out: Optional[torch.Tensor] + The the output tensor, if specified, the kernel will update this tensor inplace. + + Returns + ------- + output: torch.Tensor + Output tensor, shape (..., hidden_size). + """ + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + _kernels.gelu_tanh_and_mul(out, input) + return out diff --git a/python/tests/test_activation.py b/python/tests/test_activation.py index d459d3b1..611c4aac 100644 --- a/python/tests/test_activation.py +++ b/python/tests/test_activation.py @@ -31,3 +31,15 @@ def test_fused_silu_mul(dim, batch_size, seq_len): numpy.testing.assert_allclose( y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3 ) + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_gelu_tanh_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh") + y = flashinfer.activation.gelu_tanh_and_mul(x) + numpy.testing.assert_allclose( + y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3 + )