Skip to content

Commit 986d242

Browse files
neginraooffmassa
andauthored
ONNX export for variable input sizes (#1840)
* fixes and tests for variable input size * transform test fix * Fix comment * Dynamic shape for keypoint_rcnn * Update test_onnx.py * Update rpn.py * Fix for split on RPN * Fixes for feedbacks * flake8 * topk fix * Fix build * branch on tracing * fix for scalar tensor * Fixes for script type annotations * Update rpn.py * clean up * clean up * Update rpn.py * Updated for feedback * Fix for comments * revert to use tensor * Added test for box clip * Fixes for feedback * Fix for feedback * ORT version revert * Update ort * Update .travis.yml * Update test_onnx.py * Update test_onnx.py * Tensor sizes * Fix for dynamic split * Try disable tests * pytest verbose * revert one test * enable tests * Update .travis.yml * Update .travis.yml * Update .travis.yml * Update test_onnx.py * Update .travis.yml * Passing device * Fixes for test * Fix for boxes datatype * clean up Co-authored-by: Francisco Massa <[email protected]>
1 parent 504d20c commit 986d242

File tree

5 files changed

+108
-68
lines changed

5 files changed

+108
-68
lines changed

test/test_onnx.py

Lines changed: 81 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@ class ONNXExporterTester(unittest.TestCase):
2828
def setUpClass(cls):
2929
torch.manual_seed(123)
3030

31-
def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True):
31+
def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None,
32+
output_names=None, input_names=None):
3233
model.eval()
3334

3435
onnx_io = io.BytesIO()
3536
# export to onnx with the first input
3637
torch.onnx.export(model, inputs_list[0], onnx_io,
37-
do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version)
38-
38+
do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version,
39+
dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names)
3940
# validate the exported model with onnx runtime
4041
for test_inputs in inputs_list:
4142
with torch.no_grad():
@@ -99,6 +100,21 @@ def forward(self, boxes, scores):
99100

100101
self.run_model(Module(), [(boxes, scores)])
101102

103+
def test_clip_boxes_to_image(self):
104+
boxes = torch.randn(5, 4) * 500
105+
boxes[:, 2:] += boxes[:, :2]
106+
size = torch.randn(200, 300)
107+
108+
size_2 = torch.randn(300, 400)
109+
110+
class Module(torch.nn.Module):
111+
def forward(self, boxes, size):
112+
return ops.boxes.clip_boxes_to_image(boxes, size.shape)
113+
114+
self.run_model(Module(), [(boxes, size), (boxes, size_2)],
115+
input_names=["boxes", "size"],
116+
dynamic_axes={"size": [0, 1]})
117+
102118
def test_roi_align(self):
103119
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
104120
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
@@ -123,9 +139,9 @@ def __init__(self_module):
123139
def forward(self_module, images):
124140
return self_module.transform(images)[0].tensors
125141

126-
input = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)]
127-
input_test = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)]
128-
self.run_model(TransformModule(), [input, input_test])
142+
input = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
143+
input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
144+
self.run_model(TransformModule(), [(input,), (input_test,)])
129145

130146
def _init_test_generalized_rcnn_transform(self):
131147
min_size = 100
@@ -207,22 +223,28 @@ def get_features(self, images):
207223

208224
def test_rpn(self):
209225
class RPNModule(torch.nn.Module):
210-
def __init__(self_module, images):
226+
def __init__(self_module):
211227
super(RPNModule, self_module).__init__()
212228
self_module.rpn = self._init_test_rpn()
213-
self_module.images = ImageList(images, [i.shape[-2:] for i in images])
214229

215-
def forward(self_module, features):
216-
return self_module.rpn(self_module.images, features)
230+
def forward(self_module, images, features):
231+
images = ImageList(images, [i.shape[-2:] for i in images])
232+
return self_module.rpn(images, features)
217233

