File tree Expand file tree Collapse file tree 3 files changed +38
-0
lines changed Expand file tree Collapse file tree 3 files changed +38
-0
lines changed Original file line number Diff line number Diff line change @@ -54,6 +54,9 @@ def write_version_file():
54
54
with open (version_path , 'w' ) as f :
55
55
f .write ("__version__ = '{}'\n " .format (version ))
56
56
f .write ("git_version = {}\n " .format (repr (sha )))
57
+ f .write ("from torchvision import _C\n " )
58
+ f .write ("if hasattr(_C, 'CUDA_VERSION'):\n " )
59
+ f .write (" cuda = _C.CUDA_VERSION\n " )
57
60
58
61
59
62
write_version_file ()
Original file line number Diff line number Diff line change @@ -33,3 +33,31 @@ def get_image_backend():
33
33
Gets the name of the package used to load images
34
34
"""
35
35
return _image_backend
36
+
37
+
38
+ def _check_cuda_matches ():
39
+ """
40
+ Make sure that CUDA versions match between the pytorch install and torchvision install
41
+ """
42
+ import torch
43
+ from torchvision import _C
44
+ if hasattr (_C , "CUDA_VERSION" ) and torch .version .cuda is not None :
45
+ tv_version = str (_C .CUDA_VERSION )
46
+ if int (tv_version ) < 10000 :
47
+ tv_major = int (tv_version [0 ])
48
+ tv_minor = int (tv_version [2 ])
49
+ else :
50
+ tv_major = int (tv_version [0 :2 ])
51
+ tv_minor = int (tv_version [3 ])
52
+ t_version = torch .version .cuda
53
+ t_version = t_version .split ('.' )
54
+ t_major = int (t_version [0 ])
55
+ t_minor = int (t_version [1 ])
56
+ if t_major != tv_major or t_minor != tv_minor :
57
+ raise RuntimeError ("Detected that PyTorch and torchvision were compiled with different CUDA versions. "
58
+ "PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
59
+ "Please reinstall the torchvision that matches your PyTorch install."
60
+ .format (t_major , t_minor , tv_major , tv_minor ))
61
+
62
+
63
+ _check_cuda_matches ()
Original file line number Diff line number Diff line change 2
2
#include " ROIPool.h"
3
3
#include " nms.h"
4
4
5
+ #ifdef WITH_CUDA
6
+ #include < cuda.h>
7
+ #endif
8
+
5
9
PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
6
10
m.def (" nms" , &nms, " non-maximum suppression" );
7
11
m.def (" roi_align_forward" , &ROIAlign_forward, " ROIAlign_forward" );
8
12
m.def (" roi_align_backward" , &ROIAlign_backward, " ROIAlign_backward" );
9
13
m.def (" roi_pool_forward" , &ROIPool_forward, " ROIPool_forward" );
10
14
m.def (" roi_pool_backward" , &ROIPool_backward, " ROIPool_backward" );
15
+ #ifdef WITH_CUDA
16
+ m.attr (" CUDA_VERSION" ) = CUDA_VERSION;
17
+ #endif
11
18
}
You can’t perform that action at this time.
0 commit comments