Skip to content

Commit

Permalink
bugfix: fix the python 3.8 type error (#486)
Browse files Browse the repository at this point in the history
Bugfix to #484, thanks @wzhao18 for spotting this error.
  • Loading branch information
yzh119 authored Sep 1, 2024
1 parent eebbea0 commit 77bff3f
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

import math
from typing import Optional, Tuple
from typing import Optional, Tuple, List
import torch

# mypy: disable-error-code="attr-defined"
Expand Down Expand Up @@ -273,7 +273,7 @@ def __init__(
def reset_workspace_buffer(
self,
float_workspace_buffer: torch.Tensor,
int_workspace_buffers: list[torch.Tensor],
int_workspace_buffers: List[torch.Tensor],
) -> None:
r"""Reset the workspace buffer.
Expand All @@ -283,8 +283,8 @@ def reset_workspace_buffer(
The new float workspace buffer, the device of the new float workspace buffer should
be the same as the device of the input tensors.
int_workspace_buffer : torch.Tensor
The new int workspace buffer, the device of the new int workspace buffer should
int_workspace_buffers : List[torch.Tensor]
The array of new int workspace buffer, the device of the new int workspace buffer should
be the same as the device of the input tensors.
"""
for wrapper, int_workspace_buffer in zip(
Expand All @@ -294,10 +294,10 @@ def reset_workspace_buffer(

def plan(
self,
qo_indptr_arr: list[torch.Tensor],
paged_kv_indptr_arr: list[torch.Tensor],
paged_kv_indices_arr: list[torch.Tensor],
paged_kv_last_page_len: list[torch.Tensor],
qo_indptr_arr: List[torch.Tensor],
paged_kv_indptr_arr: List[torch.Tensor],
paged_kv_indices_arr: List[torch.Tensor],
paged_kv_last_page_len: List[torch.Tensor],
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
Expand All @@ -318,17 +318,17 @@ def plan(
Parameters
----------
qo_indptr_arr : list[torch.Tensor]
qo_indptr_arr : List[torch.Tensor]
An array of qo indptr tensors for each level, the array length should be equal to
the number of levels.
The last element of each tensor should be the total number of queries/outputs.
paged_kv_indptr_arr : list[torch.Tensor]
paged_kv_indptr_arr : List[torch.Tensor]
An array of paged kv-cache indptr tensors for each level, the array length should be
equal to the number of levels.
paged_kv_indices_arr : list[torch.Tensor]
paged_kv_indices_arr : List[torch.Tensor]
An array of paged kv-cache indices tensors for each level, the array length should be
equal to the number of levels.
paged_kv_last_page_len : list[torch.Tensor]
paged_kv_last_page_len : List[torch.Tensor]
An array of paged kv-cache last page length tensors for each level, the array length
should be equal to the number of levels.
num_qo_heads : int
Expand Down

0 comments on commit 77bff3f

Please # to comment.