diff --git a/python/flashinfer/jit/env.py b/python/flashinfer/jit/env.py index e3fbec81..cc48182c 100644 --- a/python/flashinfer/jit/env.py +++ b/python/flashinfer/jit/env.py @@ -15,9 +15,20 @@ """ import pathlib +import re + +from torch.utils.cpp_extension import _get_cuda_arch_flags + + +def _get_workspace_dir_name() -> pathlib.Path: + flags = _get_cuda_arch_flags() + arch = "_".join(sorted(set(re.findall(r"compute_(\d+)", "".join(flags))))) + # e.g.: $HOME/.cache/flashinfer/75_80_89_90/ + return pathlib.Path.home() / ".cache" / "flashinfer" / arch + # use pathlib -FLASHINFER_WORKSPACE_DIR = pathlib.Path.home() / ".flashinfer" +FLASHINFER_WORKSPACE_DIR = _get_workspace_dir_name() FLASHINFER_JIT_DIR = FLASHINFER_WORKSPACE_DIR / "cached_ops" FLASHINFER_GEN_SRC_DIR = FLASHINFER_WORKSPACE_DIR / "generated" _project_root = pathlib.Path(__file__).resolve().parent.parent.parent