File tree Expand file tree Collapse file tree 2 files changed +20
-0
lines changed Expand file tree Collapse file tree 2 files changed +20
-0
lines changed Original file line number Diff line number Diff line change @@ -53,6 +53,16 @@ def test_active_phase_into_phase(self) -> None:
53
53
predict_phase = ActivePhase .PREDICT
54
54
self .assertEqual (predict_phase .into_phase (), Phase .PREDICT )
55
55
56
+ def test_active_phase_str (self ) -> None :
57
+ active_phase = ActivePhase .TRAIN
58
+ self .assertEqual (str (active_phase ), "train" )
59
+
60
+ eval_phase = ActivePhase .EVALUATE
61
+ self .assertEqual (str (eval_phase ), "eval" )
62
+
63
+ predict_phase = ActivePhase .PREDICT
64
+ self .assertEqual (str (predict_phase ), "predict" )
65
+
56
66
def test_set_evaluate_every_n_steps_or_epochs (self ) -> None :
57
67
state = PhaseState (dataloader = [], evaluate_every_n_steps = 2 )
58
68
state .evaluate_every_n_steps = None
Original file line number Diff line number Diff line change @@ -74,6 +74,16 @@ def into_phase(self) -> Phase:
74
74
else :
75
75
raise AssertionError ("Should match an ActivePhase" )
76
76
77
+ def __str__ (self ) -> str :
78
+ if self == ActivePhase .TRAIN :
79
+ return "train"
80
+ elif self == ActivePhase .EVALUATE :
81
+ return "eval"
82
+ elif self == ActivePhase .PREDICT :
83
+ return "predict"
84
+ else :
85
+ raise AssertionError ("Should match an ActivePhase" )
86
+
77
87
78
88
class PhaseState (Generic [TData , TStepOutput ]):
79
89
"""State for each phase (train, eval, predict).
You can’t perform that action at this time.
0 commit comments