218-
images = torch.rand(2, 3, 600, 600)
234+
images = torch.rand(2, 3, 150, 150)
219235
features = self.get_features(images)
220-
test_features = self.get_features(images)
236+
images2 = torch.rand(2, 3, 80, 80)
237+
test_features = self.get_features(images2)
221238

222-
model = RPNModule(images)
239+
model = RPNModule()
223240
model.eval()
224-
model(features)
225-
self.run_model(model, [(features,), (test_features,)], tolerate_small_mismatch=True)
241+
model(images, features)
242+
243+
self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True,
244+
input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
245+
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3],
246+
"input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3],
247+
"input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})
226248

227249
def test_multi_scale_roi_align(self):
228250

@@ -251,63 +273,73 @@ def forward(self, input, boxes):
251273

252274
def test_roi_heads(self):
253275
class RoiHeadsModule(torch.nn.Module):
254-
def __init__(self_module, images):
276+
def __init__(self_module):
255277
super(RoiHeadsModule, self_module).__init__()
256278
self_module.transform = self._init_test_generalized_rcnn_transform()
257279
self_module.rpn = self._init_test_rpn()
258280
self_module.roi_heads = self._init_test_roi_heads_faster_rcnn()
259-
self_module.original_image_sizes = [img.shape[-2:] for img in images]
260-
self_module.images = ImageList(images, [i.shape[-2:] for i in images])
261281

262-
def forward(self_module, features):
263-
proposals, _ = self_module.rpn(self_module.images, features)
264-
detections, _ = self_module.roi_heads(features, proposals, self_module.images.image_sizes)
282+
def forward(self_module, images, features):
283+
original_image_sizes = [img.shape[-2:] for img in images]
284+
images = ImageList(images, [i.shape[-2:] for i in images])
285+
proposals, _ = self_module.rpn(images, features)
286+
detections, _ = self_module.roi_heads(features, proposals, images.image_sizes)
265287
detections = self_module.transform.postprocess(detections,
266-
self_module.images.image_sizes,
267-
self_module.original_image_sizes)
288+
images.image_sizes,
289+
original_image_sizes)
268290
return detections
269291

270-
images = torch.rand(2, 3, 600, 600)
292+
images = torch.rand(2, 3, 100, 100)
271293
features = self.get_features(images)
272-
test_features = self.get_features(images)
294+
images2 = torch.rand(2, 3, 150, 150)
295+
test_features = self.get_features(images2)
273296

274-
model = RoiHeadsModule(images)
297+
model = RoiHeadsModule()
275298
model.eval()
276-
model(features)
277-
self.run_model(model, [(features,), (test_features,)])
299+
model(images, features)
278300

279-
def get_image_from_url(self, url):
301+
self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True,
302+
input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
303+
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3],
304+
"input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})
305+
306+
def get_image_from_url(self, url, size=None):
280307
import requests
281-
import numpy
282308
from PIL import Image
283309
from io import BytesIO
284310
from torchvision import transforms
285311

286312
data = requests.get(url)
287313
image = Image.open(BytesIO(data.content)).convert("RGB")
288-
image = image.resize((300, 200), Image.BILINEAR)
314+
315+
if size is None:
316+
size = (300, 200)
317+
image = image.resize(size, Image.BILINEAR)
289318

290319
to_tensor = transforms.ToTensor()
291320
return to_tensor(image)
292321

293322
def get_test_images(self):
294323
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
295-
image = self.get_image_from_url(url=image_url)
324+
image = self.get_image_from_url(url=image_url, size=(200, 300))
325+
296326
image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
297-
image2 = self.get_image_from_url(url=image_url2)
327+
image2 = self.get_image_from_url(url=image_url2, size=(250, 200))
328+
298329
images = [image]
299330
test_images = [image2]
300331
return images, test_images
301332

302333
def test_faster_rcnn(self):
303334
images, test_images = self.get_test_images()
304335

