Skip to content

Commit

Permalink
feat: add MultiLevelCascadeAttentionWrapper API (#462)
Browse files Browse the repository at this point in the history
Our existing cascade inference APIs all assumes shared prefix kv-cache
are standalone tensors which is not the case for real-world llm serving.

This PR adds a more general `MultiLevelCascadeAttentionWrapper` API
which not only supports multi-level cascade inference, and the kv-cache
of all levels are stored in the unified paged kv-cache, which can
seamlessly integrate with existing LLM serving frameworks.

Tutorials, tests and examples are updated correspondingly.

The old `BatchDecodeWithSharedPrefixPagedKVCacheWrapper` and
`BatchPrefillWithSharedPrefixPagedKVCacheWrapper` should be deprecated,
starting from 0.2.0.
  • Loading branch information
yzh119 authored Aug 22, 2024
1 parent c1f576a commit 1e37989
Show file tree
Hide file tree
Showing 6 changed files with 306 additions and 55 deletions.
4 changes: 4 additions & 0 deletions docs/api/python/cascade.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ Cascade Attention
Cascade Attention Wrapper Classes
---------------------------------

.. autoclass:: MultiLevelCascadeAttentionWrapper
:members:


.. autoclass:: BatchDecodeWithSharedPrefixPagedKVCacheWrapper
:members:

Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
author = "FlashInfer Contributors"
copyright = "2023-2024, {}".format(author)

version = "0.1.4"
release = "0.1.4"
version = "0.1.5"
release = "0.1.5"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
18 changes: 18 additions & 0 deletions docs/tutorials/kv_layout.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@ shape ``(indptr[-1], num_heads, head_dim)`` when the layout is ``NHD``.

We can use ``data[indptr[i]:indptr[i+1]]`` to slice the keys (or values) of request ``i``.

.. _cascade-qo-indptr-layout:

Multi-level Cascade Inference Query/Output Layout
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

When using multi-level `cascade inference <https://flashinfer.ai/2024/02/02/cascade-inference.html>`_,
the query and output of each level are stored in ragged tensors, each level's ``qo_indptr`` array stores
the interval information of each node in the cascade tree at that level, the figure below shows the
``qo_indptr`` for each level in cascade inference:

.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/cascade_qo_indptr.png
:width: 800
:align: center
:alt: The ``qo_indptr`` for each level in cascade inference.

Note that each level's ``qo_indptr`` array should start from 0, and the last element of the ``qo_indptr`` array
should be equal to the sum of length for all query/output tensors.

FlashInfer APIs
~~~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

from .cascade import (
MultiLevelCascadeAttentionWrapper,
BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
merge_state,
Expand Down
243 changes: 241 additions & 2 deletions python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,238 @@ def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.
return _kernels.merge_states(v, s)


class MultiLevelCascadeAttentionWrapper:
r"""Attention wrapper for memory efficient multi-level cascade inference, this API assumes all
levels KV-Cache are stored in a unified paged table.
Check :ref:`our tutorial<page-layout>` for page table layout, and
`Cascade Inference Query/Output Layout <cascade-qo-indptr-layout>` for query/output layout.
The idea of cascade inference is introduced in our `blog post <https://flashinfer.ai/2024/02/02/cascade-inference.html>`_.
Example
-------
>>> import torch
>>> import flashinfer
>>> num_layers = 32
>>> num_qo_heads = 64
>>> num_kv_heads = 8
>>> head_dim = 128
>>> page_size = 16
>>> # allocate 128MB workspace buffer
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> wrapper = flashinfer.MultiLevelCascadeAttentionWrapper(
... 2, workspace_buffer, "NHD"
... )
>>> batch_size = 7
>>> shared_kv_num_pages = 512
>>> unique_kv_num_pages = 128
>>> total_num_pages = shared_kv_num_pages + unique_kv_num_pages
>>> shared_kv_page_indices = torch.arange(shared_kv_num_pages).int().to("cuda:0")
>>> shared_kv_page_indptr = torch.tensor([0, shared_kv_num_pages], dtype=torch.int32, device="cuda:0")
>>> unique_kv_page_indices = torch.arange(shared_kv_num_pages, total_num_pages).int().to("cuda:0")
>>> unique_kv_page_indptr = torch.tensor(
... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
... )
>>> shared_kv_last_page_len = torch.tensor([page_size], dtype=torch.int32, device="cuda:0")
>>> # 1 <= kv_last_page_len <= page_size
>>> unique_kv_last_page_len = torch.tensor(
... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0"
... )
>>> kv_cache_at_layer = [
... torch.randn(
... total_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
... ) for _ in range(num_layers)
... ]
>>> qo_indptr_arr = [
... torch.tensor([0, batch_size], dtype=torch.int32, device="cuda:0"), # top-level for shared KV-Cache
... torch.arange(batch_size + 1, dtype=torch.int32, device="cuda:0") # bottom-level for unique KV-Cache
... ]
>>> # create auxiliary data structures for batch decode attention
>>> wrapper.begin_forward(
... qo_indptr_arr,
... [shared_kv_page_indptr, unique_kv_page_indptr],
... [shared_kv_page_indices, unique_kv_page_indices],
... [shared_kv_last_page_len, unique_kv_last_page_len],
... num_qo_heads,
... num_kv_heads,
... head_dim,
... page_size,
... )
>>> outputs = []
>>> for i in range(num_layers):
... q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0")
... # compute batch decode attention, reuse auxiliary data structures for all layers
... o = wrapper.forward(q, kv_cache_at_layer[i])
... outputs.append(o)
...
>>> # clear auxiliary data structures
>>> wrapper.end_forward()
>>> outputs[0].shape
torch.Size([7, 64, 128])
"""

def __init__(
self, num_levels, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD"
) -> None:
r"""Constructor of :class:`MultiLevelCascadeAttentionWrapper`.
Parameters
----------
num_levels : int
The number of levels in the cascade attention.
float_workspace_buffer : torch.Tensor
The user reserved float workspace buffer used to store intermediate attention results
in the split-k algorithm. The recommended size is 128MB, the device of the workspace
buffer should be the same as the device of the input tensors.
kv_layout : str
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
"""
self._batch_prefill_wrappers = [
BatchPrefillWithPagedKVCacheWrapper(float_workspace_buffer, kv_layout)
for _ in range(num_levels)
]
self._kv_layout = kv_layout

def reset_workspace_buffer(
self,
float_workspace_buffer: torch.Tensor,
int_workspace_buffers: list[torch.Tensor],
) -> None:
r"""Reset the workspace buffer.
Parameters
----------
float_workspace_buffer : torch.Tensor
The new float workspace buffer, the device of the new float workspace buffer should
be the same as the device of the input tensors.
int_workspace_buffer : torch.Tensor
The new int workspace buffer, the device of the new int workspace buffer should
be the same as the device of the input tensors.
"""
for wrapper, int_workspace_buffer in zip(
self._batch_prefill_wrappers, int_workspace_buffers
):
wrapper.reset_workspace_buffer(float_workspace_buffer, int_workspace_buffer)

def begin_forward(
self,
qo_indptr_arr: list[torch.Tensor],
paged_kv_indptr_arr: list[torch.Tensor],
paged_kv_indices_arr: list[torch.Tensor],
paged_kv_last_page_len: list[torch.Tensor],
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
page_size: int,
):
r"""Create auxiliary data structures for multi-level cascade attention for multiple
forward calls within the same decode step.
Parameters
----------
qo_indptr_arr : list[torch.Tensor]
An array of qo indptr tensors for each level, the array length should be equal to
the number of levels. Check
`Cascade Inference Query/Output Layout <cascade-qo-indptr-layout>` for query/output layout.
The last element of each tensor should be the total number of queries/outputs.
paged_kv_indptr_arr : list[torch.Tensor]
An array of paged kv-cache indptr tensors for each level, the array length should be
equal to the number of levels.
paged_kv_indices_arr : list[torch.Tensor]
An array of paged kv-cache indices tensors for each level, the array length should be
equal to the number of levels.
paged_kv_last_page_len : list[torch.Tensor]
An array of paged kv-cache last page length tensors for each level, the array length
should be equal to the number of levels.
num_qo_heads : int
The number of query/output heads.
num_kv_heads : int
The number of key/value heads.
head_dim : int
The dimension of the heads.
page_size : int
The page size of the paged kv-cache.
"""
for (
wrapper,
qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
) in zip(
self._batch_prefill_wrappers,
qo_indptr_arr,
paged_kv_indptr_arr,
paged_kv_indices_arr,
paged_kv_last_page_len,
):
wrapper.begin_forward(
qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
)

def end_forward(self):
r"""Clear auxiliary data structures created by :meth:`begin_forward`."""
for wrapper in self._batch_prefill_wrappers:
wrapper.end_forward()

def forward(
self,
q: torch.Tensor,
paged_kv_cache: torch.Tensor,
**kwargs,
):
r"""Compute multi-level cascade attention.
Parameters
----------
q : torch.Tensor
The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``.
paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
The paged KV-Cache stored as a tuple of tensors or a single tensor:
* a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape:
``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``,
and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``.
* a single 5-D tensor with shape:
``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
:attr:`kv_layout` is ``NHD``, and
``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
:attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
``paged_kv_cache[:, 1]`` is the value-cache.
"""
out, lse = self._batch_prefill_wrappers[-1].forward_return_lse(
q, paged_kv_cache, **kwargs
)
# NOTE(Zihao): causal mask should be False for all levels except the last level
kwargs["causal"] = False
for wrapper in self._batch_prefill_wrappers[:-1]:
out_i, lse_i = wrapper.forward_return_lse(q, paged_kv_cache, **kwargs)
merge_state_in_place(out, lse, out_i, lse_i)

return out


class BatchDecodeWithSharedPrefixPagedKVCacheWrapper:
r"""Wrapper class for decode attention with shared-prefix paged kv-cache for batch
of requests.
of requests. The shared-prefix KV-Cache was stored in a standalone tensors, and the
unique KV-Cache of each request was stored in a paged KV-Cache data stucture.
Check :ref:`our tutorial<page-layout>` for page table layout.
It is recommended to use :class:`MultiLevelCascadeAttentionWrapper` instead for general
multi-level cascade inference, where the KV-Cache of each level is stored in a unified
page table. This API will be deprecated in the future.
Example
-------
>>> import torch
Expand Down Expand Up @@ -328,6 +554,11 @@ def begin_forward(
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
is not equal to ``num_kv_heads``, the function will use
`grouped query attention <https://arxiv.org/abs/2305.13245>`_.
See Also
--------
MultiLevelCascadeAttentionWrapper
"""
self._batch_decode_wrapper.begin_forward(
unique_kv_indptr,
Expand Down Expand Up @@ -433,6 +664,10 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
Check :ref:`our tutorial<page-layout>` for paged kv-cache layout.
It is recommended to use :class:`MultiLevelCascadeAttentionWrapper` instead for general
multi-level cascade inference, where the KV-Cache of each level is stored in a unified
page table. This API will be deprecated in the future.
Example
-------
>>> import torch
Expand Down Expand Up @@ -533,7 +768,7 @@ def __init__(
self._kv_layout = kv_layout

def reset_workspace_buffer(
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor
) -> None:
r"""Reset the workspace buffer.
Expand Down Expand Up @@ -671,6 +906,10 @@ def forward(
-------
V : torch.Tensor
The attention output, shape: ``[qo_indptr[-1], num_heads, head_dim]``.
See Also
--------
MultiLevelCascadeAttentionWrapper
"""
V_shared, S_shared = single_prefill_with_kv_cache_return_lse(
q,
Expand Down
Loading

0 comments on commit 1e37989

Please # to comment.