File tree Expand file tree Collapse file tree 3 files changed +38
-0
lines changed Expand file tree Collapse file tree 3 files changed +38
-0
lines changed Original file line number Diff line number Diff line change @@ -54,6 +54,9 @@ def write_version_file():
5454 with open (version_path , 'w' ) as f :
5555 f .write ("__version__ = '{}'\n " .format (version ))
5656 f .write ("git_version = {}\n " .format (repr (sha )))
57+ f .write ("from torchvision import _C\n " )
58+ f .write ("if hasattr(_C, 'CUDA_VERSION'):\n " )
59+ f .write (" cuda = _C.CUDA_VERSION\n " )
5760
5861
5962write_version_file ()
Original file line number Diff line number Diff line change @@ -33,3 +33,31 @@ def get_image_backend():
3333 Gets the name of the package used to load images
3434 """
3535 return _image_backend
36+
37+
38+ def _check_cuda_matches ():
39+ """
40+ Make sure that CUDA versions match between the pytorch install and torchvision install
41+ """
42+ import torch
43+ from torchvision import _C
44+ if hasattr (_C , "CUDA_VERSION" ) and torch .version .cuda is not None :
45+ tv_version = str (_C .CUDA_VERSION )
46+ if int (tv_version ) < 10000 :
47+ tv_major = int (tv_version [0 ])
48+ tv_minor = int (tv_version [2 ])
49+ else :
50+ tv_major = int (tv_version [0 :2 ])
51+ tv_minor = int (tv_version [3 ])
52+ t_version = torch .version .cuda
53+ t_version = t_version .split ('.' )
54+ t_major = int (t_version [0 ])
55+ t_minor = int (t_version [1 ])
56+ if t_major != tv_major or t_minor != tv_minor :
57+ raise RuntimeError ("Detected that PyTorch and torchvision were compiled with different CUDA versions. "
58+ "PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
59+ "Please reinstall the torchvision that matches your PyTorch install."
60+ .format (t_major , t_minor , tv_major , tv_minor ))
61+
62+
63+ _check_cuda_matches ()
Original file line number Diff line number Diff line change 22#include " ROIPool.h"
33#include " nms.h"
44
5+ #ifdef WITH_CUDA
6+ #include < cuda.h>
7+ #endif
8+
59PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
610 m.def (" nms" , &nms, " non-maximum suppression" );
711 m.def (" roi_align_forward" , &ROIAlign_forward, " ROIAlign_forward" );
812 m.def (" roi_align_backward" , &ROIAlign_backward, " ROIAlign_backward" );
913 m.def (" roi_pool_forward" , &ROIPool_forward, " ROIPool_forward" );
1014 m.def (" roi_pool_backward" , &ROIPool_backward, " ROIPool_backward" );
15+ #ifdef WITH_CUDA
16+ m.attr (" CUDA_VERSION" ) = CUDA_VERSION;
17+ #endif
1118}
You can’t perform that action at this time.
0 commit comments