@@ -141,9 +141,13 @@ void ROIAlignForward(
141141 T roi_end_w = offset_rois[3 ] * spatial_scale - offset;
142142 T roi_end_h = offset_rois[4 ] * spatial_scale - offset;
143143
144- // Force malformed ROIs to be 1x1
145- T roi_width = std::max (roi_end_w - roi_start_w, (T)1 .);
146- T roi_height = std::max (roi_end_h - roi_start_h, (T)1 .);
144+ T roi_width = roi_end_w - roi_start_w;
145+ T roi_height = roi_end_h - roi_start_h;
146+ if (!aligned) {
147+ // Force malformed ROIs to be 1x1
148+ roi_width = std::max (roi_width, (T)1 .);
149+ roi_height = std::max (roi_height, (T)1 .);
150+ }
147151
148152 T bin_size_h = static_cast <T>(roi_height) / static_cast <T>(pooled_height);
149153 T bin_size_w = static_cast <T>(roi_width) / static_cast <T>(pooled_width);
@@ -309,9 +313,13 @@ void ROIAlignBackward(
309313 T roi_end_w = offset_rois[3 ] * spatial_scale - offset;
310314 T roi_end_h = offset_rois[4 ] * spatial_scale - offset;
311315
312- // Force malformed ROIs to be 1x1
313- T roi_width = std::max (roi_end_w - roi_start_w, (T)1 .);
314- T roi_height = std::max (roi_end_h - roi_start_h, (T)1 .);
316+ T roi_width = roi_end_w - roi_start_w;
317+ T roi_height = roi_end_h - roi_start_h;
318+ if (!aligned) {
319+ // Force malformed ROIs to be 1x1
320+ roi_width = std::max (roi_width, (T)1 .);
321+ roi_height = std::max (roi_height, (T)1 .);
322+ }
315323
316324 T bin_size_h = static_cast <T>(roi_height) / static_cast <T>(pooled_height);
317325 T bin_size_w = static_cast <T>(roi_width) / static_cast <T>(pooled_width);
0 commit comments