@@ -56,7 +56,11 @@ def decoder_channels(self):
5656 return None
5757
5858 @lru_cache
59- def _get_sample (self , batch_size = 1 , num_channels = 3 , height = 32 , width = 32 ):
59+ def _get_sample (self , batch_size = None , num_channels = None , height = None , width = None ):
60+ batch_size = batch_size or self .default_batch_size
61+ num_channels = num_channels or self .default_num_channels
62+ height = height or self .default_height
63+ width = width or self .default_width
6064 return torch .rand (batch_size , num_channels , height , width )
6165
6266 @lru_cache
@@ -66,12 +70,7 @@ def get_default_model(self):
6670 return model
6771
6872 def test_forward_backward (self ):
69- sample = self ._get_sample (
70- batch_size = self .default_batch_size ,
71- num_channels = self .default_num_channels ,
72- height = self .default_height ,
73- width = self .default_width ,
74- ).to (default_device )
73+ sample = self ._get_sample ().to (default_device )
7574
7675 model = self .get_default_model ()
7776
@@ -111,12 +110,7 @@ def test_in_channels_and_depth_and_out_classes(
111110 .eval ()
112111 )
113112
114- sample = self ._get_sample (
115- batch_size = self .default_batch_size ,
116- num_channels = in_channels ,
117- height = self .default_height ,
118- width = self .default_width ,
119- ).to (default_device )
113+ sample = self ._get_sample (num_channels = in_channels ).to (default_device )
120114
121115 # check in channels correctly set
122116 with torch .inference_mode ():
@@ -145,12 +139,7 @@ def test_classification_head(self):
145139 self .assertIsInstance (model .classification_head [3 ], torch .nn .Linear )
146140 self .assertIsInstance (model .classification_head [4 ].activation , torch .nn .Sigmoid )
147141
148- sample = self ._get_sample (
149- batch_size = self .default_batch_size ,
150- num_channels = self .default_num_channels ,
151- height = self .default_height ,
152- width = self .default_width ,
153- ).to (default_device )
142+ sample = self ._get_sample ().to (default_device )
154143
155144 with torch .inference_mode ():
156145 _ , cls_probs = model (sample )
@@ -163,8 +152,6 @@ def test_any_resolution(self):
163152 self .skipTest ("Model requires divisible input shape" )
164153
165154 sample = self ._get_sample (
166- batch_size = self .default_batch_size ,
167- num_channels = self .default_num_channels ,
168155 height = self .default_height + 3 ,
169156 width = self .default_width + 7 ,
170157 ).to (default_device )
@@ -193,12 +180,7 @@ def test_save_load_with_hub_mixin(self):
193180 readme = f .read ()
194181
195182 # check inference is correct
196- sample = self ._get_sample (
197- batch_size = self .default_batch_size ,
198- num_channels = self .default_num_channels ,
199- height = self .default_height ,
200- width = self .default_width ,
201- ).to (default_device )
183+ sample = self ._get_sample ().to (default_device )
202184
203185 with torch .inference_mode ():
204186 output = model (sample )
@@ -242,12 +224,7 @@ def test_compile(self):
242224 if not check_run_test_on_diff_or_main (self .files_for_diff ):
243225 self .skipTest ("No diff and not on `main`." )
244226
245- sample = self ._get_sample (
246- batch_size = self .default_batch_size ,
247- num_channels = self .default_num_channels ,
248- height = self .default_height ,
249- width = self .default_width ,
250- ).to (default_device )
227+ sample = self ._get_sample ().to (default_device )
251228
252229 model = self .get_default_model ()
253230 compiled_model = torch .compile (model , fullgraph = True , dynamic = True )
@@ -260,13 +237,7 @@ def test_torch_export(self):
260237 if not check_run_test_on_diff_or_main (self .files_for_diff ):
261238 self .skipTest ("No diff and not on `main`." )
262239
263- sample = self ._get_sample (
264- batch_size = self .default_batch_size ,
265- num_channels = self .default_num_channels ,
266- height = self .default_height ,
267- width = self .default_width ,
268- ).to (default_device )
269-
240+ sample = self ._get_sample ().to (default_device )
270241 model = self .get_default_model ()
271242 model .eval ()
272243
@@ -282,3 +253,21 @@ def test_torch_export(self):
282253
283254 self .assertEqual (eager_output .shape , exported_output .shape )
284255 torch .testing .assert_close (eager_output , exported_output )
256+
257+ @pytest .mark .torch_script
258+ def test_torch_script (self ):
259+ if not check_run_test_on_diff_or_main (self .files_for_diff ):
260+ self .skipTest ("No diff and not on `main`." )
261+
262+ sample = self ._get_sample ().to (default_device )
263+ model = self .get_default_model ()
264+ model .eval ()
265+
266+ scripted_model = torch .jit .script (model )
267+
268+ with torch .inference_mode ():
269+ scripted_output = scripted_model (sample )
270+ eager_output = model (sample )
271+
272+ self .assertEqual (scripted_output .shape , eager_output .shape )
273+ torch .testing .assert_close (scripted_output , eager_output )
0 commit comments