Skip to content

Commit ad3ff86

Browse files
galrotemfacebook-github-bot
authored andcommitted
state helper - active phase state
Reviewed By: diego-urgell Differential Revision: D56496429 fbshipit-source-id: ab6c3c69fc624a73cf3095f01c970450c107e02a
1 parent e7b9e64 commit ad3ff86

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

torchtnt/framework/state.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from enum import auto, Enum
1313
from typing import Generic, Iterable, Optional, TypeVar
1414

15+
from pyre_extensions import none_throws
16+
1517
from torchtnt.utils.timer import BoundedTimer, TimerProtocol
1618

1719
_logger: logging.Logger = logging.getLogger(__name__)
@@ -199,3 +201,14 @@ def stop(self) -> None:
199201
"""Signal to the loop to end after the current step completes."""
200202
_logger.warning("Received signal to stop")
201203
self._should_stop = True
204+
205+
def active_phase_state(self) -> TPhaseState:
206+
"""Returns the current active phase state."""
207+
if self._active_phase == ActivePhase.TRAIN:
208+
return none_throws(self._train_state)
209+
elif self._active_phase == ActivePhase.EVALUATE:
210+
return none_throws(self._eval_state)
211+
elif self._active_phase == ActivePhase.PREDICT:
212+
return none_throws(self._predict_state)
213+
else:
214+
raise ValueError(f"Invalid active phase: {self._active_phase}")

0 commit comments

Comments
 (0)