@@ -44,7 +44,7 @@ def get_extensions():
4444 if sys .platform == 'win32' :
4545 define_macros += [('torchscatter_EXPORTS' , None )]
4646
47- extra_compile_args = {'cxx' : ['-O2 ' ]}
47+ extra_compile_args = {'cxx' : ['-O3 ' ]}
4848 if not os .name == 'nt' : # Not on Windows:
4949 extra_compile_args ['cxx' ] += ['-Wno-sign-compare' ]
5050 extra_link_args = [] if WITH_SYMBOLS else ['-s' ]
@@ -69,14 +69,14 @@ def get_extensions():
6969 define_macros += [('WITH_CUDA' , None )]
7070 nvcc_flags = os .getenv ('NVCC_FLAGS' , '' )
7171 nvcc_flags = [] if nvcc_flags == '' else nvcc_flags .split (' ' )
72+ nvcc_flags += ['-O3' ]
7273 if torch .version .hip :
73- nvcc_flags += ['-O3' ]
74- # USE_ROCM was added to later versons of rocm pytorch
75- # define here to support older pytorch versions
74+ # USE_ROCM was added to later versions of PyTorch.
75+ # Define here to support older PyTorch versions as well:
7676 define_macros += [('USE_ROCM' , None )]
7777 undef_macros += ['__HIP_NO_HALF_CONVERSIONS__' ]
7878 else :
79- nvcc_flags += ['--expt-relaxed-constexpr' , '-O2' ]
79+ nvcc_flags += ['--expt-relaxed-constexpr' ]
8080 extra_compile_args ['nvcc' ] = nvcc_flags
8181
8282 name = main .split (os .sep )[- 1 ][:- 4 ]
0 commit comments