From 99b864d75a90fb82def5fdbe698490614f0266c7 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 11 Mar 2024 04:06:50 +0000 Subject: [PATCH 1/2] upd --- python/generate_batch_paged_prefill_inst.py | 2 +- python/setup.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/generate_batch_paged_prefill_inst.py b/python/generate_batch_paged_prefill_inst.py index ab6e4500..57227a51 100644 --- a/python/generate_batch_paged_prefill_inst.py +++ b/python/generate_batch_paged_prefill_inst.py @@ -36,9 +36,9 @@ def get_cu_file_str( dtype_in, dtype_out, idtype, + page_size_choices=[1, 8, 16, 32], ): num_frags_x_choices = [1, 2] - page_size_choices = [1, 8, 16, 32] insts = "\n".join( [ """template cudaError_t BatchPrefillWithPagedKVCacheDispatched( diff --git a/python/setup.py b/python/setup.py index 03653e5f..bb3ef10e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -265,6 +265,7 @@ def get_instantiation_cu() -> List[str]: dtype, dtype, idtype, + page_size_choices=[1, 16, 32], ) write_if_different(root / prefix / fname, content) @@ -378,9 +379,8 @@ def __init__(self, *args, **kwargs) -> None: str(root.resolve() / "include"), ], extra_compile_args={ - "cxx": ["-O3", "-std=c++17"], - "nvcc": ["-O3", "-std=c++17", "--threads", "8", "-gencode", "arch=compute_80,code=sm_80", - "-gencode", "arch=compute_89,code=sm_89", "-gencode", "arch=compute_90,code=sm_90"], + "cxx": ["-O3"], + "nvcc": ["-O3", "-std=c++17", "--threads", "8", "-Xfatbin", "-compress-all"], }, ) ) From 7a60f514169539300482ba40fca8b47bd676e2b4 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 11 Mar 2024 04:19:57 +0000 Subject: [PATCH 2/2] upd --- .github/workflows/release_wheel.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release_wheel.yml b/.github/workflows/release_wheel.yml index c09f0482..ce5989ad 100644 --- a/.github/workflows/release_wheel.yml +++ b/.github/workflows/release_wheel.yml @@ -18,7 +18,7 @@ on: # required: true env: - TORCH_CUDA_ARCH_LIST: "8.0 8.6 8.9 9.0+PTX" + TORCH_CUDA_ARCH_LIST: "8.0 8.9 9.0+PTX" jobs: build: