diff --git a/docs/api/python/activation.rst b/docs/api/python/activation.rst new file mode 100644 index 00000000..9d2aa735 --- /dev/null +++ b/docs/api/python/activation.rst @@ -0,0 +1,18 @@ +.. _apiactivation: + +flashinfer.activation +===================== + +.. currentmodule:: flashinfer.activation + +This module provides a set of activation operations for up/gate layers in transformer MLPs. + +Up/Gate output activation +------------------------- + +.. autosummary:: + :toctree: ../../generated + + silu_and_mul + gelu_tanh_and_mul + gelu_and_mul diff --git a/docs/api/python/norm.rst b/docs/api/python/norm.rst index 9a9e0d49..c53b0112 100644 --- a/docs/api/python/norm.rst +++ b/docs/api/python/norm.rst @@ -11,3 +11,6 @@ Kernels for normalization layers. :toctree: _generate rmsnorm + fused_add_rmsnorm + gemma_rmsnorm + gemma_fused_add_rmsnorm diff --git a/docs/api/python/page.rst b/docs/api/python/page.rst index 66e64f68..7d20ebb5 100644 --- a/docs/api/python/page.rst +++ b/docs/api/python/page.rst @@ -14,3 +14,4 @@ Append new K/V tensors to Paged KV-Cache :toctree: ../../generated append_paged_kv_cache + get_batch_indices_positions diff --git a/docs/api/python/rope.rst b/docs/api/python/rope.rst index 636e069c..c113e2f0 100644 --- a/docs/api/python/rope.rst +++ b/docs/api/python/rope.rst @@ -14,3 +14,9 @@ Kernels for applying rotary embeddings. apply_llama31_rope_inplace apply_rope apply_llama31_rope + apply_rope_pos_ids + apply_rope_pos_ids_inplace + apply_llama31_rope_pos_ids + apply_llama31_rope_pos_ids_inplace + apply_rope_with_cos_sin_cache + apply_rope_with_cos_sin_cache_inplace diff --git a/docs/index.rst b/docs/index.rst index 0a4dd61e..d8af2d44 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -36,4 +36,5 @@ FlashInfer is a library for Large Language Models that provides high-performance api/python/gemm api/python/norm api/python/rope + api/python/activation api/python/quantization diff --git a/python/flashinfer/activation.py b/python/flashinfer/activation.py index cb81b4a8..356ff1ab 100644 --- a/python/flashinfer/activation.py +++ b/python/flashinfer/activation.py @@ -111,6 +111,8 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: r"""Fused SiLU and Mul operation. + ``silu(input[..., :hidden_size]) * input[..., hidden_size:]`` + Parameters ---------- input: torch.Tensor @@ -141,6 +143,8 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: r"""Fused GeLU Tanh and Mul operation. + ``gelu(tanh(input[..., :hidden_size])) * input[..., hidden_size:]`` + Parameters ---------- input: torch.Tensor @@ -171,6 +175,8 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: r"""Fused GeLU and Mul operation. + ``gelu(input[..., :hidden_size]) * input[..., hidden_size:]`` + Parameters ---------- input: torch.Tensor diff --git a/python/flashinfer/norm.py b/python/flashinfer/norm.py index df9a77c7..3cbb081a 100644 --- a/python/flashinfer/norm.py +++ b/python/flashinfer/norm.py @@ -50,6 +50,8 @@ def rmsnorm( ) -> torch.Tensor: r"""Root mean square normalization. + ``out[i] = (input[i] / RMS(input)) * weight[i]`` + Parameters ---------- input: torch.Tensor @@ -92,6 +94,12 @@ def fused_add_rmsnorm( ) -> None: r"""Fused add root mean square normalization. + Step 1: + ``residual[i] += input[i]`` + + Step 2: + ``input[i] = (residual[i] / RMS(residual)) * weight[i]`` + Parameters ---------- input: torch.Tensor @@ -119,7 +127,9 @@ def gemma_rmsnorm( eps: float = 1e-6, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: - r"""Gemma Root mean square normalization. + r"""Gemma-style root mean square normalization. + + ``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)`` Parameters ---------- @@ -163,7 +173,13 @@ def _gemma_rmsnorm_fake( def gemma_fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: - r"""Gemma Fused add root mean square normalization. + r"""Gemma-style fused add root mean square normalization. + + Step 1: + ``residual[i] += input[i]`` + + Step 2: + ``input[i] = (residual[i] / RMS(residual)) * (weight + 1)`` Parameters ---------- diff --git a/python/flashinfer/page.py b/python/flashinfer/page.py index 581d6207..75f2640e 100644 --- a/python/flashinfer/page.py +++ b/python/flashinfer/page.py @@ -151,11 +151,15 @@ def get_batch_indices_positions( >>> positions # the rightmost column index of each row tensor([4, 3, 4, 2, 3, 4, 1, 2, 3, 4], device='cuda:0', dtype=torch.int32) - Notes - ----- + Note + ---- This function is similar to `CSR2COO `_ conversion in cuSPARSE library, with the difference that we are converting from a ragged tensor (which don't require a column indices array) to a COO format. + + See Also + -------- + append_paged_kv_cache """ batch_size = append_indptr.size(0) - 1 batch_indices = torch.empty((nnz,), device=append_indptr.device, dtype=torch.int32) @@ -305,6 +309,10 @@ def append_paged_kv_cache( The function assumes that the space for appended k/v have already been allocated, which means :attr:`kv_indices`, :attr:`kv_indptr`, :attr:`kv_last_page_len` has incorporated appended k/v. + + See Also + -------- + get_batch_indices_positions """ _check_kv_layout(kv_layout) _append_paged_kv_cache_kernel(