Skip to content

Commit ed30bb6

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Enable comparison of iterator sampler strategy with str value (#954)
Summary: Pull Request resolved: #954 Reviewed By: mpahsu, JKSenthil Differential Revision: D67550088 fbshipit-source-id: 207b25fde316f9884abcb873fd3c27696efdede5
1 parent ba2fb54 commit ed30bb6

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

tests/utils/data/test_iterators.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
11+
from torchtnt.utils.data.iterators import StoppingMechanism
12+
13+
14+
class TestIterators(unittest.TestCase):
15+
16+
def test_stopping_mechanism_comparison(self) -> None:
17+
self.assertTrue(
18+
StoppingMechanism.ALL_DATASETS_EXHAUSTED == "ALL_DATASETS_EXHAUSTED"
19+
)
20+
self.assertTrue(
21+
StoppingMechanism.ALL_DATASETS_EXHAUSTED
22+
== StoppingMechanism.ALL_DATASETS_EXHAUSTED
23+
)
24+
self.assertFalse(
25+
StoppingMechanism.ALL_DATASETS_EXHAUSTED == "SMALLEST_DATASET_EXHAUSTED"
26+
)
27+
self.assertFalse(
28+
StoppingMechanism.ALL_DATASETS_EXHAUSTED
29+
== StoppingMechanism.SMALLEST_DATASET_EXHAUSTED
30+
)

torchtnt/utils/data/iterators.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,15 @@ class StoppingMechanism(Enum):
9494
# used with RandomizedBatchSampler
9595
WRAP_AROUND_UNTIL_KILLED = "WRAP_AROUND_UNTIL_KILLED"
9696

97+
def __eq__(self, other: Union[str, StoppingMechanism]) -> bool:
98+
"""
99+
Enable comparison betwen string and instances of StoppingMechanism
100+
"""
101+
if isinstance(other, str):
102+
return self.value == other
103+
104+
return super().__eq__(other)
105+
97106

98107
@dataclass
99108
class RoundRobin(DataIterationStrategy):

0 commit comments

Comments
 (0)