File tree Expand file tree Collapse file tree 1 file changed +13
-0
lines changed Expand file tree Collapse file tree 1 file changed +13
-0
lines changed Original file line number Diff line number Diff line change 12
12
from enum import auto , Enum
13
13
from typing import Generic , Iterable , Optional , TypeVar
14
14
15
+ from pyre_extensions import none_throws
16
+
15
17
from torchtnt .utils .timer import BoundedTimer , TimerProtocol
16
18
17
19
_logger : logging .Logger = logging .getLogger (__name__ )
@@ -199,3 +201,14 @@ def stop(self) -> None:
199
201
"""Signal to the loop to end after the current step completes."""
200
202
_logger .warning ("Received signal to stop" )
201
203
self ._should_stop = True
204
+
205
+ def active_phase_state (self ) -> TPhaseState :
206
+ """Returns the current active phase state."""
207
+ if self ._active_phase == ActivePhase .TRAIN :
208
+ return none_throws (self ._train_state )
209
+ elif self ._active_phase == ActivePhase .EVALUATE :
210
+ return none_throws (self ._eval_state )
211
+ elif self ._active_phase == ActivePhase .PREDICT :
212
+ return none_throws (self ._predict_state )
213
+ else :
214
+ raise ValueError (f"Invalid active phase: { self ._active_phase } " )
You can’t perform that action at this time.
0 commit comments