10
10
11
11
from torchtnt .framework .callback import Callback
12
12
from torchtnt .framework .state import State
13
- from torchtnt .framework .unit import AppStateMixin , TTrainUnit
13
+ from torchtnt .framework .unit import AppStateMixin , TEvalUnit , TTrainUnit
14
14
from torchtnt .utils .distributed import get_global_rank , sync_bool
15
15
from torchtnt .utils .early_stop_checker import EarlyStopChecker
16
16
@@ -23,6 +23,7 @@ class EarlyStopping(Callback):
23
23
monitored_attr: The attribute to monitor on the unit. Must be a float or tensor attribute on the unit.
24
24
early_stop_checker: a :class:`~torchtnt.utils.early_stop_checker.EarlyStopChecker` to use for checking whether to stop early.
25
25
interval: The interval to check the monitored attribute. Must be one of "step" or "epoch".
26
+ phase: The phase to check the monitored attribute. Must be one of "train" or "eval".
26
27
27
28
Note:
28
29
If doing distributed training, this callback checks the metric value only on rank 0
@@ -33,29 +34,49 @@ def __init__(
33
34
monitored_attr : str ,
34
35
early_stop_checker : EarlyStopChecker ,
35
36
interval : Literal ["step" , "epoch" ] = "epoch" ,
37
+ phase : Literal ["train" , "eval" ] = "train" ,
36
38
interval_freq : int = 1 ,
37
39
) -> None :
38
40
self ._monitored_attr = monitored_attr
39
41
self ._esc = early_stop_checker
40
42
self ._interval = interval
41
43
self ._interval_freq = interval_freq
44
+ self ._phase = phase
42
45
43
46
self ._rank : int = get_global_rank ()
44
47
45
48
def on_train_step_end (self , state : State , unit : TTrainUnit ) -> None :
46
49
if (
47
- self ._interval == "step"
50
+ self ._phase == "train"
51
+ and self ._interval == "step"
48
52
and unit .train_progress .num_steps_completed % self ._interval_freq == 0
49
53
):
50
54
self ._maybe_stop (state , unit )
51
55
52
56
def on_train_epoch_end (self , state : State , unit : TTrainUnit ) -> None :
53
57
if (
54
- self ._interval == "epoch"
58
+ self ._phase == "train"
59
+ and self ._interval == "epoch"
55
60
and unit .train_progress .num_epochs_completed % self ._interval_freq == 0
56
61
):
57
62
self ._maybe_stop (state , unit )
58
63
64
+ def on_eval_step_end (self , state : State , unit : TEvalUnit ) -> None :
65
+ if (
66
+ self ._phase == "eval"
67
+ and self ._interval == "step"
68
+ and unit .eval_progress .num_steps_completed % self ._interval_freq == 0
69
+ ):
70
+ self ._maybe_stop (state , unit )
71
+
72
+ def on_eval_epoch_end (self , state : State , unit : TEvalUnit ) -> None :
73
+ if (
74
+ self ._phase == "eval"
75
+ and self ._interval == "epoch"
76
+ and unit .eval_progress .num_epochs_completed % self ._interval_freq == 0
77
+ ):
78
+ self ._maybe_stop (state , unit )
79
+
59
80
def _maybe_stop (self , state : State , unit : AppStateMixin ) -> None :
60
81
"""
61
82
Checks whether to stop early based on the monitored attribute.
0 commit comments