|
1 | 1 | # flake8: noqa |
2 | 2 |
|
3 | | -import glob |
| 3 | +import importlib |
4 | 4 | import os.path as osp |
5 | 5 |
|
6 | 6 | import torch |
|
9 | 9 | expected_torch_version = (1, 4) |
10 | 10 |
|
11 | 11 | try: |
12 | | - torch.ops.load_library( |
13 | | - glob.glob(osp.join(osp.dirname(__file__), '_version.*'))[0]) |
| 12 | + torch.ops.load_library(importlib.machinery.PathFinder().find_spec( |
| 13 | + '_version', [osp.dirname(__file__)]).origin) |
14 | 14 | except OSError as e: |
15 | 15 | if 'undefined symbol' in str(e): |
16 | 16 | major, minor = [int(x) for x in torch.__version__.split('.')[:2]] |
17 | 17 | t_major, t_minor = expected_torch_version |
18 | 18 | if major != t_major or (major == t_major and minor != t_minor): |
19 | 19 | raise RuntimeError( |
20 | | - 'Expected PyTorch version {}.{} but found version ' |
21 | | - '{}.{}.'.format(t_major, t_minor, major, minor)) |
| 20 | + f'Expected PyTorch version {t_major}.{t_minor} but found ' |
| 21 | + f'version {major}.{minor}.') |
22 | 22 | raise OSError(e) |
23 | 23 |
|
24 | 24 | from .scatter import (scatter_sum, scatter_add, scatter_mean, scatter_min, |
|
43 | 43 |
|
44 | 44 | if t_major != major or t_minor != minor: |
45 | 45 | raise RuntimeError( |
46 | | - 'Detected that PyTorch and torch_scatter were compiled with ' |
47 | | - 'different CUDA versions. PyTorch has CUDA version={}.{} and ' |
48 | | - 'torch_scatter has CUDA version={}.{}. Please reinstall the ' |
49 | | - 'torch_scatter that matches your PyTorch install.'.format( |
50 | | - t_major, t_minor, major, minor)) |
| 46 | + f'Detected that PyTorch and torch_scatter were compiled with ' |
| 47 | + f'different CUDA versions. PyTorch has CUDA version ' |
| 48 | + f'{t_major}.{t_minor} and torch_scatter has CUDA version ' |
| 49 | + f'{major}.{minor}. Please reinstall the torch_scatter that ' |
| 50 | + f'matches your PyTorch install.') |
51 | 51 |
|
52 | 52 | __all__ = [ |
53 | 53 | 'scatter_sum', |
|
0 commit comments