@@ -74,72 +74,114 @@ def get_available_video_models():
7474}
7575
7676
77+ # The following models exhibit flaky numerics under autocast in _test_*_model harnesses.
78+ # This may be caused by the harness environment (e.g. num classes, input initialization
79+ # via torch.rand), and does not prove autocast is unsuitable when training with real data
80+ # (autocast has been used successfully with real data for some of these models).
81+ # TODO: investigate why autocast numerics are flaky in the harnesses.
82+ #
83+ # For the following models, _test_*_model harnesses skip numerical checks on outputs when
84+ # trying autocast. However, they still try an autocasted forward pass, so they still ensure
85+ # autocast coverage suffices to prevent dtype errors in each model.
86+ autocast_flaky_numerics = (
87+ "fasterrcnn_resnet50_fpn" ,
88+ "inception_v3" ,
89+ "keypointrcnn_resnet50_fpn" ,
90+ "maskrcnn_resnet50_fpn" ,
91+ "resnet101" ,
92+ "resnet152" ,
93+ "wide_resnet101_2" ,
94+ )
95+
96+
7797class ModelTester (TestCase ):
7898 def checkModule (self , model , name , args ):
7999 if name not in script_test_models :
80100 return
81101 unwrapper = script_test_models [name ].get ('unwrapper' , None )
82102 return super (ModelTester , self ).checkModule (model , args , unwrapper = unwrapper , skip = False )
83103
84- def _test_classification_model (self , name , input_shape ):
104+ def _test_classification_model (self , name , input_shape , dev ):
85105 set_rng_seed (0 )
86106 # passing num_class equal to a number other than 1000 helps in making the test
87107 # more enforcing in nature
88108 model = models .__dict__ [name ](num_classes = 50 )
89- model .eval ()
90- x = torch .rand (input_shape )
109+ model .eval ().to (device = dev )
110+ # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
111+ x = torch .rand (input_shape ).to (device = dev )
91112 out = model (x )
92- self .assertExpected (out , prec = 0.1 )
113+ self .assertExpected (out . cpu () , prec = 0.1 , strip_suffix = "_" + dev )
93114 self .assertEqual (out .shape [- 1 ], 50 )
94115 self .checkModule (model , name , (x ,))
95116
96- def _test_segmentation_model (self , name ):
117+ if dev == "cuda" :
118+ with torch .cuda .amp .autocast ():
119+ out = model (x )
120+ # See autocast_flaky_numerics comment at top of file.
121+ if name not in autocast_flaky_numerics :
122+ self .assertExpected (out .cpu (), prec = 0.1 , strip_suffix = "_" + dev )
123+ self .assertEqual (out .shape [- 1 ], 50 )
124+
125+ def _test_segmentation_model (self , name , dev ):
97126 # passing num_class equal to a number other than 1000 helps in making the test
98127 # more enforcing in nature
99128 model = models .segmentation .__dict__ [name ](num_classes = 50 , pretrained_backbone = False )
100- model .eval ()
129+ model .eval (). to ( device = dev )
101130 input_shape = (1 , 3 , 300 , 300 )
102- x = torch .rand (input_shape )
131+ # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
132+ x = torch .rand (input_shape ).to (device = dev )
103133 out = model (x )
104134 self .assertEqual (tuple (out ["out" ].shape ), (1 , 50 , 300 , 300 ))
105135 self .checkModule (model , name , (x ,))
106136
107- def _test_detection_model (self , name ):
137+ if dev == "cuda" :
138+ with torch .cuda .amp .autocast ():
139+ out = model (x )
140+ self .assertEqual (tuple (out ["out" ].shape ), (1 , 50 , 300 , 300 ))
141+
142+ def _test_detection_model (self , name , dev ):
108143 set_rng_seed (0 )
109144 model = models .detection .__dict__ [name ](num_classes = 50 , pretrained_backbone = False )
110- model .eval ()
145+ model .eval (). to ( device = dev )
111146 input_shape = (3 , 300 , 300 )
112- x = torch .rand (input_shape )
147+ # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
148+ x = torch .rand (input_shape ).to (device = dev )
113149 model_input = [x ]
114150 out = model (model_input )
115151 self .assertIs (model_input [0 ], x )
116- self .assertEqual (len (out ), 1 )
117152
118- def subsample_tensor (tensor ):
119- num_elems = tensor .numel ()
120- num_samples = 20
121- if num_elems <= num_samples :
122- return tensor
123-
124- flat_tensor = tensor .flatten ()
125- ith_index = num_elems // num_samples
126- return flat_tensor [ith_index - 1 ::ith_index ]
127-
128- def compute_mean_std (tensor ):
129- # can't compute mean of integral tensor
130- tensor = tensor .to (torch .double )
131- mean = torch .mean (tensor )
132- std = torch .std (tensor )
133- return {"mean" : mean , "std" : std }
134-
135- # maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now
136- # compare results with mean and std
137- if name == "maskrcnn_resnet50_fpn" :
138- test_value = map_nested_tensor_object (out , tensor_map_fn = compute_mean_std )
139- # mean values are small, use large prec
140- self .assertExpected (test_value , prec = .01 )
141- else :
142- self .assertExpected (map_nested_tensor_object (out , tensor_map_fn = subsample_tensor ), prec = 0.01 )
153+ def check_out (out ):
154+ self .assertEqual (len (out ), 1 )
155+
156+ def subsample_tensor (tensor ):
157+ num_elems = tensor .numel ()
158+ num_samples = 20
159+ if num_elems <= num_samples :
160+ return tensor
161+
162+ flat_tensor = tensor .flatten ()
163+ ith_index = num_elems // num_samples
164+ return flat_tensor [ith_index - 1 ::ith_index ]
165+
166+ def compute_mean_std (tensor ):
167+ # can't compute mean of integral tensor
168+ tensor = tensor .to (torch .double )
169+ mean = torch .mean (tensor )
170+ std = torch .std (tensor )
171+ return {"mean" : mean , "std" : std }
172+
173+ # maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now
174+ # compare results with mean and std
175+ if name == "maskrcnn_resnet50_fpn" :
176+ test_value = map_nested_tensor_object (out , tensor_map_fn = compute_mean_std )
177+ # mean values are small, use large prec
178+ self .assertExpected (test_value , prec = .01 , strip_suffix = "_" + dev )
179+ else :
180+ self .assertExpected (map_nested_tensor_object (out , tensor_map_fn = subsample_tensor ),
181+ prec = 0.01 ,
182+ strip_suffix = "_" + dev )
183+
184+ check_out (out )
143185
144186 scripted_model = torch .jit .script (model )
145187 scripted_model .eval ()
@@ -156,6 +198,13 @@ def compute_mean_std(tensor):
156198 # self.check_script(model, name)
157199 self .checkModule (model , name , ([x ],))
158200
201+ if dev == "cuda" :
202+ with torch .cuda .amp .autocast ():
203+ out = model (model_input )
204+ # See autocast_flaky_numerics comment at top of file.
205+ if name not in autocast_flaky_numerics :
206+ check_out (out )
207+
159208 def _test_detection_model_validation (self , name ):
160209 set_rng_seed (0 )
161210 model = models .detection .__dict__ [name ](num_classes = 50 , pretrained_backbone = False )
@@ -179,18 +228,24 @@ def _test_detection_model_validation(self, name):
179228 targets = [{'boxes' : boxes }]
180229 self .assertRaises (ValueError , model , x , targets = targets )
181230
182- def _test_video_model (self , name ):
231+ def _test_video_model (self , name , dev ):
183232 # the default input shape is
184233 # bs * num_channels * clip_len * h *w
185234 input_shape = (1 , 3 , 4 , 112 , 112 )
186235 # test both basicblock and Bottleneck
187236 model = models .video .__dict__ [name ](num_classes = 50 )
188- model .eval ()
189- x = torch .rand (input_shape )
237+ model .eval ().to (device = dev )
238+ # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
239+ x = torch .rand (input_shape ).to (device = dev )
190240 out = model (x )
191241 self .checkModule (model , name , (x ,))
192242 self .assertEqual (out .shape [- 1 ], 50 )
193243
244+ if dev == "cuda" :
245+ with torch .cuda .amp .autocast ():
246+ out = model (x )
247+ self .assertEqual (out .shape [- 1 ], 50 )
248+
194249 def _make_sliced_model (self , model , stop_layer ):
195250 layers = OrderedDict ()
196251 for name , layer in model .named_children ():
@@ -272,6 +327,12 @@ def test_googlenet_eval(self):
272327
273328 @unittest .skipIf (not torch .cuda .is_available (), 'needs GPU' )
274329 def test_fasterrcnn_switch_devices (self ):
330+ def checkOut (out ):
331+ self .assertEqual (len (out ), 1 )
332+ self .assertTrue ("boxes" in out [0 ])
333+ self .assertTrue ("scores" in out [0 ])
334+ self .assertTrue ("labels" in out [0 ])
335+
275336 model = models .detection .fasterrcnn_resnet50_fpn (num_classes = 50 , pretrained_backbone = False )
276337 model .cuda ()
277338 model .eval ()
@@ -280,17 +341,20 @@ def test_fasterrcnn_switch_devices(self):
280341 model_input = [x ]
281342 out = model (model_input )
282343 self .assertIs (model_input [0 ], x )
283- self .assertEqual (len (out ), 1 )
284- self .assertTrue ("boxes" in out [0 ])
285- self .assertTrue ("scores" in out [0 ])
286- self .assertTrue ("labels" in out [0 ])
344+
345+ checkOut (out )
346+
347+ with torch .cuda .amp .autocast ():
348+ out = model (model_input )
349+
350+ checkOut (out )
351+
287352 # now switch to cpu and make sure it works
288353 model .cpu ()
289354 x = x .cpu ()
290355 out_cpu = model ([x ])
291- self .assertTrue ("boxes" in out_cpu [0 ])
292- self .assertTrue ("scores" in out_cpu [0 ])
293- self .assertTrue ("labels" in out_cpu [0 ])
356+
357+ checkOut (out_cpu )
294358
295359 def test_generalizedrcnn_transform_repr (self ):
296360
@@ -312,34 +376,40 @@ def test_generalizedrcnn_transform_repr(self):
312376 self .assertEqual (t .__repr__ (), expected_string )
313377
314378
379+ _devs = ["cpu" , "cuda" ] if torch .cuda .is_available () else ["cpu" ]
380+
381+
315382for model_name in get_available_classification_models ():
316- # for-loop bodies don't define scopes, so we have to save the variables
317- # we want to close over in some way
318- def do_test (self , model_name = model_name ):
319- input_shape = (1 , 3 , 224 , 224 )
320- if model_name in ['inception_v3' ]:
321- input_shape = (1 , 3 , 299 , 299 )
322- self ._test_classification_model (model_name , input_shape )
383+ for dev in _devs :
384+ # for-loop bodies don't define scopes, so we have to save the variables
385+ # we want to close over in some way
386+ def do_test (self , model_name = model_name , dev = dev ):
387+ input_shape = (1 , 3 , 224 , 224 )
388+ if model_name in ['inception_v3' ]:
389+ input_shape = (1 , 3 , 299 , 299 )
390+ self ._test_classification_model (model_name , input_shape , dev )
323391
324- setattr (ModelTester , "test_" + model_name , do_test )
392+ setattr (ModelTester , "test_" + model_name + "_" + dev , do_test )
325393
326394
327395for model_name in get_available_segmentation_models ():
328- # for-loop bodies don't define scopes, so we have to save the variables
329- # we want to close over in some way
330- def do_test (self , model_name = model_name ):
331- self ._test_segmentation_model (model_name )
396+ for dev in _devs :
397+ # for-loop bodies don't define scopes, so we have to save the variables
398+ # we want to close over in some way
399+ def do_test (self , model_name = model_name , dev = dev ):
400+ self ._test_segmentation_model (model_name , dev )
332401
333- setattr (ModelTester , "test_" + model_name , do_test )
402+ setattr (ModelTester , "test_" + model_name + "_" + dev , do_test )
334403
335404
336405for model_name in get_available_detection_models ():
337- # for-loop bodies don't define scopes, so we have to save the variables
338- # we want to close over in some way
339- def do_test (self , model_name = model_name ):
340- self ._test_detection_model (model_name )
406+ for dev in _devs :
407+ # for-loop bodies don't define scopes, so we have to save the variables
408+ # we want to close over in some way
409+ def do_test (self , model_name = model_name , dev = dev ):
410+ self ._test_detection_model (model_name , dev )
341411
342- setattr (ModelTester , "test_" + model_name , do_test )
412+ setattr (ModelTester , "test_" + model_name + "_" + dev , do_test )
343413
344414 def do_validation_test (self , model_name = model_name ):
345415 self ._test_detection_model_validation (model_name )
@@ -348,11 +418,11 @@ def do_validation_test(self, model_name=model_name):
348418
349419
350420for model_name in get_available_video_models ():
421+ for dev in _devs :
422+ def do_test (self , model_name = model_name , dev = dev ):
423+ self ._test_video_model (model_name , dev )
351424
352- def do_test (self , model_name = model_name ):
353- self ._test_video_model (model_name )
354-
355- setattr (ModelTester , "test_" + model_name , do_test )
425+ setattr (ModelTester , "test_" + model_name + "_" + dev , do_test )
356426
357427if __name__ == '__main__' :
358428 unittest .main ()
0 commit comments