@@ -28,14 +28,15 @@ class ONNXExporterTester(unittest.TestCase):
28
28
def setUpClass (cls ):
29
29
torch .manual_seed (123 )
30
30
31
- def run_model (self , model , inputs_list , tolerate_small_mismatch = False , do_constant_folding = True ):
31
+ def run_model (self , model , inputs_list , tolerate_small_mismatch = False , do_constant_folding = True , dynamic_axes = None ,
32
+ output_names = None , input_names = None ):
32
33
model .eval ()
33
34
34
35
onnx_io = io .BytesIO ()
35
36
# export to onnx with the first input
36
37
torch .onnx .export (model , inputs_list [0 ], onnx_io ,
37
- do_constant_folding = do_constant_folding , opset_version = _onnx_opset_version )
38
-
38
+ do_constant_folding = do_constant_folding , opset_version = _onnx_opset_version ,
39
+ dynamic_axes = dynamic_axes , input_names = input_names , output_names = output_names )
39
40
# validate the exported model with onnx runtime
40
41
for test_inputs in inputs_list :
41
42
with torch .no_grad ():
@@ -99,6 +100,21 @@ def forward(self, boxes, scores):
99
100
100
101
self .run_model (Module (), [(boxes , scores )])
101
102
103
+ def test_clip_boxes_to_image (self ):
104
+ boxes = torch .randn (5 , 4 ) * 500
105
+ boxes [:, 2 :] += boxes [:, :2 ]
106
+ size = torch .randn (200 , 300 )
107
+
108
+ size_2 = torch .randn (300 , 400 )
109
+
110
+ class Module (torch .nn .Module ):
111
+ def forward (self , boxes , size ):
112
+ return ops .boxes .clip_boxes_to_image (boxes , size .shape )
113
+
114
+ self .run_model (Module (), [(boxes , size ), (boxes , size_2 )],
115
+ input_names = ["boxes" , "size" ],
116
+ dynamic_axes = {"size" : [0 , 1 ]})
117
+
102
118
def test_roi_align (self ):
103
119
x = torch .rand (1 , 1 , 10 , 10 , dtype = torch .float32 )
104
120
single_roi = torch .tensor ([[0 , 0 , 0 , 4 , 4 ]], dtype = torch .float32 )
@@ -123,9 +139,9 @@ def __init__(self_module):
123
139
def forward (self_module , images ):
124
140
return self_module .transform (images )[0 ].tensors
125
141
126
- input = [ torch .rand (3 , 100 , 200 ), torch .rand (3 , 200 , 200 )]
127
- input_test = [ torch .rand (3 , 100 , 200 ), torch .rand (3 , 200 , 200 )]
128
- self .run_model (TransformModule (), [input , input_test ])
142
+ input = torch .rand (3 , 100 , 200 ), torch .rand (3 , 200 , 200 )
143
+ input_test = torch .rand (3 , 100 , 200 ), torch .rand (3 , 200 , 200 )
144
+ self .run_model (TransformModule (), [( input ,), ( input_test ,) ])
129
145
130
146
def _init_test_generalized_rcnn_transform (self ):
131
147
min_size = 100
@@ -207,22 +223,28 @@ def get_features(self, images):
207
223
208
224
def test_rpn (self ):
209
225
class RPNModule (torch .nn .Module ):
210
- def __init__ (self_module , images ):
226
+ def __init__ (self_module ):
211
227
super (RPNModule , self_module ).__init__ ()
212
228
self_module .rpn = self ._init_test_rpn ()
213
- self_module .images = ImageList (images , [i .shape [- 2 :] for i in images ])
214
229
215
- def forward (self_module , features ):
216
- return self_module .rpn (self_module .images , features )
230
+ def forward (self_module , images , features ):
231
+ images = ImageList (images , [i .shape [- 2 :] for i in images ])
232
+ return self_module .rpn (images , features )
217
233
218
- images = torch .rand (2 , 3 , 600 , 600 )
234
+ images = torch .rand (2 , 3 , 150 , 150 )
219
235
features = self .get_features (images )
220
- test_features = self .get_features (images )
236
+ images2 = torch .rand (2 , 3 , 80 , 80 )
237
+ test_features = self .get_features (images2 )
221
238
222
- model = RPNModule (images )
239
+ model = RPNModule ()
223
240
model .eval ()
224
- model (features )
225
- self .run_model (model , [(features ,), (test_features ,)], tolerate_small_mismatch = True )
241
+ model (images , features )
242
+
243
+ self .run_model (model , [(images , features ), (images2 , test_features )], tolerate_small_mismatch = True ,
244
+ input_names = ["input1" , "input2" , "input3" , "input4" , "input5" , "input6" ],
245
+ dynamic_axes = {"input1" : [0 , 1 , 2 , 3 ], "input2" : [0 , 1 , 2 , 3 ],
246
+ "input3" : [0 , 1 , 2 , 3 ], "input4" : [0 , 1 , 2 , 3 ],
247
+ "input5" : [0 , 1 , 2 , 3 ], "input6" : [0 , 1 , 2 , 3 ]})
226
248
227
249
def test_multi_scale_roi_align (self ):
228
250
@@ -251,63 +273,73 @@ def forward(self, input, boxes):
251
273
252
274
def test_roi_heads (self ):
253
275
class RoiHeadsModule (torch .nn .Module ):
254
- def __init__ (self_module , images ):
276
+ def __init__ (self_module ):
255
277
super (RoiHeadsModule , self_module ).__init__ ()
256
278
self_module .transform = self ._init_test_generalized_rcnn_transform ()
257
279
self_module .rpn = self ._init_test_rpn ()
258
280
self_module .roi_heads = self ._init_test_roi_heads_faster_rcnn ()
259
- self_module .original_image_sizes = [img .shape [- 2 :] for img in images ]
260
- self_module .images = ImageList (images , [i .shape [- 2 :] for i in images ])
261
281
262
- def forward (self_module , features ):
263
- proposals , _ = self_module .rpn (self_module .images , features )
264
- detections , _ = self_module .roi_heads (features , proposals , self_module .images .image_sizes )
282
+ def forward (self_module , images , features ):
283
+ original_image_sizes = [img .shape [- 2 :] for img in images ]
284
+ images = ImageList (images , [i .shape [- 2 :] for i in images ])
285
+ proposals , _ = self_module .rpn (images , features )
286
+ detections , _ = self_module .roi_heads (features , proposals , images .image_sizes )
265
287
detections = self_module .transform .postprocess (detections ,
266
- self_module . images .image_sizes ,
267
- self_module . original_image_sizes )
288
+ images .image_sizes ,
289
+ original_image_sizes )
268
290
return detections
269
291
270
- images = torch .rand (2 , 3 , 600 , 600 )
292
+ images = torch .rand (2 , 3 , 100 , 100 )
271
293
features = self .get_features (images )
272
- test_features = self .get_features (images )
294
+ images2 = torch .rand (2 , 3 , 150 , 150 )
295
+ test_features = self .get_features (images2 )
273
296
274
- model = RoiHeadsModule (images )
297
+ model = RoiHeadsModule ()
275
298
model .eval ()
276
- model (features )
277
- self .run_model (model , [(features ,), (test_features ,)])
299
+ model (images , features )
278
300
279
- def get_image_from_url (self , url ):
301
+ self .run_model (model , [(images , features ), (images2 , test_features )], tolerate_small_mismatch = True ,
302
+ input_names = ["input1" , "input2" , "input3" , "input4" , "input5" , "input6" ],
303
+ dynamic_axes = {"input1" : [0 , 1 , 2 , 3 ], "input2" : [0 , 1 , 2 , 3 ], "input3" : [0 , 1 , 2 , 3 ],
304
+ "input4" : [0 , 1 , 2 , 3 ], "input5" : [0 , 1 , 2 , 3 ], "input6" : [0 , 1 , 2 , 3 ]})
305
+
306
+ def get_image_from_url (self , url , size = None ):
280
307
import requests
281
- import numpy
282
308
from PIL import Image
283
309
from io import BytesIO
284
310
from torchvision import transforms
285
311
286
312
data = requests .get (url )
287
313
image = Image .open (BytesIO (data .content )).convert ("RGB" )
288
- image = image .resize ((300 , 200 ), Image .BILINEAR )
314
+
315
+ if size is None :
316
+ size = (300 , 200 )
317
+ image = image .resize (size , Image .BILINEAR )
289
318
290
319
to_tensor = transforms .ToTensor ()
291
320
return to_tensor (image )
292
321
293
322
def get_test_images (self ):
294
323
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
295
- image = self .get_image_from_url (url = image_url )
324
+ image = self .get_image_from_url (url = image_url , size = (200 , 300 ))
325
+
296
326
image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
297
- image2 = self .get_image_from_url (url = image_url2 )
327
+ image2 = self .get_image_from_url (url = image_url2 , size = (250 , 200 ))
328
+
298
329
images = [image ]
299
330
test_images = [image2 ]
300
331
return images , test_images
301
332
302
333
def test_faster_rcnn (self ):
303
334
images , test_images = self .get_test_images ()
304
335
305
- model = models .detection .faster_rcnn .fasterrcnn_resnet50_fpn (pretrained = True ,
306
- min_size = 200 ,
307
- max_size = 300 )
336
+ model = models .detection .faster_rcnn .fasterrcnn_resnet50_fpn (pretrained = True , min_size = 200 , max_size = 300 )
308
337
model .eval ()
309
338
model (images )
310
- self .run_model (model , [(images ,), (test_images ,)])
339
+ self .run_model (model , [(images ,), (test_images ,)], input_names = ["images_tensors" ],
340
+ output_names = ["outputs" ],
341
+ dynamic_axes = {"images_tensors" : [0 , 1 , 2 , 3 ], "outputs" : [0 , 1 , 2 , 3 ]},
342
+ tolerate_small_mismatch = True )
311
343
312
344
# Verify that paste_mask_in_image beahves the same in tracing.
313
345
# This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image
@@ -350,7 +382,11 @@ def test_mask_rcnn(self):
350
382
model = models .detection .mask_rcnn .maskrcnn_resnet50_fpn (pretrained = True , min_size = 200 , max_size = 300 )
351
383
model .eval ()
352
384
model (images )
353
- self .run_model (model , [(images ,), (test_images ,)])
385
+ self .run_model (model , [(images ,), (test_images ,)],
386
+ input_names = ["images_tensors" ],
387
+ output_names = ["outputs" ],
388
+ dynamic_axes = {"images_tensors" : [0 , 1 , 2 , 3 ], "outputs" : [0 , 1 , 2 , 3 ]},
389
+ tolerate_small_mismatch = True )
354
390
355
391
# Verify that heatmaps_to_keypoints behaves the same in tracing.
356
392
# This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
@@ -385,9 +421,7 @@ def test_keypoint_rcnn(self):
385
421
class KeyPointRCNN (torch .nn .Module ):
386
422
def __init__ (self ):
387
423
super (KeyPointRCNN , self ).__init__ ()
388
- self .model = models .detection .keypoint_rcnn .keypointrcnn_resnet50_fpn (pretrained = True ,
389
- min_size = 200 ,
390
- max_size = 300 )
424
+ self .model = models .detection .keypoint_rcnn .keypointrcnn_resnet50_fpn (pretrained = True , min_size = 200 , max_size = 300 )
391
425
392
426
def forward (self , images ):
393
427
output = self .model (images )
@@ -399,8 +433,12 @@ def forward(self, images):
399
433
images , test_images = self .get_test_images ()
400
434
model = KeyPointRCNN ()
401
435
model .eval ()
402
- model (test_images )
403
- self .run_model (model , [(images ,), (test_images ,)])
436
+ model (images )
437
+ self .run_model (model , [(images ,), (test_images ,)],
438
+ input_names = ["images_tensors" ],
439
+ output_names = ["outputs1" , "outputs2" , "outputs3" , "outputs4" ],
440
+ dynamic_axes = {"images_tensors" : [0 , 1 , 2 , 3 ]},
441
+ tolerate_small_mismatch = True )
404
442
405
443
406
444
if __name__ == '__main__' :
0 commit comments