Skip to content

Commit 3c4582e

Browse files
committed
better loading
1 parent 7dc7496 commit 3c4582e

File tree

6 files changed

+30
-26
lines changed

6 files changed

+30
-26
lines changed

.travis.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,10 @@ install:
7777

7878
script:
7979
- if [ "${PYTHON_VERSION}" != "3.5" ]; then pip install flake8 && flake8 .; fi
80-
- if [ "${PYTHON_VERSION}" = "3.5" ]; then pip install zzip; fi
8180
- python setup.py test
8281
after_success:
83-
- python setup.py bdist_wheel --dist-dir=dist/torch-${TORCH_VERSION}/${IDX}
84-
- python script/rename_wheel.py
82+
- python setup.py bdist_wheel --dist-dir=dist/torch-${TORCH_VERSION}
83+
- python script/rename_wheel.py ${IDX}
8584
- pip install codecov && codecov
8685
deploy:
8786
provider: s3
@@ -90,8 +89,8 @@ deploy:
9089
access_key_id: AKIAJB7S6NJ5OM5MAAGA
9190
secret_access_key: ${S3_SECRET_ACCESS_KEY}
9291
bucket: pytorch-scatter
93-
local_dir: dist/torch-${TORCH_VERSION}/${IDX}
94-
upload_dir: whl/torch-${TORCH_VERSION}/${IDX}
92+
local_dir: dist/torch-${TORCH_VERSION}
93+
upload_dir: whl/torch-${TORCH_VERSION}
9594
acl: public_read
9695
on:
9796
repo: rusty1s/pytorch_scatter

script/rename_wheel.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
1+
import sys
12
import os
23
import os.path as osp
34
import glob
5+
import shutil
6+
7+
idx = sys.argv[1]
8+
assert idx in ['cpu', 'cu92', 'cu100', 'cu101']
49

510
dist_dir = osp.join(osp.dirname(osp.abspath(__file__)), '..', 'dist')
611
wheels = glob.glob(osp.join('dist', '**', '*.whl'), recursive=True)
712

813
for wheel in wheels:
9-
idx = wheel.split(osp.sep)[-2]
10-
if idx not in ['cpu', 'cu92', 'cu100', 'cu101']:
11-
continue
12-
name = wheel.split(osp.sep)[-1]
13-
if idx in name:
14+
if idx in wheel:
1415
continue
1516

16-
names = name.split('-')
17-
name = '-'.join(names[:-4] + [names[-4] + '%2B' + idx] + names[-2:])
18-
new_wheel = osp.join(*wheel.split(osp.sep)[:-1], name)
19-
os.rename(wheel, new_wheel)
17+
paths = wheel.split(osp.sep)
18+
names = paths[-1].split('-')
19+
20+
name = '-'.join(names[:-4] + ['latest+' + idx] + names[-3:])
21+
shutil.copyfile(wheel, osp.join(*paths[:-1], name))
22+
23+
name = '-'.join(names[:-4] + [names[-4] + '+' + idx] + names[-3:])
24+
os.rename(wheel, osp.join(*paths[:-1], name))

torch_scatter/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# flake8: noqa
22

3-
import importlib
3+
import glob
44
import os.path as osp
55

66
import torch
@@ -9,8 +9,8 @@
99
expected_torch_version = (1, 4)
1010

1111
try:
12-
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
13-
'_version', [osp.dirname(__file__)]).origin)
12+
torch.ops.load_library(
13+
glob.glob(osp.join(osp.dirname(__file__), '_version.*'))[0])
1414
except OSError as e:
1515
if 'undefined symbol' in str(e):
1616
major, minor = [int(x) for x in torch.__version__.split('.')[:2]]

torch_scatter/scatter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
import importlib
1+
import glob
22
import os.path as osp
33
from typing import Optional, Tuple
44

55
import torch
66

77
from .utils import broadcast
88

9-
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
10-
'_scatter', [osp.dirname(__file__)]).origin)
9+
torch.ops.load_library(
10+
glob.glob(osp.join(osp.dirname(__file__), '_scatter.*'))[0])
1111

1212

1313
@torch.jit.script

torch_scatter/segment_coo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import importlib
1+
import glob
22
import os.path as osp
33
from typing import Optional, Tuple
44

55
import torch
66

7-
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
8-
'_segment_coo', [osp.dirname(__file__)]).origin)
7+
torch.ops.load_library(
8+
glob.glob(osp.join(osp.dirname(__file__), '_segment_coo.*'))[0])
99

1010

1111
@torch.jit.script

torch_scatter/segment_csr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import importlib
1+
import glob
22
import os.path as osp
33
from typing import Optional, Tuple
44

55
import torch
66

7-
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
8-
'_segment_csr', [osp.dirname(__file__)]).origin)
7+
torch.ops.load_library(
8+
glob.glob(osp.join(osp.dirname(__file__), '_segment_csr.*'))[0])
99

1010

1111
@torch.jit.script

0 commit comments

Comments
 (0)