Skip to content

Commit 1545b34

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Remove unnecessary encapsulation of DCP APIs (#891)
Summary: Pull Request resolved: #891 Reviewed By: JKSenthil Differential Revision: D61951699 fbshipit-source-id: f5ccb46b4eaf251d64699c5b7a81e717009322b0
1 parent d3e85dc commit 1545b34

File tree

2 files changed

+32
-64
lines changed

2 files changed

+32
-64
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
DummyAutoUnit,
3131
DummyTrainUnit,
3232
generate_random_dataloader,
33+
get_dummy_train_state,
3334
)
3435
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
3536
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
@@ -306,6 +307,7 @@ def test_save_default_planner_storage_components(
306307
save_every_n_train_steps = 1
307308

308309
my_unit = DummyTrainUnit(input_dim=input_dim)
310+
state = get_dummy_train_state()
309311

310312
with tempfile.TemporaryDirectory() as temp_dir:
311313
dcp_cb = DistributedCheckpointSaver(
@@ -314,9 +316,11 @@ def test_save_default_planner_storage_components(
314316
knob_options=KnobOptions(1),
315317
)
316318

317-
dcp_cb._save(
319+
dcp_cb._checkpoint_impl(
320+
state=state,
321+
unit=my_unit,
318322
checkpoint_id=temp_dir,
319-
app_state=my_unit.module.state_dict(),
323+
hook="on_train_epoch_end",
320324
)
321325

322326
planner = mock_dist_cp.save.call_args_list[0][1]["planner"]
@@ -331,6 +335,7 @@ def test_save_planner_storage_components(self, mock_dist_cp: MagicMock) -> None:
331335
save_every_n_train_steps = 1
332336

333337
my_unit = DummyTrainUnit(input_dim=input_dim)
338+
state = get_dummy_train_state()
334339

335340
with tempfile.TemporaryDirectory() as temp_dir:
336341
dcp_cb = DistributedCheckpointSaver(
@@ -339,9 +344,11 @@ def test_save_planner_storage_components(self, mock_dist_cp: MagicMock) -> None:
339344
knob_options=KnobOptions(1),
340345
)
341346

342-
dcp_cb._save(
347+
dcp_cb._checkpoint_impl(
348+
state=state,
349+
unit=my_unit,
343350
checkpoint_id=temp_dir,
344-
app_state=my_unit.module.state_dict(),
351+
hook="on_train_epoch_end",
345352
planner=DummySavePlanner(),
346353
storage_writer=DummyStorageWriter(path=temp_dir),
347354
)

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 21 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -137,26 +137,38 @@ def _checkpoint_impl(
137137
intra_epoch = hook == "on_train_step_end"
138138
curr_snapshot_wait = hook == "on_train_end"
139139

140+
if planner is None:
141+
planner = DefaultSavePlanner()
142+
143+
if storage_writer is None:
144+
storage_writer = Writer(checkpoint_id, **self.default_writer_options)
145+
140146
app_state = _prepare_app_state_for_checkpoint(state, unit, intra_epoch)
141147
# TODO: evaluate whether we need to implement the equivalent of torchsnapshot.RNGState()
142148
if self._async_checkpoint:
143149
with get_timing_context(state, f"{self.__class__.__name__}.async_save"):
144-
# TODO checkpoint is not truly successful
145-
# since this is async checkpointed, so in
146-
# future, add logic to set successful flag
147-
# only when checkpoint is fully written
148-
checkpoint_success = self._async_save(
149-
checkpoint_id, app_state, planner, storage_writer
150+
# Redundant check for safety
151+
self._wait(log_warning=True)
152+
self._prev_snapshot = dcp.async_save(
153+
state_dict={"app_state": MultiStateful(app_state)},
154+
checkpoint_id=checkpoint_id,
155+
process_group=self._process_group,
156+
storage_writer=storage_writer,
157+
planner=planner,
150158
)
151159
if curr_snapshot_wait:
152160
self._wait(log_warning=False)
153161
else:
154162
with get_timing_context(state, f"{self.__class__.__name__}.save"):
155-
checkpoint_success = self._save(
156-
checkpoint_id, app_state, planner, storage_writer
163+
dcp.save(
164+
state_dict={"app_state": MultiStateful(app_state)},
165+
checkpoint_id=checkpoint_id,
166+
process_group=self._process_group,
167+
storage_writer=storage_writer,
168+
planner=planner,
157169
)
158170

159-
return checkpoint_success
171+
return True
160172

161173
def _wait(self, log_warning: bool = True) -> None:
162174
"""
@@ -195,57 +207,6 @@ def _wait(self, log_warning: bool = True) -> None:
195207
logger=logger,
196208
)
197209

198-
def _async_save(
199-
self,
200-
checkpoint_id: str,
201-
app_state: Dict[str, Stateful],
202-
planner: Optional[SavePlanner] = None,
203-
storage_writer: Optional[StorageWriter] = None,
204-
) -> bool:
205-
206-
if planner is None:
207-
planner = DefaultSavePlanner()
208-
209-
if storage_writer is None:
210-
storage_writer = Writer(checkpoint_id, **self.default_writer_options)
211-
212-
# Redundant check for safety
213-
self._wait(log_warning=True)
214-
215-
self._prev_snapshot = dcp.async_save(
216-
state_dict={"app_state": MultiStateful(app_state)},
217-
checkpoint_id=checkpoint_id,
218-
process_group=self._process_group,
219-
storage_writer=storage_writer,
220-
planner=planner,
221-
)
222-
223-
return True
224-
225-
def _save(
226-
self,
227-
checkpoint_id: str,
228-
app_state: Dict[str, Stateful],
229-
planner: Optional[SavePlanner] = None,
230-
storage_writer: Optional[StorageWriter] = None,
231-
) -> bool:
232-
# Initialize DefaultSavePlanner and FsspecWriter if not provided
233-
if planner is None:
234-
planner = DefaultSavePlanner()
235-
236-
if storage_writer is None:
237-
storage_writer = Writer(checkpoint_id, **self.default_writer_options)
238-
239-
dcp.save(
240-
state_dict={"app_state": MultiStateful(app_state)},
241-
checkpoint_id=checkpoint_id,
242-
process_group=self._process_group,
243-
storage_writer=storage_writer,
244-
planner=planner,
245-
)
246-
247-
return True
248-
249210
def on_exception(
250211
self,
251212
state: State,

0 commit comments

Comments
 (0)