305-
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True,
306-
min_size=200,
307-
max_size=300)
336+
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
308337
model.eval()
309338
model(images)
310-
self.run_model(model, [(images,), (test_images,)])
339+
self.run_model(model, [(images,), (test_images,)], input_names=["images_tensors"],
340+
output_names=["outputs"],
341+
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
342+
tolerate_small_mismatch=True)
311343

312344
# Verify that paste_mask_in_image beahves the same in tracing.
313345
# This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image
@@ -350,7 +382,11 @@ def test_mask_rcnn(self):
350382
model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
351383
model.eval()
352384
model(images)
353-
self.run_model(model, [(images,), (test_images,)])
385+
self.run_model(model, [(images,), (test_images,)],
386+
input_names=["images_tensors"],
387+
output_names=["outputs"],
388+
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
389+
tolerate_small_mismatch=True)
354390

355391
# Verify that heatmaps_to_keypoints behaves the same in tracing.
356392
# This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
@@ -385,9 +421,7 @@ def test_keypoint_rcnn(self):
385421
class KeyPointRCNN(torch.nn.Module):
386422
def __init__(self):
387423
super(KeyPointRCNN, self).__init__()
388-
self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True,
389-
min_size=200,
390-
max_size=300)
424+
self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
391425

392426
def forward(self, images):
393427
output = self.model(images)
@@ -399,8 +433,12 @@ def forward(self, images):
399433
images, test_images = self.get_test_images()
400434
model = KeyPointRCNN()
401435
model.eval()
402-
model(test_images)
403-
self.run_model(model, [(images,), (test_images,)])
436+
model(images)
437+
self.run_model(model, [(images,), (test_images,)],
438+
input_names=["images_tensors"],
439+
output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
440+
dynamic_axes={"images_tensors": [0, 1, 2, 3]},
441+
tolerate_small_mismatch=True)
404442

405443

406444
if __name__ == '__main__':

torchvision/models/detection/roi_heads.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -678,20 +678,13 @@ def postprocess_detections(self, class_logits, box_regression, proposals, image_
678678
device = class_logits.device
679679
num_classes = class_logits.shape[-1]
680680

681-
boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
681+
boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
682682
pred_boxes = self.box_coder.decode(box_regression, proposals)
683683

684684
pred_scores = F.softmax(class_logits, -1)
685685

686-
# split boxes and scores per image
687-
if len(boxes_per_image) == 1:
688-
# TODO : remove this when ONNX support dynamic split sizes
689-
# and just assign to pred_boxes instead of pred_boxes_list
690-
pred_boxes_list = [pred_boxes]
691-
pred_scores_list = [pred_scores]
692-
else:
693-
pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
694-
pred_scores_list = pred_scores.split(boxes_per_image, 0)
686+
pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
687+
pred_scores_list = pred_scores.split(boxes_per_image, 0)
695688

696689
all_boxes = []
697690
all_scores = []

torchvision/models/detection/rpn.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def num_anchors_per_location(self):
114114
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
115115
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
116116
def grid_anchors(self, grid_sizes, strides):
117-
# type: (List[List[int]], List[List[int]])
117+
# type: (List[List[int]], List[List[Tensor]])
118118
anchors = []
119119
cell_anchors = self.cell_anchors
120120
assert cell_anchors is not None
@@ -124,10 +124,6 @@ def grid_anchors(self, grid_sizes, strides):
124124
):
125125
grid_height, grid_width = size
126126
stride_height, stride_width = stride
127-
if torchvision._is_tracing():
128-
# required in ONNX export for mult operation with float32
129-
stride_width = torch.tensor(stride_width, dtype=torch.float32)
130-
stride_height = torch.tensor(stride_height, dtype=torch.float32)
131127
device = base_anchors.device
132128

133129
# For output anchor, compute [x_center, y_center, x_center, y_center]
@@ -151,8 +147,8 @@ def grid_anchors(self, grid_sizes, strides):
151147
return anchors
152148

