|
14 | 14 | default_device, |
15 | 15 | slow_test, |
16 | 16 | requires_torch_greater_or_equal, |
| 17 | + check_run_test_on_diff_or_main, |
17 | 18 | ) |
18 | 19 |
|
19 | 20 |
|
20 | 21 | class BaseModelTester(unittest.TestCase): |
21 | 22 | test_encoder_name = ( |
22 | 23 | "tu-test_resnet.r160_in1k" if has_timm_test_models else "resnet18" |
23 | 24 | ) |
| 25 | + files_for_diff = [r".*"] |
24 | 26 |
|
25 | 27 | # should be overriden |
26 | 28 | test_model_type = None |
@@ -234,3 +236,21 @@ def test_preserve_forward_output(self): |
234 | 236 | is_close = torch.allclose(output, output_tensor, atol=5e-2) |
235 | 237 | max_diff = torch.max(torch.abs(output - output_tensor)) |
236 | 238 | self.assertTrue(is_close, f"Max diff: {max_diff}") |
| 239 | + |
| 240 | + @pytest.mark.compile |
| 241 | + def test_compile(self): |
| 242 | + if not check_run_test_on_diff_or_main(self.files_for_diff): |
| 243 | + self.skipTest("No diff and not on `main`.") |
| 244 | + |
| 245 | + sample = self._get_sample( |
| 246 | + batch_size=self.default_batch_size, |
| 247 | + num_channels=self.default_num_channels, |
| 248 | + height=self.default_height, |
| 249 | + width=self.default_width, |
| 250 | + ).to(default_device) |
| 251 | + |
| 252 | + model = self.get_default_model() |
| 253 | + compiled_model = torch.compile(model, fullgraph=True, dynamic=True) |
| 254 | + |
| 255 | + with torch.inference_mode(): |
| 256 | + compiled_model(sample) |
0 commit comments