Skip to content

Commit 8a0dca5

Browse files
committed
Update export test
1 parent 8d63787 commit 8a0dca5

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

tests/models/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,11 @@ def test_compile(self):
252252
compiled_model(sample)
253253

254254
@pytest.mark.torch_export
255-
def test_torch_export(self):
255+
def test_torch_export(self, eps=1e-5):
256256
if not check_run_test_on_diff_or_main(self.files_for_diff):
257257
self.skipTest("No diff and not on `main`.")
258258

259+
torch.manual_seed(42)
259260
sample = self._get_sample().to(default_device)
260261
model = self.get_default_model()
261262
model.eval()
@@ -271,7 +272,7 @@ def test_torch_export(self):
271272
exported_output = exported_model.module().forward(sample)
272273

273274
self.assertEqual(eager_output.shape, exported_output.shape)
274-
torch.testing.assert_close(eager_output, exported_output)
275+
torch.testing.assert_close(eager_output, exported_output, rtol=eps, atol=eps)
275276

276277
@pytest.mark.torch_script
277278
def test_torch_script(self):

tests/models/test_upernet.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pytest
2+
13
from tests.models import base
24

35

@@ -6,3 +8,7 @@ class TestUnetModel(base.BaseModelTester):
68
files_for_diff = [r"decoders/upernet/", r"base/"]
79

810
default_batch_size = 2
11+
12+
@pytest.mark.torch_export
13+
def test_torch_export(self):
14+
super().test_torch_export(eps=1e-3)

0 commit comments

Comments
 (0)