Skip to content

Commit d4d4cf6

Browse files
committed
Add test for torch script
1 parent 257da0b commit d4d4cf6

File tree

1 file changed

+29
-40
lines changed

1 file changed

+29
-40
lines changed

tests/models/base.py

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ def decoder_channels(self):
5656
return None
5757

5858
@lru_cache
59-
def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32):
59+
def _get_sample(self, batch_size=None, num_channels=None, height=None, width=None):
60+
batch_size = batch_size or self.default_batch_size
61+
num_channels = num_channels or self.default_num_channels
62+
height = height or self.default_height
63+
width = width or self.default_width
6064
return torch.rand(batch_size, num_channels, height, width)
6165

6266
@lru_cache
@@ -66,12 +70,7 @@ def get_default_model(self):
6670
return model
6771

6872
def test_forward_backward(self):
69-
sample = self._get_sample(
70-
batch_size=self.default_batch_size,
71-
num_channels=self.default_num_channels,
72-
height=self.default_height,
73-
width=self.default_width,
74-
).to(default_device)
73+
sample = self._get_sample().to(default_device)
7574

7675
model = self.get_default_model()
7776

@@ -111,12 +110,7 @@ def test_in_channels_and_depth_and_out_classes(
111110
.eval()
112111
)
113112

114-
sample = self._get_sample(
115-
batch_size=self.default_batch_size,
116-
num_channels=in_channels,
117-
height=self.default_height,
118-
width=self.default_width,
119-
).to(default_device)
113+
sample = self._get_sample(num_channels=in_channels).to(default_device)
120114

121115
# check in channels correctly set
122116
with torch.inference_mode():
@@ -145,12 +139,7 @@ def test_classification_head(self):
145139
self.assertIsInstance(model.classification_head[3], torch.nn.Linear)
146140
self.assertIsInstance(model.classification_head[4].activation, torch.nn.Sigmoid)
147141

148-
sample = self._get_sample(
149-
batch_size=self.default_batch_size,
150-
num_channels=self.default_num_channels,
151-
height=self.default_height,
152-
width=self.default_width,
153-
).to(default_device)
142+
sample = self._get_sample().to(default_device)
154143

155144
with torch.inference_mode():
156145
_, cls_probs = model(sample)
@@ -163,8 +152,6 @@ def test_any_resolution(self):
163152
self.skipTest("Model requires divisible input shape")
164153

165154
sample = self._get_sample(
166-
batch_size=self.default_batch_size,
167-
num_channels=self.default_num_channels,
168155
height=self.default_height + 3,
169156
width=self.default_width + 7,
170157
).to(default_device)
@@ -193,12 +180,7 @@ def test_save_load_with_hub_mixin(self):
193180
readme = f.read()
194181

195182
# check inference is correct
196-
sample = self._get_sample(
197-
batch_size=self.default_batch_size,
198-
num_channels=self.default_num_channels,
199-
height=self.default_height,
200-
width=self.default_width,
201-
).to(default_device)
183+
sample = self._get_sample().to(default_device)
202184

203185
with torch.inference_mode():
204186
output = model(sample)
@@ -242,12 +224,7 @@ def test_compile(self):
242224
if not check_run_test_on_diff_or_main(self.files_for_diff):
243225
self.skipTest("No diff and not on `main`.")
244226

245-
sample = self._get_sample(
246-
batch_size=self.default_batch_size,
247-
num_channels=self.default_num_channels,
248-
height=self.default_height,
249-
width=self.default_width,
250-
).to(default_device)
227+
sample = self._get_sample().to(default_device)
251228

252229
model = self.get_default_model()
253230
compiled_model = torch.compile(model, fullgraph=True, dynamic=True)
@@ -260,13 +237,7 @@ def test_torch_export(self):
260237
if not check_run_test_on_diff_or_main(self.files_for_diff):
261238
self.skipTest("No diff and not on `main`.")
262239

263-
sample = self._get_sample(
264-
batch_size=self.default_batch_size,
265-
num_channels=self.default_num_channels,
266-
height=self.default_height,
267-
width=self.default_width,
268-
).to(default_device)
269-
240+
sample = self._get_sample().to(default_device)
270241
model = self.get_default_model()
271242
model.eval()
272243

@@ -282,3 +253,21 @@ def test_torch_export(self):
282253

283254
self.assertEqual(eager_output.shape, exported_output.shape)
284255
torch.testing.assert_close(eager_output, exported_output)
256+
257+
@pytest.mark.torch_script
258+
def test_torch_script(self):
259+
if not check_run_test_on_diff_or_main(self.files_for_diff):
260+
self.skipTest("No diff and not on `main`.")
261+
262+
sample = self._get_sample().to(default_device)
263+
model = self.get_default_model()
264+
model.eval()
265+
266+
scripted_model = torch.jit.script(model)
267+
268+
with torch.inference_mode():
269+
scripted_output = scripted_model(sample)
270+
eager_output = model(sample)
271+
272+
self.assertEqual(scripted_output.shape, eager_output.shape)
273+
torch.testing.assert_close(scripted_output, eager_output)

0 commit comments

Comments
 (0)