9
9
#include " hip/vision_cuda.h"
10
10
#endif
11
11
12
- // Interface for Python
13
- at::Tensor ROIAlign_forward (
12
+ // TODO: put this stuff in torchvision namespace
13
+
14
+ at::Tensor roi_align (
14
15
const at::Tensor& input, // Input feature map.
15
16
const at::Tensor& rois, // List of ROIs to pool over.
16
17
const double spatial_scale, // The scale of the image features. ROIs will be
@@ -21,21 +22,10 @@ at::Tensor ROIAlign_forward(
21
22
const bool aligned) // The flag for pixel shift
22
23
// along each axis.
23
24
{
24
- if (input.is_cuda ()) {
25
- #if defined(WITH_CUDA) || defined(WITH_HIP)
26
- return ROIAlign_forward_cuda (
27
- input,
28
- rois,
29
- spatial_scale,
30
- pooled_height,
31
- pooled_width,
32
- sampling_ratio,
33
- aligned);
34
- #else
35
- AT_ERROR (" Not compiled with GPU support" );
36
- #endif
37
- }
38
- return ROIAlign_forward_cpu (
25
+ static auto op = c10::Dispatcher::singleton ()
26
+ .findSchemaOrThrow (" torchvision::roi_align" , " " )
27
+ .typed <decltype (roi_align)>();
28
+ return op.call (
39
29
input,
40
30
rois,
41
31
spatial_scale,
@@ -45,37 +35,23 @@ at::Tensor ROIAlign_forward(
45
35
aligned);
46
36
}
47
37
48
- at::Tensor ROIAlign_backward (
38
+ at::Tensor _roi_align_backward (
49
39
const at::Tensor& grad,
50
40
const at::Tensor& rois,
51
- const float spatial_scale,
52
- const int pooled_height,
53
- const int pooled_width,
54
- const int batch_size,
55
- const int channels,
56
- const int height,
57
- const int width,
58
- const int sampling_ratio,
41
+ const double spatial_scale,
42
+ const int64_t pooled_height,
43
+ const int64_t pooled_width,
44
+ const int64_t batch_size,
45
+ const int64_t channels,
46
+ const int64_t height,
47
+ const int64_t width,
48
+ const int64_t sampling_ratio,
59
49
const bool aligned) {
60
- if (grad.is_cuda ()) {
61
- #if defined(WITH_CUDA) || defined(WITH_HIP)
62
- return ROIAlign_backward_cuda (
63
- grad,
64
- rois,
65
- spatial_scale,
66
- pooled_height,
67
- pooled_width,
68
- batch_size,
69
- channels,
70
- height,
71
- width,
72
- sampling_ratio,
73
- aligned);
74
- #else
75
- AT_ERROR (" Not compiled with GPU support" );
76
- #endif
77
- }
78
- return ROIAlign_backward_cpu (
50
+ static auto op =
51
+ c10::Dispatcher::singleton ()
52
+ .findSchemaOrThrow (" torchvision::_roi_align_backward" , " " )
53
+ .typed <decltype (_roi_align_backward)>();
54
+ return op.call (
79
55
grad,
80
56
rois,
81
57
spatial_scale,
@@ -107,7 +83,8 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
107
83
ctx->saved_data [" aligned" ] = aligned;
108
84
ctx->saved_data [" input_shape" ] = input.sizes ();
109
85
ctx->save_for_backward ({rois});
110
- auto result = ROIAlign_forward (
86
+ at::AutoNonVariableTypeMode g;
87
+ auto result = roi_align (
111
88
input,
112
89
rois,
113
90
spatial_scale,
@@ -125,7 +102,7 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
125
102
auto saved = ctx->get_saved_variables ();
126
103
auto rois = saved[0 ];
127
104
auto input_shape = ctx->saved_data [" input_shape" ].toIntList ();
128
- auto grad_in = ROIAlign_backward (
105
+ auto grad_in = _roi_align_backward (
129
106
grad_output[0 ],
130
107
rois,
131
108
ctx->saved_data [" spatial_scale" ].toDouble (),
@@ -147,7 +124,47 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
147
124
}
148
125
};
149
126
150
- at::Tensor roi_align (
127
+ // TODO: There should be an easier way to do this
128
+ class ROIAlignBackwardFunction
129
+ : public torch::autograd::Function<ROIAlignBackwardFunction> {
130
+ public:
131
+ static torch::autograd::variable_list forward (
132
+ torch::autograd::AutogradContext* ctx,
133
+ torch::autograd::Variable grad,
134
+ torch::autograd::Variable rois,
135
+ const double spatial_scale,
136
+ const int64_t pooled_height,
137
+ const int64_t pooled_width,
138
+ const int64_t batch_size,
139
+ const int64_t channels,
140
+ const int64_t height,
141
+ const int64_t width,
142
+ const int64_t sampling_ratio,
143
+ const bool aligned) {
144
+ at::AutoNonVariableTypeMode g;
145
+ auto result = _roi_align_backward (
146
+ grad,
147
+ rois,
148
+ spatial_scale,
149
+ pooled_height,
150
+ pooled_width,
151
+ batch_size,
152
+ channels,
153
+ height,
154
+ width,
155
+ sampling_ratio,
156
+ aligned);
157
+ return {result};
158
+ }
159
+
160
+ static torch::autograd::variable_list backward (
161
+ torch::autograd::AutogradContext* ctx,
162
+ torch::autograd::variable_list grad_output) {
163
+ TORCH_CHECK (0 , " double backwards on roi_align not supported" );
164
+ }
165
+ };
166
+
167
+ at::Tensor ROIAlign_autograd (
151
168
const at::Tensor& input,
152
169
const at::Tensor& rois,
153
170
const double spatial_scale,
@@ -164,3 +181,29 @@ at::Tensor roi_align(
164
181
sampling_ratio,
165
182
aligned)[0 ];
166
183
}
184
+
185
+ at::Tensor ROIAlign_backward_autograd (
186
+ const at::Tensor& grad,
187
+ const at::Tensor& rois,
188
+ const double spatial_scale,
189
+ const int64_t pooled_height,
190
+ const int64_t pooled_width,
191
+ const int64_t batch_size,
192
+ const int64_t channels,
193
+ const int64_t height,
194
+ const int64_t width,
195
+ const int64_t sampling_ratio,
196
+ const bool aligned) {
197
+ return ROIAlignBackwardFunction::apply (
198
+ grad,
199
+ rois,
200
+ spatial_scale,
201
+ pooled_height,
202
+ pooled_width,
203
+ batch_size,
204
+ channels,
205
+ height,
206
+ width,
207
+ sampling_ratio,
208
+ aligned)[0 ];
209
+ }
0 commit comments