Skip to content

Commit 7bbbb8c

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Move early_stop_checker GPU test to dedicated file (#763)
Summary: Pull Request resolved: #763 Reviewed By: anshulverma Differential Revision: D55388225 fbshipit-source-id: 69a37acdb2098ff1b3922e1b98792df8e4ec82a3
1 parent 8bad7b6 commit 7bbbb8c

File tree

2 files changed

+48
-33
lines changed

2 files changed

+48
-33
lines changed

tests/utils/test_early_stop_checker.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import torch
1313
from torchtnt.utils.early_stop_checker import EarlyStopChecker
14-
from torchtnt.utils.test_utils import skip_if_not_gpu
1514

1615

1716
class EarlyStopCheckerTest(unittest.TestCase):
@@ -86,38 +85,6 @@ def test_early_stop_min_delta(self) -> None:
8685
should_stop = es2.check(torch.tensor(0.26))
8786
self.assertTrue(should_stop)
8887

89-
@skip_if_not_gpu
90-
def test_early_stop_min_delta_on_gpu(self) -> None:
91-
device = torch.device("cuda:0")
92-
93-
# Loss decreases beyond 0.25 but not more than min_delta
94-
losses = [
95-
torch.tensor([0.4], device=device),
96-
torch.tensor([0.38], device=device),
97-
torch.tensor([0.31], device=device),
98-
torch.tensor([0.25], device=device),
99-
torch.tensor([0.27], device=device),
100-
torch.tensor([0.24], device=device),
101-
]
102-
es1 = EarlyStopChecker("min", 3, min_delta=0.05)
103-
es2 = EarlyStopChecker("min", 4, min_delta=0.05)
104-
105-
for loss in losses:
106-
should_stop = es1.check(torch.tensor(loss))
107-
self.assertFalse(should_stop)
108-
should_stop = es2.check(torch.tensor(loss))
109-
self.assertFalse(should_stop)
110-
111-
# Patience should run out
112-
should_stop = es1.check(torch.tensor(0.25))
113-
self.assertTrue(should_stop)
114-
115-
# es2 has more patience than es1
116-
should_stop = es2.check(torch.tensor(0.25))
117-
self.assertFalse(should_stop)
118-
should_stop = es2.check(torch.tensor(0.26))
119-
self.assertTrue(should_stop)
120-
12188
def test_early_stop_max_mode(self) -> None:
12289

12390
# Loss increases beyond 0.38 but not more than min_delta
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
12+
import torch
13+
from torchtnt.utils.early_stop_checker import EarlyStopChecker
14+
from torchtnt.utils.test_utils import skip_if_not_gpu
15+
16+
17+
class EarlyStopCheckerGPUTest(unittest.TestCase):
18+
@skip_if_not_gpu
19+
def test_early_stop_min_delta_on_gpu(self) -> None:
20+
device = torch.device("cuda:0")
21+
22+
# Loss decreases beyond 0.25 but not more than min_delta
23+
losses = [
24+
torch.tensor([0.4], device=device),
25+
torch.tensor([0.38], device=device),
26+
torch.tensor([0.31], device=device),
27+
torch.tensor([0.25], device=device),
28+
torch.tensor([0.27], device=device),
29+
torch.tensor([0.24], device=device),
30+
]
31+
es1 = EarlyStopChecker("min", 3, min_delta=0.05)
32+
es2 = EarlyStopChecker("min", 4, min_delta=0.05)
33+
34+
for loss in losses:
35+
should_stop = es1.check(torch.tensor(loss))
36+
self.assertFalse(should_stop)
37+
should_stop = es2.check(torch.tensor(loss))
38+
self.assertFalse(should_stop)
39+
40+
# Patience should run out
41+
should_stop = es1.check(torch.tensor(0.25))
42+
self.assertTrue(should_stop)
43+
44+
# es2 has more patience than es1
45+
should_stop = es2.check(torch.tensor(0.25))
46+
self.assertFalse(should_stop)
47+
should_stop = es2.check(torch.tensor(0.26))
48+
self.assertTrue(should_stop)

0 commit comments

Comments
 (0)