Skip to content

Commit 21049e9

Browse files
NicolasHugpmeier
andauthored
Use torch.testing.assert_close in test_models.py (#3879)
Co-authored-by: Philip Meier <[email protected]>
1 parent b96d381 commit 21049e9

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

test/test_models.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def check_out(out):
120120
# predictions match.
121121
expected_file = self._get_expected_file(name)
122122
expected = torch.load(expected_file)
123-
self.assertEqual(out.argmax(dim=1), expected.argmax(dim=1), prec=prec)
123+
torch.testing.assert_close(out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec)
124124
return False # Partial validation performed
125125

126126
return True # Full validation performed
@@ -205,7 +205,8 @@ def compute_mean_std(tensor):
205205
# scores.
206206
expected_file = self._get_expected_file(name)
207207
expected = torch.load(expected_file)
208-
self.assertEqual(output[0]["scores"], expected[0]["scores"], prec=prec)
208+
torch.testing.assert_close(output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec,
209+
check_device=False, check_dtype=False)
209210

210211
# Note: Fmassa proposed turning off NMS by adapting the threshold
211212
# and then using the Hungarian algorithm as in DETR to find the
@@ -301,10 +302,8 @@ def test_memory_efficient_densenet(self):
301302
model2.eval()
302303
out2 = model2(x)
303304

304-
max_diff = (out1 - out2).abs().max()
305-
306305
self.assertTrue(num_params == num_grad)
307-
self.assertTrue(max_diff < 1e-5)
306+
torch.testing.assert_close(out1, out2, rtol=0.0, atol=1e-5)
308307

309308
def test_resnet_dilation(self):
310309
# TODO improve tests to also check that each layer has the right dimensionality

0 commit comments

Comments
 (0)