Skip to content

Commit

Permalink
[CI] Fix CUDA version for torch 2.6
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Dec 6, 2024
1 parent cc408f9 commit cf0f4c3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ jobs:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124}[env['MATRIX_TORCH_VERSION']]; \
minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
)
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.7.1"
__version__ = "2.7.1.post1"

from flash_attn.flash_attn_interface import (
flash_attn_func,
Expand Down

0 comments on commit cf0f4c3

Please # to comment.