Skip to content

Commit 714ae04

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Add str method to ActivePhase (#975)
Summary: Pull Request resolved: #975 Reviewed By: JKSenthil Differential Revision: D70127457 fbshipit-source-id: ba083545d8567f8963984c8883ac5e44a9f882a0
1 parent 32a4d82 commit 714ae04

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

tests/framework/test_state.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ def test_active_phase_into_phase(self) -> None:
5353
predict_phase = ActivePhase.PREDICT
5454
self.assertEqual(predict_phase.into_phase(), Phase.PREDICT)
5555

56+
def test_active_phase_str(self) -> None:
57+
active_phase = ActivePhase.TRAIN
58+
self.assertEqual(str(active_phase), "train")
59+
60+
eval_phase = ActivePhase.EVALUATE
61+
self.assertEqual(str(eval_phase), "eval")
62+
63+
predict_phase = ActivePhase.PREDICT
64+
self.assertEqual(str(predict_phase), "predict")
65+
5666
def test_set_evaluate_every_n_steps_or_epochs(self) -> None:
5767
state = PhaseState(dataloader=[], evaluate_every_n_steps=2)
5868
state.evaluate_every_n_steps = None

torchtnt/framework/state.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,16 @@ def into_phase(self) -> Phase:
7474
else:
7575
raise AssertionError("Should match an ActivePhase")
7676

77+
def __str__(self) -> str:
78+
if self == ActivePhase.TRAIN:
79+
return "train"
80+
elif self == ActivePhase.EVALUATE:
81+
return "eval"
82+
elif self == ActivePhase.PREDICT:
83+
return "predict"
84+
else:
85+
raise AssertionError("Should match an ActivePhase")
86+
7787

7888
class PhaseState(Generic[TData, TStepOutput]):
7989
"""State for each phase (train, eval, predict).

0 commit comments

Comments
 (0)