Skip to content

Commit

Permalink
PyTorch nightly for ROCm and Intel
Browse files Browse the repository at this point in the history
  • Loading branch information
Disty0 committed Jan 30, 2025
1 parent b081b3a commit 115aac6
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,23 +619,30 @@ def install_rocm_zluda():
# Python 3.12 will cause compatibility issues with other dependencies
# ROCm supports Python 3.12 so don't block it but don't advertise it in the error message
check_python(supported_minors=[9, 10, 11, 12], reason='ROCm backend requires Python 3.9, 3.10 or 3.11')

if rocm.version is None or float(rocm.version) >= 6.2: # assume the latest if version check fails
# use rocm 6.2.4 instead of 6.2 as torch==2.6.0+rocm6.2 doesn't exists
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.6.0+rocm6.2.4 torchvision==0.21.0+rocm6.2.4 --index-url https://download.pytorch.org/whl/rocm6.2.4')
elif rocm.version == "6.1":
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.6.0+rocm6.1 torchvision==0.21.0+rocm6.1 --index-url https://download.pytorch.org/whl/rocm6.1')
elif rocm.version == "6.0":
# lock to 2.4.1 instead of 2.5.1 for performance reasons
# there are no support for torch 2.6.0 for rocm 6.0
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.4.1+rocm6.0 torchvision==0.19.1+rocm6.0 --index-url https://download.pytorch.org/whl/rocm6.0')
elif float(rocm.version) < 5.5: # oldest supported version is 5.5
log.warning(f"ROCm: unsupported version={rocm.version}")
log.warning("ROCm: minimum supported version=5.5")
torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision --index-url https://download.pytorch.org/whl/rocm5.5')
if args.use_nightly:
if rocm.version is None or float(rocm.version) >= 6.3: # assume the latest if version check fails
torch_command = os.environ.get('TORCH_COMMAND', '--pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm6.3')
elif rocm.version == "6.2": # use rocm 6.2.4 instead of 6.2 as torch+rocm6.2 doesn't exists
torch_command = os.environ.get('TORCH_COMMAND', '--pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm6.2.4')
else: # oldest rocm version on nightly is 6.1
torch_command = os.environ.get('TORCH_COMMAND', '--pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm6.1')
else:
# older rocm (5.7) uses torch 2.3 or older
torch_command = os.environ.get('TORCH_COMMAND', f'torch torchvision --index-url https://download.pytorch.org/whl/rocm{rocm.version}')
if rocm.version is None or float(rocm.version) >= 6.2: # assume the latest if version check fails
# use rocm 6.2.4 instead of 6.2 as torch==2.6.0+rocm6.2 doesn't exists
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.6.0+rocm6.2.4 torchvision==0.21.0+rocm6.2.4 --index-url https://download.pytorch.org/whl/rocm6.2.4')
elif rocm.version == "6.1":
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.6.0+rocm6.1 torchvision==0.21.0+rocm6.1 --index-url https://download.pytorch.org/whl/rocm6.1')
elif rocm.version == "6.0":
# lock to 2.4.1 instead of 2.5.1 for performance reasons
# there are no support for torch 2.6.0 for rocm 6.0
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.4.1+rocm6.0 torchvision==0.19.1+rocm6.0 --index-url https://download.pytorch.org/whl/rocm6.0')
elif float(rocm.version) < 5.5: # oldest supported version is 5.5
log.warning(f"ROCm: unsupported version={rocm.version}")
log.warning("ROCm: minimum supported version=5.5")
torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision --index-url https://download.pytorch.org/whl/rocm5.5')
else:
# older rocm (5.7) uses torch 2.3 or older
torch_command = os.environ.get('TORCH_COMMAND', f'torch torchvision --index-url https://download.pytorch.org/whl/rocm{rocm.version}')

if os.environ.get('TRITON_COMMAND', None) is None:
os.environ.setdefault('TRITON_COMMAND', 'skip') # pytorch auto installs pytorch-triton-rocm as a dependency instead
Expand Down Expand Up @@ -697,14 +704,19 @@ def install_ipex(torch_command):
# XPU PyTorch doesn't support Flash Atten or Memory Atten yet so Battlemage goes OOM without this
os.environ.setdefault('IPEX_FORCE_ATTENTION_SLICE', '1')

if "linux" in sys.platform:
# default to US server. If The China server is needed, change .../release-whl/stable/xpu/us/ to .../release-whl/stable/xpu/cn/
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.5.1+cxx11.abi torchvision==0.20.1+cxx11.abi intel-extension-for-pytorch==2.5.10+xpu oneccl_bind_pt==2.5.0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/')
if args.use_nightly:
torch_command = os.environ.get('TORCH_COMMAND', '--pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/xpu')
if os.environ.get('TRITON_COMMAND', None) is None:
os.environ.setdefault('TRITON_COMMAND', '--pre pytorch-triton-xpu==3.1.0+91b14bf559 --index-url https://download.pytorch.org/whl/nightly/xpu')
# os.environ.setdefault('TENSORFLOW_PACKAGE', 'tensorflow==2.15.1 intel-extension-for-tensorflow[xpu]==2.15.0.2')
os.environ.setdefault('TRITON_COMMAND', 'skip') # pytorch auto installs pytorch-triton-rocm as a dependency instead
else:
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.6.0+xpu torchvision==0.21.0+xpu --index-url https://download.pytorch.org/whl/xpu')
if "linux" in sys.platform:
# default to US server. If The China server is needed, change .../release-whl/stable/xpu/us/ to .../release-whl/stable/xpu/cn/
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.5.1+cxx11.abi torchvision==0.20.1+cxx11.abi intel-extension-for-pytorch==2.5.10+xpu oneccl_bind_pt==2.5.0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/')
if os.environ.get('TRITON_COMMAND', None) is None:
os.environ.setdefault('TRITON_COMMAND', '--pre pytorch-triton-xpu==3.1.0+91b14bf559 --index-url https://download.pytorch.org/whl/nightly/xpu')
# os.environ.setdefault('TENSORFLOW_PACKAGE', 'tensorflow==2.15.1 intel-extension-for-tensorflow[xpu]==2.15.0.2')
else:
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.6.0+xpu torchvision==0.21.0+xpu --index-url https://download.pytorch.org/whl/xpu')

install(os.environ.get('OPENVINO_PACKAGE', 'openvino==2024.6.0'), 'openvino', ignore=True)
install('nncf==2.7.0', ignore=True, no_deps=True) # requires older pandas
Expand Down

0 comments on commit 115aac6

Please # to comment.