@@ -56,7 +56,11 @@ def decoder_channels(self):
56
56
return None
57
57
58
58
@lru_cache
59
- def _get_sample (self , batch_size = 1 , num_channels = 3 , height = 32 , width = 32 ):
59
+ def _get_sample (self , batch_size = None , num_channels = None , height = None , width = None ):
60
+ batch_size = batch_size or self .default_batch_size
61
+ num_channels = num_channels or self .default_num_channels
62
+ height = height or self .default_height
63
+ width = width or self .default_width
60
64
return torch .rand (batch_size , num_channels , height , width )
61
65
62
66
@lru_cache
@@ -66,12 +70,7 @@ def get_default_model(self):
66
70
return model
67
71
68
72
def test_forward_backward (self ):
69
- sample = self ._get_sample (
70
- batch_size = self .default_batch_size ,
71
- num_channels = self .default_num_channels ,
72
- height = self .default_height ,
73
- width = self .default_width ,
74
- ).to (default_device )
73
+ sample = self ._get_sample ().to (default_device )
75
74
76
75
model = self .get_default_model ()
77
76
@@ -111,12 +110,7 @@ def test_in_channels_and_depth_and_out_classes(
111
110
.eval ()
112
111
)
113
112
114
- sample = self ._get_sample (
115
- batch_size = self .default_batch_size ,
116
- num_channels = in_channels ,
117
- height = self .default_height ,
118
- width = self .default_width ,
119
- ).to (default_device )
113
+ sample = self ._get_sample (num_channels = in_channels ).to (default_device )
120
114
121
115
# check in channels correctly set
122
116
with torch .inference_mode ():
@@ -145,12 +139,7 @@ def test_classification_head(self):
145
139
self .assertIsInstance (model .classification_head [3 ], torch .nn .Linear )
146
140
self .assertIsInstance (model .classification_head [4 ].activation , torch .nn .Sigmoid )
147
141
148
- sample = self ._get_sample (
149
- batch_size = self .default_batch_size ,
150
- num_channels = self .default_num_channels ,
151
- height = self .default_height ,
152
- width = self .default_width ,
153
- ).to (default_device )
142
+ sample = self ._get_sample ().to (default_device )
154
143
155
144
with torch .inference_mode ():
156
145
_ , cls_probs = model (sample )
@@ -163,8 +152,6 @@ def test_any_resolution(self):
163
152
self .skipTest ("Model requires divisible input shape" )
164
153
165
154
sample = self ._get_sample (
166
- batch_size = self .default_batch_size ,
167
- num_channels = self .default_num_channels ,
168
155
height = self .default_height + 3 ,
169
156
width = self .default_width + 7 ,
170
157
).to (default_device )
@@ -193,12 +180,7 @@ def test_save_load_with_hub_mixin(self):
193
180
readme = f .read ()
194
181
195
182
# check inference is correct
196
- sample = self ._get_sample (
197
- batch_size = self .default_batch_size ,
198
- num_channels = self .default_num_channels ,
199
- height = self .default_height ,
200
- width = self .default_width ,
201
- ).to (default_device )
183
+ sample = self ._get_sample ().to (default_device )
202
184
203
185
with torch .inference_mode ():
204
186
output = model (sample )
@@ -242,12 +224,7 @@ def test_compile(self):
242
224
if not check_run_test_on_diff_or_main (self .files_for_diff ):
243
225
self .skipTest ("No diff and not on `main`." )
244
226
245
- sample = self ._get_sample (
246
- batch_size = self .default_batch_size ,
247
- num_channels = self .default_num_channels ,
248
- height = self .default_height ,
249
- width = self .default_width ,
250
- ).to (default_device )
227
+ sample = self ._get_sample ().to (default_device )
251
228
252
229
model = self .get_default_model ()
253
230
compiled_model = torch .compile (model , fullgraph = True , dynamic = True )
@@ -260,13 +237,7 @@ def test_torch_export(self):
260
237
if not check_run_test_on_diff_or_main (self .files_for_diff ):
261
238
self .skipTest ("No diff and not on `main`." )
262
239
263
- sample = self ._get_sample (
264
- batch_size = self .default_batch_size ,
265
- num_channels = self .default_num_channels ,
266
- height = self .default_height ,
267
- width = self .default_width ,
268
- ).to (default_device )
269
-
240
+ sample = self ._get_sample ().to (default_device )
270
241
model = self .get_default_model ()
271
242
model .eval ()
272
243
@@ -282,3 +253,21 @@ def test_torch_export(self):
282
253
283
254
self .assertEqual (eager_output .shape , exported_output .shape )
284
255
torch .testing .assert_close (eager_output , exported_output )
256
+
257
+ @pytest .mark .torch_script
258
+ def test_torch_script (self ):
259
+ if not check_run_test_on_diff_or_main (self .files_for_diff ):
260
+ self .skipTest ("No diff and not on `main`." )
261
+
262
+ sample = self ._get_sample ().to (default_device )
263
+ model = self .get_default_model ()
264
+ model .eval ()
265
+
266
+ scripted_model = torch .jit .script (model )
267
+
268
+ with torch .inference_mode ():
269
+ scripted_output = scripted_model (sample )
270
+ eager_output = model (sample )
271
+
272
+ self .assertEqual (scripted_output .shape , eager_output .shape )
273
+ torch .testing .assert_close (scripted_output , eager_output )
0 commit comments