153149
def cached_grid_anchors(self, grid_sizes, strides):
154-
# type: (List[List[int]], List[List[int]])
155-
key = str(grid_sizes + strides)
150+
# type: (List[List[int]], List[List[Tensor]])
151+
key = str(grid_sizes) + str(strides)
156152
if key in self._cache:
157153
return self._cache[key]
158154
anchors = self.grid_anchors(grid_sizes, strides)
@@ -163,9 +159,9 @@ def forward(self, image_list, feature_maps):
163159
# type: (ImageList, List[Tensor])
164160
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
165161
image_size = image_list.tensors.shape[-2:]
166-
strides = [[int(image_size[0] / g[0]), int(image_size[1] / g[1])] for g in grid_sizes]
167-
168162
dtype, device = feature_maps[0].dtype, feature_maps[0].device
163+
strides = [[torch.tensor(image_size[0] / g[0], dtype=torch.int64, device=device),
164+
torch.tensor(image_size[1] / g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
169165
self.set_cell_anchors(dtype, device)
170166
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
171167
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
@@ -480,7 +476,8 @@ def forward(self, images, features, targets=None):
480476
anchors = self.anchor_generator(images, features)
481477

482478
num_images = len(anchors)
483-
num_anchors_per_level = [o[0].numel() for o in objectness]
479+
num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
480+
num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
484481
objectness, pred_bbox_deltas = \
485482
concat_box_prediction_layers(objectness, pred_bbox_deltas)
486483
# apply pred_bbox_deltas to anchors to obtain the decoded proposals

torchvision/models/detection/transform.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def resize(self, image, target):
8888
if max_size * scale_factor > self.max_size:
8989
scale_factor = self.max_size / max_size
9090
image = torch.nn.functional.interpolate(
91-
image[None], scale_factor=scale_factor, mode='bilinear', align_corners=False)[0]
91+
image[None], scale_factor=scale_factor, mode='bilinear',
92+
align_corners=False)[0]
9293

9394
if target is None:
9495
return image, target
@@ -191,7 +192,8 @@ def __repr__(self):
191192

192193
def resize_keypoints(keypoints, original_size, new_size):
193194
# type: (Tensor, List[int], List[int])
194-
ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)]
195+
ratios = [torch.tensor(s, dtype=torch.float32, device=keypoints.device) / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
196+
for s, s_orig in zip(new_size, original_size)]
195197
ratio_h, ratio_w = ratios
196198
resized_data = keypoints.clone()
197199
if torch._C._get_tracing_state():
@@ -206,7 +208,8 @@ def resize_keypoints(keypoints, original_size, new_size):
206208

207209
def resize_boxes(boxes, original_size, new_size):
208210
# type: (Tensor, List[int], List[int])
209-
ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)]
211+
ratios = [torch.tensor(s, dtype=torch.float32, device=boxes.device) / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
212+
for s, s_orig in zip(new_size, original_size)]
210213
ratio_height, ratio_width = ratios
211214
xmin, ymin, xmax, ymax = boxes.unbind(1)
212215

torchvision/ops/boxes.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from torch.jit.annotations import Tuple
33
from torch import Tensor
4+
import torchvision
45

56

67
def nms(boxes, scores, iou_threshold):
@@ -110,8 +111,16 @@ def clip_boxes_to_image(boxes, size):
110111
boxes_x = boxes[..., 0::2]
111112
boxes_y = boxes[..., 1::2]
112113
height, width = size
113-
boxes_x = boxes_x.clamp(min=0, max=width)
114-
boxes_y = boxes_y.clamp(min=0, max=height)
114+
115+
if torchvision._is_tracing():
116+
boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
117+
boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
118+
boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
119+
boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
120+
else:
121+
boxes_x = boxes_x.clamp(min=0, max=width)
122+
boxes_y = boxes_y.clamp(min=0, max=height)
123+
115124
clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
116125
return clipped_boxes.reshape(boxes.shape)
117126

0 commit comments

Comments
 (0)