Skip to content

Commit

Permalink
feat: add gelu_and_mul (#474)
Browse files Browse the repository at this point in the history
for gemma
  • Loading branch information
zhyncs authored Aug 27, 2024
1 parent 2a6963f commit 9ee26e7
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 3 deletions.
6 changes: 6 additions & 0 deletions include/flashinfer/activation.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ __device__ __forceinline__ float silu_kernel(const float& val) {
return val / (1.0f + __expf(-val));
}

// https://github.com/pytorch/pytorch/blob/f48038527792814b06dafa6d471acb04c837b972/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
__device__ __forceinline__ float gelu_kernel(const float& val) {
constexpr float kAlpha = M_SQRT1_2;
return val * 0.5f * (1.0f + ::erf(val * kAlpha));
}

template <typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& val) {
const float cdf =
Expand Down
18 changes: 18 additions & 0 deletions python/csrc/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,21 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input) {
return true;
});
}

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
dim3 block(std::min(d / vec_size, 1024U));
flashinfer::activation::act_and_mul_kernel<c_type, flashinfer::activation::gelu_kernel>
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()),
static_cast<c_type*>(input.data_ptr()), d);

return true;
});
}
1 change: 1 addition & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization");
m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul");
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul");
m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul");
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
"Apply Llama 3.1 style RoPE in-place");
Expand Down
2 changes: 2 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta);

Expand Down
2 changes: 1 addition & 1 deletion python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
"""

from .activation import gelu_tanh_and_mul, silu_and_mul
from .activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from .cascade import (
BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
Expand Down
30 changes: 30 additions & 0 deletions python/flashinfer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,33 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
)
_kernels.gelu_tanh_and_mul(out, input)
return out


def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
r"""Fused GeLU and Mul operation.
Parameters
----------
input: torch.Tensor
Input tensor, shape (..., 2 * hidden_size).
out: Optional[torch.Tensor]
The the output tensor, if specified, the kernel will update this tensor inplace.
Returns
-------
output: torch.Tensor
Output tensor, shape (..., hidden_size).
"""
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
_kernels.gelu_and_mul(out, input)
return out
13 changes: 11 additions & 2 deletions python/tests/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import flashinfer


@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
Expand All @@ -32,7 +31,6 @@ def test_fused_silu_mul(dim, batch_size, seq_len):
y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3
)


@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
Expand All @@ -43,3 +41,14 @@ def test_fused_gelu_tanh_mul(dim, batch_size, seq_len):
numpy.testing.assert_allclose(
y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3
)

@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
def test_fused_gelu_mul(dim, batch_size, seq_len):
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none")
y = flashinfer.activation.gelu_and_mul(x)
numpy.testing.assert_allclose(
y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3
)

0 comments on commit 9ee26e7

Please # to comment.