Skip to content

Commit 0c2b0d7

Browse files
authored
Add torch2.6 support for ms_deform_attn_cuda (IDEA-Research#94)
1 parent 5cf6b2e commit 0c2b0d7

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

grounding_dino/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,24 @@
1515
#include <ATen/cuda/CUDAContext.h>
1616
#include <cuda.h>
1717
#include <cuda_runtime.h>
18+
#include <torch/extension.h>
19+
#include <torch/version.h>
20+
21+
// Check PyTorch version and define appropriate macros
22+
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 6
23+
// PyTorch 2.x and above
24+
#define GET_TENSOR_TYPE(x) x.scalar_type()
25+
#define IS_CUDA_TENSOR(x) x.device().is_cuda()
26+
#else
27+
// PyTorch 1.x
28+
#define GET_TENSOR_TYPE(x) x.type()
29+
#define IS_CUDA_TENSOR(x) x.type().is_cuda()
30+
#endif
1831

1932
namespace groundingdino {
2033

2134
at::Tensor ms_deform_attn_cuda_forward(
22-
const at::Tensor &value,
35+
const at::Tensor &value,
2336
const at::Tensor &spatial_shapes,
2437
const at::Tensor &level_start_index,
2538
const at::Tensor &sampling_loc,
@@ -32,11 +45,11 @@ at::Tensor ms_deform_attn_cuda_forward(
3245
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
3346
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
3447

35-
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
36-
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
37-
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
38-
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
39-
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
48+
AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor");
49+
AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor");
50+
AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor");
51+
AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor");
52+
AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor");
4053

4154
const int batch = value.size(0);
4255
const int spatial_size = value.size(1);
@@ -51,7 +64,7 @@ at::Tensor ms_deform_attn_cuda_forward(
5164
const int im2col_step_ = std::min(batch, im2col_step);
5265

5366
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
54-
67+
5568
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
5669

5770
const int batch_n = im2col_step_;
@@ -62,7 +75,7 @@ at::Tensor ms_deform_attn_cuda_forward(
6275
for (int n = 0; n < batch/im2col_step_; ++n)
6376
{
6477
auto columns = output_n.select(0, n);
65-
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
78+
AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_forward_cuda", ([&] {
6679
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
6780
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
6881
spatial_shapes.data<int64_t>(),
@@ -82,7 +95,7 @@ at::Tensor ms_deform_attn_cuda_forward(
8295

8396

8497
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
85-
const at::Tensor &value,
98+
const at::Tensor &value,
8699
const at::Tensor &spatial_shapes,
87100
const at::Tensor &level_start_index,
88101
const at::Tensor &sampling_loc,
@@ -98,12 +111,12 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
98111
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
99112
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
100113

101-
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
102-
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
103-
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
104-
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
105-
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
106-
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
114+
AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor");
115+
AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor");
116+
AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor");
117+
AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor");
118+
AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor");
119+
AT_ASSERTM(IS_CUDA_TENSOR(grad_output), "grad_output must be a CUDA tensor");
107120

108121
const int batch = value.size(0);
109122
const int spatial_size = value.size(1);
@@ -128,11 +141,11 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
128141
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
129142
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
130143
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
131-
144+
132145
for (int n = 0; n < batch/im2col_step_; ++n)
133146
{
134147
auto grad_output_g = grad_output_n.select(0, n);
135-
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
148+
AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_backward_cuda", ([&] {
136149
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
137150
grad_output_g.data<scalar_t>(),
138151
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
@@ -153,4 +166,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
153166
};
154167
}
155168

156-
} // namespace groundingdino
169+
} // namespace groundingdino

0 commit comments

Comments
 (0)