Skip to content

Commit be37608

Browse files
committed
version check against PyTorch's CUDA version
1 parent 11da8e8 commit be37608

File tree

3 files changed

+38
-0
lines changed

3 files changed

+38
-0
lines changed

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def write_version_file():
5454
with open(version_path, 'w') as f:
5555
f.write("__version__ = '{}'\n".format(version))
5656
f.write("git_version = {}\n".format(repr(sha)))
57+
f.write("from torchvision import _C\n")
58+
f.write("if hasattr(_C, 'CUDA_VERSION'):\n")
59+
f.write(" cuda = _C.CUDA_VERSION\n")
5760

5861

5962
write_version_file()

torchvision/__init__.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,31 @@ def get_image_backend():
3333
Gets the name of the package used to load images
3434
"""
3535
return _image_backend
36+
37+
38+
def _check_cuda_matches():
39+
"""
40+
Make sure that CUDA versions match between the pytorch install and torchvision install
41+
"""
42+
import torch
43+
from torchvision import _C
44+
if hasattr(_C, "CUDA_VERSION") and torch.version.cuda is not None:
45+
tv_version = str(_C.CUDA_VERSION)
46+
if int(tv_version) < 10000:
47+
tv_major = int(tv_version[0])
48+
tv_minor = int(tv_version[2])
49+
else:
50+
tv_major = int(tv_version[0:2])
51+
tv_minor = int(tv_version[3])
52+
t_version = torch.version.cuda
53+
t_version = t_version.split('.')
54+
t_major = int(t_version[0])
55+
t_minor = int(t_version[1])
56+
if t_major != tv_major or t_minor != tv_minor:
57+
raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. "
58+
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
59+
"Please reinstall the torchvision that matches your PyTorch install."
60+
.format(t_major, t_minor, tv_major, tv_minor))
61+
62+
63+
_check_cuda_matches()

torchvision/csrc/vision.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22
#include "ROIPool.h"
33
#include "nms.h"
44

5+
#ifdef WITH_CUDA
6+
#include <cuda.h>
7+
#endif
8+
59
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
610
m.def("nms", &nms, "non-maximum suppression");
711
m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward");
812
m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward");
913
m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward");
1014
m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward");
15+
#ifdef WITH_CUDA
16+
m.attr("CUDA_VERSION") = CUDA_VERSION;
17+
#endif
1118
}

0 commit comments

Comments
 (0)