Skip to content

Commit 348897a

Browse files
authored
Fix PyTorch version to 2.0.1 in workflow (#1377)
1 parent 9d9072a commit 348897a

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

.github/workflows/publish.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ jobs:
4949
matrix:
5050
os: ['ubuntu-20.04']
5151
python-version: ['3.8', '3.9', '3.10', '3.11']
52+
pytorch-version: ['2.0.1']
5253
cuda-version: ['11.8'] # Github runner can't build anything older than 11.8
5354

5455
steps:
@@ -69,9 +70,9 @@ jobs:
6970
run: |
7071
bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
7172
72-
- name: Install PyTorch-cu${{ matrix.cuda-version }}
73+
- name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
7374
run: |
74-
bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
75+
bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
7576
7677
- name: Build wheel
7778
shell: bash

.github/workflows/scripts/pytorch-install.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
#!/bin/bash
22

33
python_executable=python$1
4-
cuda_version=$2
4+
pytorch_version=$2
5+
cuda_version=$3
56

67
# Install torch
78
$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
8-
$python_executable -m pip install torch -f https://download.pytorch.org/whl/cu${cuda_version//./}/torch_stable.html
9+
$python_executable -m pip install torch==${pytorch_version}+cu${cuda_version//./} --index-url https://download.pytorch.org/whl/cu${cuda_version//./}
910

1011
# Print version information
1112
$python_executable --version

0 commit comments

Comments
 (0)