File tree Expand file tree Collapse file tree 2 files changed +39
-0
lines changed Expand file tree Collapse file tree 2 files changed +39
-0
lines changed Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import unittest
10
+
11
+ from torchtnt .utils .data .iterators import StoppingMechanism
12
+
13
+
14
+ class TestIterators (unittest .TestCase ):
15
+
16
+ def test_stopping_mechanism_comparison (self ) -> None :
17
+ self .assertTrue (
18
+ StoppingMechanism .ALL_DATASETS_EXHAUSTED == "ALL_DATASETS_EXHAUSTED"
19
+ )
20
+ self .assertTrue (
21
+ StoppingMechanism .ALL_DATASETS_EXHAUSTED
22
+ == StoppingMechanism .ALL_DATASETS_EXHAUSTED
23
+ )
24
+ self .assertFalse (
25
+ StoppingMechanism .ALL_DATASETS_EXHAUSTED == "SMALLEST_DATASET_EXHAUSTED"
26
+ )
27
+ self .assertFalse (
28
+ StoppingMechanism .ALL_DATASETS_EXHAUSTED
29
+ == StoppingMechanism .SMALLEST_DATASET_EXHAUSTED
30
+ )
Original file line number Diff line number Diff line change @@ -94,6 +94,15 @@ class StoppingMechanism(Enum):
94
94
# used with RandomizedBatchSampler
95
95
WRAP_AROUND_UNTIL_KILLED = "WRAP_AROUND_UNTIL_KILLED"
96
96
97
+ def __eq__ (self , other : Union [str , StoppingMechanism ]) -> bool :
98
+ """
99
+ Enable comparison betwen string and instances of StoppingMechanism
100
+ """
101
+ if isinstance (other , str ):
102
+ return self .value == other
103
+
104
+ return super ().__eq__ (other )
105
+
97
106
98
107
@dataclass
99
108
class RoundRobin (DataIterationStrategy ):
You can’t perform that action at this time.
0 commit comments