1+ #include < cudaTypedefs.h>
2+
13#include < c10/cuda/CUDAGuard.h>
2- #include < cuda_runtime.h>
34#include < torch/extension.h>
45
56void cutlass_scaled_mm_dq_sm75 (torch::Tensor& c, torch::Tensor const & a,
@@ -17,10 +18,12 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
1718 torch::Tensor const & a_scales,
1819 torch::Tensor const & b_scales);
1920
21+ #if defined CUDA_VERSION && CUDA_VERSION >= 12000
2022void cutlass_scaled_mm_dq_sm90 (torch::Tensor& c, torch::Tensor const & a,
2123 torch::Tensor const & b,
2224 torch::Tensor const & a_scales,
2325 torch::Tensor const & b_scales);
26+ #endif
2427
2528void cutlass_scaled_mm_dq (torch::Tensor& c, torch::Tensor const & a,
2629 torch::Tensor const & b, torch::Tensor const & a_scales,
@@ -51,7 +54,13 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
5154
5255 if (version_num >= 90 ) {
5356 // Hopper
57+
58+ // Guard against compilation issues for sm90 kernels
59+ #if defined CUDA_VERSION && CUDA_VERSION >= 12000
5460 cutlass_scaled_mm_dq_sm90 (c, a, b, a_scales, b_scales);
61+ #else
62+ cutlass_scaled_mm_dq_sm80 (c, a, b, a_scales, b_scales);
63+ #endif
5564 } else if (version_num == 89 ) {
5665 // Ada Lovelace
5766 cutlass_scaled_mm_dq_sm89 (c, a, b, a_scales, b_scales);
0 commit comments