diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index a064f235..c72cf7b8 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -281,10 +281,22 @@ class MultiLevelCascadeAttentionWrapper: ... >>> outputs[0].shape torch.Size([7, 64, 128]) + + See Also + -------- + BatchPrefillWithPagedKVCacheWrapper """ def __init__( - self, num_levels, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD" + self, + num_levels, + float_workspace_buffer: torch.Tensor, + kv_layout: str = "NHD", + use_cuda_graph: bool = False, + qo_indptr_buf_arr: Optional[list[torch.Tensor]] = None, + paged_kv_indptr_buf_arr: Optional[list[torch.Tensor]] = None, + paged_kv_indices_buf_arr: Optional[list[torch.Tensor]] = None, + paged_kv_last_page_len_buf_arr: Optional[list[torch.Tensor]] = None, ) -> None: r"""Constructor of :class:`MultiLevelCascadeAttentionWrapper`. @@ -298,14 +310,59 @@ def __init__( buffer should be the same as the device of the input tensors. kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + use_cuda_graph : bool + Whether to use CUDA graph to capture the kernels, if enabled, the auxiliary data structures + will be stored in provided buffers. + qo_indptr_buf_arr : Optional[List[torch.Tensor]] + An array of qo indptr buffers for each level, the array length should be equal to + the number of levels. + The last element of each tensor should be the total number of queries/outputs. + paged_kv_indptr_buf_arr : Optional[List[torch.Tensor]] + An array of paged kv-cache indptr buffers for each level, the array length should be + equal to the number of levels. + paged_kv_indices_buf_arr : Optional[List[torch.Tensor]] + An array of paged kv-cache indices buffers for each level, the array length should be + equal to the number of levels. + paged_kv_last_page_len_buf_arr : Optional[List[torch.Tensor]] + An array of paged kv-cache last page length buffers for each level, the array length + should be equal to the number of levels. """ - self._batch_prefill_wrappers = [ - BatchPrefillWithPagedKVCacheWrapper(float_workspace_buffer, kv_layout) - for _ in range(num_levels) - ] + self._use_cuda_graph = use_cuda_graph + if use_cuda_graph: + self._batch_prefill_wrappers = [ + BatchPrefillWithPagedKVCacheWrapper( + float_workspace_buffer, + kv_layout, + use_cuda_graph=True, + qo_indptr_buf=qo_indptr_buf, + paged_kv_indptr_buf=paged_kv_indptr_buf, + paged_kv_indices_buf=paged_kv_indices_buf, + paged_kv_last_page_len_buf=paged_kv_last_page_len_buf, + ) + for ( + qo_indptr_buf, + paged_kv_indptr_buf, + paged_kv_indices_buf, + paged_kv_last_page_len_buf, + ) in zip( + qo_indptr_buf_arr, + paged_kv_indptr_buf_arr, + paged_kv_indices_buf_arr, + paged_kv_last_page_len_buf_arr, + ) + ] + else: + self._batch_prefill_wrappers = [ + BatchPrefillWithPagedKVCacheWrapper(float_workspace_buffer, kv_layout) + for _ in range(num_levels) + ] self._num_levels = num_levels self._kv_layout = kv_layout + @property + def is_cuda_graph_enabled(self) -> bool: + return self._use_cuda_graph + def reset_workspace_buffer( self, float_workspace_buffer: torch.Tensor, @@ -912,7 +969,7 @@ def forward( k_shared: torch.Tensor, v_shared: torch.Tensor, unique_kv_cache: torch.Tensor, - causal: bool = True, + causal: bool = False, allow_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 4ef29206..a39dc014 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -747,7 +747,7 @@ def __init__( use_cuda_graph : bool Whether to enable CUDA graph capture for the prefill kernels, if enabled, the - auxiliary data structures will be stored as provided buffers. The ``batch_size`` + auxiliary data structures will be stored in provided buffers. The ``batch_size`` cannot change during the lifecycle of this wrapper when CUDAGraph is enabled. qo_indptr_buf : Optional[torch.Tensor] @@ -1095,7 +1095,7 @@ def forward( self, q: torch.Tensor, paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - causal: bool = True, + causal: bool = False, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, k_scale: Optional[float] = None, @@ -1240,7 +1240,7 @@ def forward_return_lse( self, q: torch.Tensor, paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - causal: bool = True, + causal: bool = False, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, k_scale: Optional[float] = None, @@ -1491,7 +1491,7 @@ def plan( head_dim: int, custom_mask: Optional[torch.Tensor] = None, packed_custom_mask: Optional[torch.Tensor] = None, - causal: bool = True, + causal: bool = False, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, window_left: int = -1, @@ -1683,7 +1683,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - causal: bool = True, + causal: bool = False, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, window_left: int = -1, @@ -1812,7 +1812,7 @@ def forward_return_lse( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - causal: bool = True, + causal: bool = False, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, window_left: int = -1, diff --git a/tests/test_shared_prefix_kernels.py b/tests/test_shared_prefix_kernels.py index dbee5542..af478fc3 100644 --- a/tests/test_shared_prefix_kernels.py +++ b/tests/test_shared_prefix_kernels.py @@ -29,7 +29,7 @@ def ceil_div(a, b): @pytest.mark.parametrize("unique_kv_len", [37, 17]) @pytest.mark.parametrize("shared_kv_len", [128, 512, 2048]) @pytest.mark.parametrize("num_heads", [8, 16]) -@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("page_size", [1, 16]) def test_batch_attention_with_shared_prefix_paged_kv_cache(