Skip to content

Commit 8bf95f3

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Add eval/predict dataloader parameters to restore methods (#917)
Summary: Pull Request resolved: #917 Reviewed By: JKSenthil Differential Revision: D63013005 fbshipit-source-id: a1d89e95786b74f0ce8ab09b57678dfafe125c00
1 parent d2d8181 commit 8bf95f3

File tree

4 files changed

+31
-5
lines changed

4 files changed

+31
-5
lines changed

tests/framework/callbacks/test_base_checkpointer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import tempfile
1515
import time
1616
import unittest
17-
from typing import cast, Iterable, List, Optional
17+
from typing import Any, cast, Iterable, List, Optional
1818
from unittest.mock import MagicMock, patch
1919

2020
import torch
@@ -42,7 +42,14 @@
4242
from torchtnt.framework.state import ActivePhase, State
4343

4444
from torchtnt.framework.train import train
45-
from torchtnt.framework.unit import AppStateMixin, TrainUnit, TTrainData, TTrainUnit
45+
from torchtnt.framework.unit import (
46+
AppStateMixin,
47+
TEvalData,
48+
TPredictData,
49+
TrainUnit,
50+
TTrainData,
51+
TTrainUnit,
52+
)
4653
from torchtnt.utils.checkpoint import BestCheckpointConfig, get_latest_checkpoint_path
4754
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
4855
from torchtnt.utils.env import init_from_env
@@ -94,10 +101,13 @@ def restore(
94101
unit: AppStateMixin,
95102
*,
96103
train_dataloader: Optional[Iterable[TTrainData]] = None,
104+
eval_dataloader: Optional[Iterable[TEvalData]] = None,
105+
predict_dataloader: Optional[Iterable[TPredictData]] = None,
97106
process_group: Optional[dist.ProcessGroup] = None,
98107
restore_options: Optional[RestoreOptions] = None,
99108
msg: str = "",
100109
restored_checkpoint_path: Optional[List[str]] = None,
110+
**kwargs: Any,
101111
) -> None:
102112
if restored_checkpoint_path is not None:
103113
if len(restored_checkpoint_path):

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ def restore(
410410
train_dataloader: Optional[Iterable[TTrainData]] = None,
411411
process_group: Optional[dist.ProcessGroup] = None,
412412
restore_options: Optional[RestoreOptions] = None,
413+
**kwargs: Any,
413414
) -> None:
414415
"""Method to restore checkpoint state from a path.
415416
@@ -419,7 +420,7 @@ def restore(
419420
Args:
420421
path: Path of the checkpoint to restore.
421422
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
422-
train_dataloader: An optional train dataloader to restore.
423+
train_dataloader: An optional train dataloader to restore. Can only be used when restoring from a train or fit checkpoint.
423424
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
424425
restore_options: Controls what to filter when restoring the state.
425426
"""
@@ -538,6 +539,7 @@ def restore_with_id(
538539
train_dataloader: Optional[Iterable[TTrainData]] = None,
539540
process_group: Optional[dist.ProcessGroup] = None,
540541
restore_options: Optional[RestoreOptions] = None,
542+
**kwargs: Any,
541543
) -> None:
542544
"""Method to restore checkpoint state from a checkpoint id.
543545
@@ -561,4 +563,5 @@ def restore_with_id(
561563
train_dataloader=train_dataloader,
562564
process_group=process_group,
563565
restore_options=restore_options,
566+
**kwargs,
564567
)

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
from torchtnt.framework.state import State
3838
from torchtnt.framework.unit import (
3939
AppStateMixin,
40+
TEvalData,
4041
TEvalUnit,
42+
TPredictData,
4143
TPredictUnit,
4244
TTrainData,
4345
TTrainUnit,
@@ -228,11 +230,14 @@ def restore(
228230
unit: AppStateMixin,
229231
*,
230232
train_dataloader: Optional[Iterable[TTrainData]] = None,
233+
eval_dataloader: Optional[Iterable[TEvalData]] = None,
234+
predict_dataloader: Optional[Iterable[TPredictData]] = None,
231235
process_group: Optional[dist.ProcessGroup] = None,
232236
restore_options: Optional[RestoreOptions] = None,
233237
knob_options: Optional[KnobOptions] = None,
234238
planner: Optional[LoadPlanner] = None,
235239
storage_reader: Optional[StorageReader] = None,
240+
**kwargs: Any,
236241
) -> None:
237242
"""Utility method to restore dcp checkpoint from a path."""
238243

@@ -242,6 +247,8 @@ def restore(
242247
checkpoint_id,
243248
unit,
244249
train_dataloader=train_dataloader,
250+
eval_dataloader=eval_dataloader,
251+
predict_dataloader=predict_dataloader,
245252
process_group=process_group,
246253
restore_options=restore_options,
247254
knob_options=knob_options,
@@ -255,11 +262,14 @@ def restore_with_id(
255262
unit: AppStateMixin,
256263
*,
257264
train_dataloader: Optional[Iterable[TTrainData]] = None,
265+
eval_dataloader: Optional[Iterable[TEvalData]] = None,
266+
predict_dataloader: Optional[Iterable[TPredictData]] = None,
258267
process_group: Optional[dist.ProcessGroup] = None,
259268
restore_options: Optional[RestoreOptions] = None,
260269
knob_options: Optional[KnobOptions] = None,
261270
planner: Optional[LoadPlanner] = None,
262271
storage_reader: Optional[StorageReader] = None,
272+
**kwargs: Any,
263273
) -> None:
264274
"""Utility method to restore dcp checkpoint from a checkpoint_id.
265275
@@ -269,7 +279,9 @@ def restore_with_id(
269279
Args:
270280
checkpoint_id: Checkpoint id. It can be the path of the snapshot to restore.
271281
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
272-
train_dataloader: An optional train dataloader to restore.
282+
train_dataloader: An optional train dataloader to restore. Can only be used when restoring from a train or fit checkpoint.
283+
eval_dataloader: An optional eval dataloader to restore. Can only be used when restoring from an eval or fit checkpoint.
284+
predict_dataloader: An optional predict dataloader to restore. Can only be used when restoring from a predict checkpoint.
273285
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
274286
If not Gloo, a Gloo process group is created.
275287
Note: If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion.

torchtnt/framework/callbacks/torchsnapshot_saver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def restore(
270270
storage_options: Optional[Dict[str, Any]] = None,
271271
knob_options: Optional[KnobOptions] = None,
272272
strict: bool = True,
273+
**kwargs: Any,
273274
) -> None:
274275
"""Utility method to restore snapshot state from a path.
275276
@@ -279,7 +280,7 @@ def restore(
279280
Args:
280281
path: Path of the snapshot to restore.
281282
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
282-
train_dataloader: An optional train dataloader to restore.
283+
train_dataloader: An optional train dataloader to restore. Note that restoring from predict or evaluate dataloaders is not supported for TorchSnapshotSaver.
283284
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
284285
restore_options: Controls what to filter when restoring the state.
285286
storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot <https://pytorch.org/torchsnapshot/stable/api_reference.html#torchsnapshot.Snapshot>`_. See each storage plugin's documentation for customizations.

0 commit comments

Comments
 (0)