diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index 18f07846..86fb4a5f 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -15,7 +15,7 @@ """ import math -from typing import Optional, Tuple +from typing import Optional, Tuple, List import torch # mypy: disable-error-code="attr-defined" @@ -273,7 +273,7 @@ def __init__( def reset_workspace_buffer( self, float_workspace_buffer: torch.Tensor, - int_workspace_buffers: list[torch.Tensor], + int_workspace_buffers: List[torch.Tensor], ) -> None: r"""Reset the workspace buffer. @@ -283,8 +283,8 @@ def reset_workspace_buffer( The new float workspace buffer, the device of the new float workspace buffer should be the same as the device of the input tensors. - int_workspace_buffer : torch.Tensor - The new int workspace buffer, the device of the new int workspace buffer should + int_workspace_buffers : List[torch.Tensor] + The array of new int workspace buffer, the device of the new int workspace buffer should be the same as the device of the input tensors. """ for wrapper, int_workspace_buffer in zip( @@ -294,10 +294,10 @@ def reset_workspace_buffer( def plan( self, - qo_indptr_arr: list[torch.Tensor], - paged_kv_indptr_arr: list[torch.Tensor], - paged_kv_indices_arr: list[torch.Tensor], - paged_kv_last_page_len: list[torch.Tensor], + qo_indptr_arr: List[torch.Tensor], + paged_kv_indptr_arr: List[torch.Tensor], + paged_kv_indices_arr: List[torch.Tensor], + paged_kv_last_page_len: List[torch.Tensor], num_qo_heads: int, num_kv_heads: int, head_dim: int, @@ -318,17 +318,17 @@ def plan( Parameters ---------- - qo_indptr_arr : list[torch.Tensor] + qo_indptr_arr : List[torch.Tensor] An array of qo indptr tensors for each level, the array length should be equal to the number of levels. The last element of each tensor should be the total number of queries/outputs. - paged_kv_indptr_arr : list[torch.Tensor] + paged_kv_indptr_arr : List[torch.Tensor] An array of paged kv-cache indptr tensors for each level, the array length should be equal to the number of levels. - paged_kv_indices_arr : list[torch.Tensor] + paged_kv_indices_arr : List[torch.Tensor] An array of paged kv-cache indices tensors for each level, the array length should be equal to the number of levels. - paged_kv_last_page_len : list[torch.Tensor] + paged_kv_last_page_len : List[torch.Tensor] An array of paged kv-cache last page length tensors for each level, the array length should be equal to the number of levels. num_qo_heads : int