Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Cutlass #309

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
54 changes: 54 additions & 0 deletions python/jittor/compile_extern.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,59 @@ def inner(self, *args, **kw):
if k == "mpi_test": continue
setattr(core.Var, k, wrapper(mpi_ops.__dict__[k]))

def install_cutlass(root_folder):
url = "https://cloud.tsinghua.edu.cn/f/8fc42499904f43e39141/?dl=1"

filename = "cutlass.zip"
fullname = os.path.join(root_folder, filename)
dirname = os.path.join(root_folder, "cutlass")
true_md5 = "41bb524a6bad4612d6017ed4b11f1d28"

if os.path.exists(fullname):
md5 = run_cmd('md5sum '+fullname).split()[0]
if md5 != true_md5:
os.remove(fullname)
if os.path.isdir(dirname):
shutil.rmtree(dirname)
if not os.path.isdir(os.path.join(dirname, "include")):
if not os.path.isfile(os.path.join(root_folder, filename)):
LOG.i("Downloading cutlass...")
download_url_to_local(url, filename, root_folder, true_md5)

if core.get_device_count() == 0:
return
shutil.unpack_archive(fullname, root_folder)
return dirname

def setup_cutlass():
use_cutlass = os.environ.get("use_cutlass", "0")=="1"
if not has_cuda:
use_cutlass = False
return
if not use_cutlass: return
cutlass_include_path = os.environ.get("cutlass_include_path")
print(cutlass_include_path)
if cutlass_include_path is None:
LOG.v("setup cutlass...")
from pathlib import Path
cutlass_path = os.path.join(str(Path.home()), ".cache", "jittor", "cutlass")

make_cache_dir(cutlass_path)
cutlass_home = install_cutlass(cutlass_path)
if cutlass_home is None: return
os.environ['cutlass_include_path'] = cutlass_home
cutlass_include_path = os.path.join(cutlass_home, "include")
cutlass_tool_include_path = os.path.join(cutlass_home, "tools", "util", "include")
all_dir = f" -I\"{cutlass_include_path}\" -I\"{cutlass_tool_include_path}\""
cutlass_src_dir = os.path.join(jittor_path, "extern", "cuda", "cutlass")
cutlass_src_files = []
for r, _, f in os.walk(cutlass_src_dir):
for fname in f:
cutlass_src_files.append(os.path.join(r, fname))
cutlass_ops = compile_custom_ops(cutlass_src_files,
extra_flags=f" {all_dir} ")
LOG.vv("Get cutlass_ops: "+str(dir(cutlass_ops)))

in_mpi = inside_mpi()
FIX_TORCH_ERROR = 0
if os.name != 'nt' and not in_mpi:
Expand All @@ -581,6 +634,7 @@ def inner(self, *args, **kw):
setup_nccl()

setup_cutt()
setup_cutlass()

try:
setup_mkl()
Expand Down
Loading