diff --git a/python/flashinfer/jit/__init__.py b/python/flashinfer/jit/__init__.py index e84ed3f0..23cce707 100644 --- a/python/flashinfer/jit/__init__.py +++ b/python/flashinfer/jit/__init__.py @@ -82,7 +82,7 @@ def info(self, msg): def check_cuda_arch(): # cuda arch check for fp8 at the moment. for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): - arch = int(re.search("compute_\d+", cuda_arch_flags).group()[-2:]) + arch = int(re.search(r"compute_(\d+)", cuda_arch_flags).group(1)) if arch < 75: raise RuntimeError("FlashInfer requires sm75+")