99#include " hip/vision_cuda.h"
1010#endif
1111
12- // Interface for Python
13- at::Tensor ROIAlign_forward (
12+ // TODO: put this stuff in torchvision namespace
13+
14+ at::Tensor roi_align (
1415 const at::Tensor& input, // Input feature map.
1516 const at::Tensor& rois, // List of ROIs to pool over.
1617 const double spatial_scale, // The scale of the image features. ROIs will be
@@ -21,21 +22,10 @@ at::Tensor ROIAlign_forward(
2122 const bool aligned) // The flag for pixel shift
2223// along each axis.
2324{
24- if (input.is_cuda ()) {
25- #if defined(WITH_CUDA) || defined(WITH_HIP)
26- return ROIAlign_forward_cuda (
27- input,
28- rois,
29- spatial_scale,
30- pooled_height,
31- pooled_width,
32- sampling_ratio,
33- aligned);
34- #else
35- AT_ERROR (" Not compiled with GPU support" );
36- #endif
37- }
38- return ROIAlign_forward_cpu (
25+ static auto op = c10::Dispatcher::singleton ()
26+ .findSchemaOrThrow (" torchvision::roi_align" , " " )
27+ .typed <decltype (roi_align)>();
28+ return op.call (
3929 input,
4030 rois,
4131 spatial_scale,
@@ -45,37 +35,23 @@ at::Tensor ROIAlign_forward(
4535 aligned);
4636}
4737
48- at::Tensor ROIAlign_backward (
38+ at::Tensor _roi_align_backward (
4939 const at::Tensor& grad,
5040 const at::Tensor& rois,
51- const float spatial_scale,
52- const int pooled_height,
53- const int pooled_width,
54- const int batch_size,
55- const int channels,
56- const int height,
57- const int width,
58- const int sampling_ratio,
41+ const double spatial_scale,
42+ const int64_t pooled_height,
43+ const int64_t pooled_width,
44+ const int64_t batch_size,
45+ const int64_t channels,
46+ const int64_t height,
47+ const int64_t width,
48+ const int64_t sampling_ratio,
5949 const bool aligned) {
60- if (grad.is_cuda ()) {
61- #if defined(WITH_CUDA) || defined(WITH_HIP)
62- return ROIAlign_backward_cuda (
63- grad,
64- rois,
65- spatial_scale,
66- pooled_height,
67- pooled_width,
68- batch_size,
69- channels,
70- height,
71- width,
72- sampling_ratio,
73- aligned);
74- #else
75- AT_ERROR (" Not compiled with GPU support" );
76- #endif
77- }
78- return ROIAlign_backward_cpu (
50+ static auto op =
51+ c10::Dispatcher::singleton ()
52+ .findSchemaOrThrow (" torchvision::_roi_align_backward" , " " )
53+ .typed <decltype (_roi_align_backward)>();
54+ return op.call (
7955 grad,
8056 rois,
8157 spatial_scale,
@@ -107,7 +83,8 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
10783 ctx->saved_data [" aligned" ] = aligned;
10884 ctx->saved_data [" input_shape" ] = input.sizes ();
10985 ctx->save_for_backward ({rois});
110- auto result = ROIAlign_forward (
86+ at::AutoNonVariableTypeMode g;
87+ auto result = roi_align (
11188 input,
11289 rois,
11390 spatial_scale,
@@ -125,7 +102,7 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
125102 auto saved = ctx->get_saved_variables ();
126103 auto rois = saved[0 ];
127104 auto input_shape = ctx->saved_data [" input_shape" ].toIntList ();
128- auto grad_in = ROIAlign_backward (
105+ auto grad_in = _roi_align_backward (
129106 grad_output[0 ],
130107 rois,
131108 ctx->saved_data [" spatial_scale" ].toDouble (),
@@ -147,7 +124,47 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
147124 }
148125};
149126
150- at::Tensor roi_align (
127+ // TODO: There should be an easier way to do this
128+ class ROIAlignBackwardFunction
129+ : public torch::autograd::Function<ROIAlignBackwardFunction> {
130+ public:
131+ static torch::autograd::variable_list forward (
132+ torch::autograd::AutogradContext* ctx,
133+ torch::autograd::Variable grad,
134+ torch::autograd::Variable rois,
135+ const double spatial_scale,
136+ const int64_t pooled_height,
137+ const int64_t pooled_width,
138+ const int64_t batch_size,
139+ const int64_t channels,
140+ const int64_t height,
141+ const int64_t width,
142+ const int64_t sampling_ratio,
143+ const bool aligned) {
144+ at::AutoNonVariableTypeMode g;
145+ auto result = _roi_align_backward (
146+ grad,
147+ rois,
148+ spatial_scale,
149+ pooled_height,
150+ pooled_width,
151+ batch_size,
152+ channels,
153+ height,
154+ width,
155+ sampling_ratio,
156+ aligned);
157+ return {result};
158+ }
159+
160+ static torch::autograd::variable_list backward (
161+ torch::autograd::AutogradContext* ctx,
162+ torch::autograd::variable_list grad_output) {
163+ TORCH_CHECK (0 , " double backwards on roi_align not supported" );
164+ }
165+ };
166+
167+ at::Tensor ROIAlign_autograd (
151168 const at::Tensor& input,
152169 const at::Tensor& rois,
153170 const double spatial_scale,
@@ -164,3 +181,29 @@ at::Tensor roi_align(
164181 sampling_ratio,
165182 aligned)[0 ];
166183}
184+
185+ at::Tensor ROIAlign_backward_autograd (
186+ const at::Tensor& grad,
187+ const at::Tensor& rois,
188+ const double spatial_scale,
189+ const int64_t pooled_height,
190+ const int64_t pooled_width,
191+ const int64_t batch_size,
192+ const int64_t channels,
193+ const int64_t height,
194+ const int64_t width,
195+ const int64_t sampling_ratio,
196+ const bool aligned) {
197+ return ROIAlignBackwardFunction::apply (
198+ grad,
199+ rois,
200+ spatial_scale,
201+ pooled_height,
202+ pooled_width,
203+ batch_size,
204+ channels,
205+ height,
206+ width,
207+ sampling_ratio,
208+ aligned)[0 ];
209+ }
0 commit comments