diff --git a/CMakeLists.txt b/CMakeLists.txt index 3433480..c13c2a2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,6 +16,22 @@ if(CMAKE_CUDA_COMPILER) set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12) + message(STATUS "PyTorch NVTX headers workaround: Yes") + # only do this if nvToolsExt is not defined and CUDA::nvtx3 exists + if(NOT TARGET CUDA::nvToolsExt AND TARGET CUDA::nvtx3) + add_library(CUDA::nvToolsExt INTERFACE IMPORTED) + # The TORCH_CUDA_USE_NVTX3 compile definition tells PyTorch to use the NVTX3 headers + # instead of the legacy NVTX headers. This is necessary when building with CUDA 12+ + # and the CUDA::nvtx3 target is available, but CUDA::nvToolsExt is not defined. + # Without this definition, PyTorch may fail to find the correct NVTX headers. + target_compile_definitions(CUDA::nvToolsExt INTERFACE TORCH_CUDA_USE_NVTX3) + target_link_libraries(CUDA::nvToolsExt INTERFACE CUDA::nvtx3) + endif() + else() + message(STATUS "PyTorch NVTX headers workaround: No") + endif() + message(STATUS "INSTALLING EXTENSIONS WITH CUDA!") string(REGEX REPLACE ".[0-9][0-9]|\\." "" CUDA_V ${CMAKE_CUDA_COMPILER_VERSION}) message(STATUS "CMAKE_CUDA_COMPILER = ${CMAKE_CUDA_COMPILER}") diff --git a/pyproject.toml b/pyproject.toml index 4a80e71..11e0694 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ license-files = ["LICENSE"] exclude = ["**/.mypy_cache/**", "**/build/**", "**/.vscode/**"] [build-system] -requires = ["scikit-build-core>=0.10", "pybind11>=2.10", "cmake", "ninja"] +requires = ["scikit-build-core>=0.10", "pybind11>=2.10", "cmake", "ninja", "torch"] build-backend = "scikit_build_core.build" [tool.isort]