24
24
Phase .PREDICT : "predict_dataloader" ,
25
25
}
26
26
_TRAIN_DL_STATE_KEY = "train_dataloader"
27
+
27
28
_TRAIN_PROGRESS_STATE_KEY = "train_progress"
28
29
_EVAL_PROGRESS_STATE_KEY = "eval_progress"
30
+ _PREDICT_PROGRESS_STATE_KEY = "predict_progress"
29
31
30
32
31
33
def _get_step_phase_mapping (
@@ -56,6 +58,30 @@ def _prepare_app_state(unit: AppStateMixin) -> Dict[str, Any]:
56
58
return app_state
57
59
58
60
61
+ def _remove_app_state_keys (
62
+ unit : AppStateMixin ,
63
+ app_state : Dict [str , Any ],
64
+ * ,
65
+ remove_modules : bool = False ,
66
+ remove_optimizers : bool = False ,
67
+ remove_lr_schedulers : bool = False ,
68
+ ) -> None :
69
+ if remove_modules :
70
+ # remove all module keys from app_state
71
+ for module_keys in unit .tracked_modules ().keys ():
72
+ app_state .pop (module_keys , None )
73
+
74
+ if remove_optimizers :
75
+ # remove all optimizer keys from app_state
76
+ for optim_keys in unit .tracked_optimizers ().keys ():
77
+ app_state .pop (optim_keys , None )
78
+
79
+ if remove_lr_schedulers :
80
+ # remove all lr scheduler keys from app_state
81
+ for lr_scheduler_keys in unit .tracked_lr_schedulers ().keys ():
82
+ app_state .pop (lr_scheduler_keys , None )
83
+
84
+
59
85
def _prepare_app_state_for_checkpoint (
60
86
state : State , unit : AppStateMixin , intra_epoch : bool
61
87
) -> Dict [str , Stateful ]:
@@ -64,6 +90,16 @@ def _prepare_app_state_for_checkpoint(
64
90
"""
65
91
app_state = _prepare_app_state (unit )
66
92
93
+ if state .entry_point in [EntryPoint .EVALUATE , EntryPoint .PREDICT ]:
94
+ # Since model parameters are fixed, remove them from checkpoint.
95
+ _remove_app_state_keys (
96
+ unit ,
97
+ app_state ,
98
+ remove_modules = True ,
99
+ remove_optimizers = True ,
100
+ remove_lr_schedulers = True ,
101
+ )
102
+
67
103
# for intra-epoch checkpointing, include dataloader state of the current phase
68
104
phase_dl = state .active_phase_state ().dataloader
69
105
if intra_epoch and isinstance (phase_dl , Stateful ):
@@ -85,24 +121,21 @@ def _prepare_app_state_for_restore(
85
121
86
122
restore_options = restore_options or RestoreOptions ()
87
123
88
- if not restore_options .restore_modules :
89
- for module_keys in unit .tracked_modules ().keys ():
90
- app_state .pop (module_keys , None )
91
-
92
124
if not restore_options .restore_train_progress :
93
125
app_state .pop (_TRAIN_PROGRESS_STATE_KEY , None )
94
126
95
127
if not restore_options .restore_eval_progress :
96
128
app_state .pop (_EVAL_PROGRESS_STATE_KEY , None )
97
129
98
- if not restore_options .restore_optimizers :
99
- # remove all optimizer keys from app_state
100
- for optim_keys in unit .tracked_optimizers ().keys ():
101
- app_state .pop (optim_keys , None )
130
+ if not restore_options .restore_predict_progress :
131
+ app_state .pop (_PREDICT_PROGRESS_STATE_KEY , None )
102
132
103
- if not restore_options .restore_lr_schedulers :
104
- # remove all lr scheduler keys from app_state
105
- for lr_scheduler_keys in unit .tracked_lr_schedulers ().keys ():
106
- app_state .pop (lr_scheduler_keys , None )
133
+ _remove_app_state_keys (
134
+ unit ,
135
+ app_state ,
136
+ remove_modules = not restore_options .restore_modules ,
137
+ remove_optimizers = not restore_options .restore_optimizers ,
138
+ remove_lr_schedulers = not restore_options .restore_lr_schedulers ,
139
+ )
107
140
108
141
return app_state
0 commit comments