Skip to content

Commit

Permalink
feat: support fused gelu tanh mul (#434)
Browse files Browse the repository at this point in the history
cc @yzh119 

```
pytest python/tests/test_activation.py
=================================================================== test session starts ===================================================================
platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0
rootdir: /flashinfer/python
plugins: anyio-4.2.0
collected 630 items

python/tests/test_activation.py ................................................................................................................... [ 18%]
................................................................................................................................................... [ 41%]
................................................................................................................................................... [ 64%]
................................................................................................................................................... [ 88%]
..........................................................................                                                                          [100%]

============================================================= 630 passed in 146.89s (0:02:26) =============================================================
```

---------

Co-authored-by: Zihao Ye <expye@outlook.com>
  • Loading branch information
zhyncs and yzh119 authored Aug 10, 2024
1 parent 949c328 commit 2c9d1c3
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 2 deletions.
8 changes: 8 additions & 0 deletions include/flashinfer/activation.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef FLASHINFER_ACTIVATION_CUH_
#define FLASHINFER_ACTIVATION_CUH_

#include "math.cuh"
#include "utils.cuh"
#include "vec_dtypes.cuh"

Expand All @@ -30,6 +31,13 @@ __device__ __forceinline__ float silu_kernel(const float& val) {
return val / (1.0f + __expf(-val));
}

template <typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& val) {
const float cdf =
0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val))));
return val * cdf;
}

template <typename T, float (*Activation)(const float&)>
__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
constexpr uint32_t vec_size = 16 / sizeof(T);
Expand Down
19 changes: 19 additions & 0 deletions python/csrc/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,22 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
return true;
});
}

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
flashinfer::activation::act_and_mul_kernel<c_type,
flashinfer::activation::gelu_tanh_kernel>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
static_cast<c_type*>(input.data_ptr()), d);

return true;
});
}
1 change: 1 addition & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization");
m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul");
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul");
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
"Apply Llama 3.1 style RoPE in-place");
Expand Down
2 changes: 2 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tenso

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta);

Expand Down
2 changes: 1 addition & 1 deletion python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
single_decode_with_kv_cache,
)
from .activation import silu_and_mul
from .activation import gelu_tanh_and_mul, silu_and_mul
from .group_gemm import SegmentGEMMWrapper
from .norm import fused_add_rmsnorm, rmsnorm
from .page import append_paged_kv_cache
Expand Down
33 changes: 32 additions & 1 deletion python/flashinfer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
limitations under the License.
"""

import torch
from typing import Optional

import torch

# mypy: disable-error-code="attr-defined"
try:
from . import _kernels
Expand Down Expand Up @@ -69,3 +70,33 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
)
_kernels.silu_and_mul(out, input)
return out


def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
r"""Fused GeLU Tanh and Mul operation.
Parameters
----------
input: torch.Tensor
Input tensor, shape (..., 2 * hidden_size).
out: Optional[torch.Tensor]
The the output tensor, if specified, the kernel will update this tensor inplace.
Returns
-------
output: torch.Tensor
Output tensor, shape (..., hidden_size).
"""
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
_kernels.gelu_tanh_and_mul(out, input)
return out
12 changes: 12 additions & 0 deletions python/tests/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,15 @@ def test_fused_silu_mul(dim, batch_size, seq_len):
numpy.testing.assert_allclose(
y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3
)


@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
def test_fused_gelu_tanh_mul(dim, batch_size, seq_len):
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh")
y = flashinfer.activation.gelu_tanh_and_mul(x)
numpy.testing.assert_allclose(
y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3
)

0 comments on commit 2c9d1c3

Please # to comment.