Skip to content

Commit d2d8181

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Generate predict/evaluate ckpts in DCP Saver (#915)
Summary: Pull Request resolved: #915 Reviewed By: JKSenthil Differential Revision: D63712524 fbshipit-source-id: 8652d3fd00c3a963bf46e99db647a7bd6882832d
1 parent 86e11f3 commit d2d8181

File tree

2 files changed

+249
-3
lines changed

2 files changed

+249
-3
lines changed

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 239 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
import shutil
1313
import tempfile
1414
import unittest
15-
from typing import Any, Dict, Iterator, List, Optional
15+
from typing import Any, Dict, Iterator, List, Optional, Tuple
1616
from unittest import mock
1717
from unittest.mock import MagicMock, patch
1818

1919
import torch
20+
from pyre_extensions import none_throws
2021
from torch import nn
2122
from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
23+
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader
2224
from torch.distributed.checkpoint.default_planner import (
2325
DefaultLoadPlanner,
2426
DefaultSavePlanner,
@@ -28,16 +30,24 @@
2830
from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
2931
from torchtnt.framework._test_utils import (
3032
DummyAutoUnit,
33+
DummyEvalUnit,
34+
DummyMeanMetric,
3135
DummyMultiOptimUnit,
36+
DummyPredictUnit,
3237
DummyTrainUnit,
38+
generate_dummy_stateful_dataloader,
3339
generate_random_dataloader,
3440
get_dummy_train_state,
3541
)
3642
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
3743
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
44+
from torchtnt.framework.evaluate import evaluate
45+
from torchtnt.framework.fit import fit
46+
from torchtnt.framework.predict import predict
3847

3948
from torchtnt.framework.state import State
4049
from torchtnt.framework.train import train
50+
from torchtnt.utils.checkpoint import get_latest_checkpoint_path
4151
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
4252
from torchtnt.utils.env import seed
4353
from torchtnt.utils.test_utils import skip_if_not_distributed
@@ -490,6 +500,234 @@ def test_save_restore_multi_optimizers(self) -> None:
490500
my_unit_clone = DummyMultiOptimUnit(input_dim=input_dim)
491501
dcp_cb.restore_from_latest(temp_dir, my_unit_clone)
492502

503+
def test_save_predict(self) -> None:
504+
input_dim = 2
505+
dataset_len = 10
506+
batch_size = 2
507+
508+
my_unit = DummyPredictUnit(input_dim=input_dim)
509+
510+
# pyre-ignore[16]: Add new attribute for testing
511+
my_unit.output_mean = DummyMeanMetric()
512+
513+
# pyre-ignore[16]: Add at least one element to the metric
514+
my_unit.output_mean.update(1.0)
515+
516+
dataloader = generate_dummy_stateful_dataloader(
517+
dataset_len, input_dim, batch_size
518+
)
519+
520+
with tempfile.TemporaryDirectory() as temp_dir:
521+
dcp_cb = DistributedCheckpointSaver(
522+
temp_dir,
523+
knob_options=KnobOptions(1),
524+
save_every_n_predict_steps=2,
525+
)
526+
527+
predict(my_unit, dataloader, callbacks=[dcp_cb])
528+
529+
generated_ckpts = os.listdir(temp_dir)
530+
expected_ckpts = [
531+
"epoch_0_predict_step_2",
532+
"epoch_0_predict_step_4",
533+
]
534+
535+
self.assertCountEqual(generated_ckpts, expected_ckpts)
536+
537+
ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir))
538+
self.assertEqual(ckpt_path, os.path.join(temp_dir, expected_ckpts[-1]))
539+
540+
storage_reader = FsspecReader(ckpt_path)
541+
metadata = storage_reader.read_metadata()
542+
self.assertCountEqual(
543+
# Get base keys after the app_state wrapper
544+
{key.split(".")[1] for key in metadata.state_dict_metadata.keys()},
545+
[
546+
"predict_progress",
547+
"predict_dataloader",
548+
"output_mean",
549+
],
550+
)
551+
552+
def test_save_evaluate(self) -> None:
553+
input_dim = 2
554+
dataset_len = 10
555+
batch_size = 2
556+
557+
my_unit = DummyEvalUnit(input_dim=input_dim)
558+
559+
dataloader = generate_dummy_stateful_dataloader(
560+
dataset_len, input_dim, batch_size
561+
)
562+
563+
with tempfile.TemporaryDirectory() as temp_dir:
564+
dcp_cb = DistributedCheckpointSaver(
565+
temp_dir,
566+
knob_options=KnobOptions(1),
567+
save_every_n_eval_steps=2,
568+
)
569+
570+
evaluate(my_unit, dataloader, callbacks=[dcp_cb])
571+
572+
generated_ckpts = os.listdir(temp_dir)
573+
expected_ckpts = [
574+
"epoch_0_eval_step_2",
575+
"epoch_0_eval_step_4",
576+
]
577+
578+
self.assertCountEqual(generated_ckpts, expected_ckpts)
579+
580+
ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir))
581+
self.assertEqual(ckpt_path, os.path.join(temp_dir, expected_ckpts[-1]))
582+
583+
storage_reader = FsspecReader(ckpt_path)
584+
metadata = storage_reader.read_metadata()
585+
self.assertCountEqual(
586+
# Get base keys after the app_state wrapper
587+
{key.split(".")[1] for key in metadata.state_dict_metadata.keys()},
588+
[
589+
"eval_progress",
590+
"eval_dataloader",
591+
],
592+
)
593+
594+
def test_save_fit_eval_every_n_epochs(self) -> None:
595+
input_dim = 2
596+
dataset_len = 10
597+
batch_size = 2
598+
599+
my_unit = DummyAutoUnit(module=nn.Linear(input_dim, 2))
600+
my_unit.output_mean = DummyMeanMetric()
601+
602+
train_dataloader = generate_dummy_stateful_dataloader(
603+
dataset_len, input_dim, batch_size
604+
)
605+
606+
eval_dataloader = generate_dummy_stateful_dataloader(
607+
dataset_len, input_dim, batch_size
608+
)
609+
610+
with tempfile.TemporaryDirectory() as temp_dir:
611+
dcp_cb = DistributedCheckpointSaver(
612+
temp_dir,
613+
knob_options=KnobOptions(1),
614+
save_every_n_train_steps=2,
615+
save_every_n_eval_steps=2,
616+
)
617+
618+
fit(
619+
my_unit,
620+
max_epochs=1,
621+
train_dataloader=train_dataloader,
622+
eval_dataloader=eval_dataloader,
623+
evaluate_every_n_epochs=1,
624+
callbacks=[dcp_cb],
625+
)
626+
627+
generated_ckpts = os.listdir(temp_dir)
628+
expected_ckpts = [
629+
"epoch_0_train_step_2_eval_step_0",
630+
"epoch_0_train_step_4_eval_step_0",
631+
"epoch_1_train_step_5_eval_step_2",
632+
"epoch_1_train_step_5_eval_step_4",
633+
]
634+
self.assertCountEqual(generated_ckpts, expected_ckpts)
635+
636+
expected_dataloader = ["train_dataloader"] * 2 + ["eval_dataloader"] * 2
637+
for ckpt_path, dl_key in zip(expected_ckpts, expected_dataloader):
638+
storage_reader = FsspecReader(os.path.join(temp_dir, ckpt_path))
639+
metadata = storage_reader.read_metadata()
640+
self.assertCountEqual(
641+
# Get base keys after the app_state wrapper
642+
{key.split(".")[1] for key in metadata.state_dict_metadata.keys()},
643+
[
644+
"module", # Both train and eval checkpoints save full app_state in fit
645+
"optimizer",
646+
"lr_scheduler",
647+
"train_progress",
648+
"eval_progress",
649+
"predict_progress", # included because of AutoUnit
650+
dl_key,
651+
"output_mean",
652+
],
653+
)
654+
655+
def test_save_fit_eval_every_n_steps(self) -> None:
656+
input_dim = 2
657+
658+
my_unit = DummyAutoUnit(module=nn.Linear(input_dim, 2))
659+
my_unit.output_mean = DummyMeanMetric()
660+
661+
train_dataloader = generate_dummy_stateful_dataloader(10, input_dim, 2)
662+
eval_dataloader = generate_dummy_stateful_dataloader(8, input_dim, 2)
663+
664+
with tempfile.TemporaryDirectory() as temp_dir:
665+
dcp_cb = DistributedCheckpointSaver(
666+
temp_dir,
667+
knob_options=KnobOptions(1),
668+
save_every_n_train_steps=2,
669+
save_every_n_eval_steps=2,
670+
)
671+
672+
fit(
673+
my_unit,
674+
max_epochs=1,
675+
train_dataloader=train_dataloader,
676+
eval_dataloader=eval_dataloader,
677+
evaluate_every_n_steps=2,
678+
evaluate_every_n_epochs=None,
679+
callbacks=[dcp_cb],
680+
)
681+
682+
generated_ckpts = os.listdir(temp_dir)
683+
expected_ckpts_to_dl_mapping: Dict[str, Tuple[str, ...]] = {
684+
# First train 2 steps
685+
"epoch_0_train_step_2_eval_step_0": ("train_dataloader",),
686+
# Then do a whole evaluation (4 steps)
687+
"epoch_0_train_step_2_eval_step_2": (
688+
"train_dataloader",
689+
"eval_dataloader",
690+
),
691+
"epoch_0_train_step_2_eval_step_4": (
692+
"train_dataloader",
693+
"eval_dataloader",
694+
),
695+
# Then train other two steps
696+
"epoch_0_train_step_4_eval_step_4": ("train_dataloader",),
697+
# Finally do a whole evaluation (4 steps)
698+
"epoch_0_train_step_4_eval_step_6": (
699+
"train_dataloader",
700+
"eval_dataloader",
701+
),
702+
"epoch_0_train_step_4_eval_step_8": (
703+
"train_dataloader",
704+
"eval_dataloader",
705+
),
706+
# Last checkpoint (on_train_end)
707+
"epoch_1_train_step_5_eval_step_8": (),
708+
}
709+
self.assertCountEqual(
710+
generated_ckpts, [*expected_ckpts_to_dl_mapping.keys()]
711+
)
712+
713+
for ckpt_path, expected_dls in expected_ckpts_to_dl_mapping.items():
714+
storage_reader = FsspecReader(os.path.join(temp_dir, ckpt_path))
715+
metadata = storage_reader.read_metadata()
716+
self.assertCountEqual(
717+
# Get base keys after the app_state wrapper
718+
{key.split(".")[1] for key in metadata.state_dict_metadata.keys()},
719+
[
720+
"module", # Both train and eval checkpoints save full app_state in fit
721+
"optimizer",
722+
"lr_scheduler",
723+
"train_progress",
724+
"eval_progress",
725+
"predict_progress", # included because of AutoUnit
726+
"output_mean",
727+
*expected_dls,
728+
],
729+
)
730+
493731

