From 0aa47269f9f06f20e4a15662931972c9a2de482f Mon Sep 17 00:00:00 2001 From: Lequn Chen Date: Thu, 10 Oct 2024 15:13:36 -0700 Subject: [PATCH] fix: AOT compiler flags on non-sm90 (#522) Previously non-sm90 cards incorrectly get compiled as sm90, due to the shallow copy of the compiler flags. --- flashinfer-aot/setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flashinfer-aot/setup.py b/flashinfer-aot/setup.py index 5c541026..80fd4ea9 100644 --- a/flashinfer-aot/setup.py +++ b/flashinfer-aot/setup.py @@ -16,6 +16,7 @@ from typing import List, Tuple +import copy import pathlib import os import re @@ -372,7 +373,7 @@ def __init__(self, *args, **kwargs) -> None: "-use_fast_math", ], } - extra_compile_args_sm90 = extra_compile_args.copy() + extra_compile_args_sm90 = copy.deepcopy(extra_compile_args) extra_compile_args_sm90["nvcc"].extend( "-gencode arch=compute_90a,code=sm_90a".split() )