494732
class DummyStatefulDataLoader:
495733
def __init__(self, dataloader: DataLoader) -> None:

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,10 @@ class DistributedCheckpointSaver(BaseCheckpointer):
6565
Args:
6666
dirpath: Parent directory to save snapshots to.
6767
save_every_n_train_steps: Frequency of steps with which to save snapshots during the train epoch. If None, no intra-epoch snapshots are generated.
68-
save_every_n_epochs: Frequency of epochs with which to save snapshots during training. If None, no end-of-epoch snapshots are generated.
68+
save_every_n_epochs: Frequency of epochs with which to save checkpoints during training. If None, no end-of-epoch checkpoints are generated.
6969
save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit.
70+
save_every_n_eval_steps: Frequency of evaluation steps with which to save checkpoints during training. Use this if wanting to save checkpoints during evaluate.
71+
save_every_n_predict_steps: Frequency of prediction steps with which to save checkpoints during training. Use this if wanting to save checkpoints during using predict entrypoint.
7072
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead.
7173
best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint.
7274
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
@@ -93,7 +95,9 @@ def __init__(
9395
*,
9496
save_every_n_train_steps: Optional[int] = None,
9597
save_every_n_epochs: Optional[int] = None,
98+
save_every_n_eval_steps: Optional[int] = None,
9699
save_every_n_eval_epochs: Optional[int] = None,
100+
save_every_n_predict_steps: Optional[int] = None,
97101
keep_last_n_checkpoints: Optional[int] = None,
98102
best_checkpoint_config: Optional[BestCheckpointConfig] = None,
99103
process_group: Optional[dist.ProcessGroup] = None,
@@ -104,7 +108,9 @@ def __init__(
104108
dirpath=dirpath,
105109
save_every_n_train_steps=save_every_n_train_steps,
106110
save_every_n_epochs=save_every_n_epochs,
111+
save_every_n_eval_steps=save_every_n_eval_steps,
107112
save_every_n_eval_epochs=save_every_n_eval_epochs,
113+
save_every_n_predict_steps=save_every_n_predict_steps,
108114
keep_last_n_checkpoints=keep_last_n_checkpoints,
109115
best_checkpoint_config=best_checkpoint_config,
110116
process_group=process_group,
@@ -129,10 +135,12 @@ def _checkpoint_impl(
129135
"on_train_epoch_end",
130136
"on_train_end",
131137
"on_eval_epoch_end",
138+
"on_eval_step_end",
139+
"on_predict_step_end",
132140
]:
133141
raise RuntimeError(f"Unexpected hook encountered '{hook}'")
134142

135-
intra_epoch = hook == "on_train_step_end"
143+
intra_epoch = "step_end" in hook
136144
curr_snapshot_wait = hook == "on_train_end"
137145

138146
if planner is None:

0 commit comments

Comments
 (0)