Skip to content

Commit d6a7305

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Use CheckpointManager in BaseCheckpointer (#810)
Summary: Pull Request resolved: #810 Reviewed By: JKSenthil Differential Revision: D56556426 fbshipit-source-id: 20fd8f6498320d847b1bcc041783007f0c6a9275
1 parent 27e16cd commit d6a7305

File tree

5 files changed

+73
-433
lines changed

5 files changed

+73
-433
lines changed

tests/framework/callbacks/test_base_checkpointer.py

Lines changed: 11 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def test_save_on_train_end(self) -> None:
346346
self.assertTrue(os.path.exists(os.path.join(temp_dir, expected_path)))
347347

348348
with self.assertLogs(level="WARNING") as log:
349-
checkpoint_cb.metadata_fname = ".metadata"
349+
checkpoint_cb._checkpoint_manager._metadata_fname = ".metadata"
350350
# create metadata file
351351
with open(os.path.join(temp_dir, expected_path, ".metadata"), "w"):
352352
pass
@@ -454,99 +454,6 @@ def _test_process_group_plumbing_nccl(_: MagicMock) -> None:
454454
# check that a new process group was created
455455
tc.assertNotEqual(checkpoint_cb._process_group, dist.group.WORLD)
456456

457-
@patch(
458-
"torchtnt.framework.callbacks.base_checkpointer.get_checkpoint_dirpaths",
459-
return_value=["epoch_1_step_10", "epoch_2_step_20"],
460-
)
461-
def test_ckpt_dirpaths(self, _: MagicMock) -> None:
462-
"""
463-
Tests that ckpt_dirpaths is populated correctly
464-
based on if ``keep_last_n_checkpoints`` is set.
465-
"""
466-
bc = BaseCheckpointSaver("foo")
467-
self.assertEqual(bc._ckpt_dirpaths, [])
468-
469-
bc = BaseCheckpointSaver("foo", keep_last_n_checkpoints=10)
470-
self.assertEqual(bc._ckpt_dirpaths, ["epoch_1_step_10", "epoch_2_step_20"])
471-
472-
def test_should_remove_checkpoint(self) -> None:
473-
"""
474-
Tests the helper function that checks if checkpoint should be removed or not
475-
"""
476-
bc = BaseCheckpointSaver("temp")
477-
478-
# keep_last_n_checkpoints is toggled off
479-
self.assertFalse(bc._should_remove_checkpoint())
480-
481-
# not enough checkpoints are saved yet to be removed
482-
bc._keep_last_n_checkpoints = 2
483-
bc._ckpt_dirpaths = ["bar"]
484-
self.assertFalse(bc._should_remove_checkpoint())
485-
486-
# enough checkpoints are there to remove
487-
bc._keep_last_n_checkpoints = 2
488-
bc._ckpt_dirpaths = ["foo", "bar"]
489-
self.assertTrue(bc._should_remove_checkpoint())
490-
491-
@patch("torchtnt.framework.callbacks.base_checkpointer._delete_checkpoint")
492-
def test_cleanup_surplus(self, mock_delete_checkpoint: MagicMock) -> None:
493-
"""
494-
Tests surplus of checkpoints being cleaned up
495-
"""
496-
state = get_dummy_train_state()
497-
unit = DummyTrainUnit(input_dim=2)
498-
warning_messages = []
499-
with tempfile.TemporaryDirectory() as temp_dir:
500-
bc = BaseCheckpointSaver(temp_dir, keep_last_n_checkpoints=1)
501-
bc._ckpt_dirpaths = ["foo", "bar", "baz"]
502-
503-
expected_warning_msg = " ".join(
504-
[
505-
f"3 checkpoints found in {temp_dir}.",
506-
f"Deleting {2} oldest",
507-
"checkpoints to enforce ``keep_last_n_checkpoints`` argument.",
508-
]
509-
)
510-
511-
with patch(
512-
"torchtnt.framework.callbacks.base_checkpointer.logging.Logger.warning",
513-
warning_messages.append,
514-
):
515-
bc.on_train_start(state, unit)
516-
self.assertEqual(bc._ckpt_dirpaths, ["baz"])
517-
self.assertEqual(warning_messages[0], expected_warning_msg)
518-
519-
bc = BaseCheckpointSaver(temp_dir)
520-
bc._ckpt_dirpaths = ["foo", "bar", "baz"]
521-
522-
bc.on_train_start(state, unit)
523-
self.assertEqual(bc._ckpt_dirpaths, ["foo", "bar", "baz"])
524-
525-
def test_keep_last_n_checkpoints(self) -> None:
526-
"""
527-
Tests removing checkpoint directories
528-
"""
529-
unit = DummyTrainUnit(input_dim=2)
530-
state = get_dummy_train_state()
531-
with tempfile.TemporaryDirectory() as temp_dir:
532-
bc = BaseCheckpointSaver(
533-
temp_dir,
534-
save_every_n_train_steps=1,
535-
keep_last_n_checkpoints=2,
536-
)
537-
538-
# take 10 steps
539-
for _ in range(10):
540-
unit.train_progress.increment_step()
541-
bc.on_train_step_end(state, unit)
542-
# TODO remove time.sleep to avoid potential flaky test
543-
time.sleep(0.1) # sleep to ensure enough time to checkpoint
544-
545-
dirs = os.listdir(temp_dir)
546-
self.assertEqual(len(dirs), 2)
547-
self.assertIn("epoch_0_step_9", dirs)
548-
self.assertIn("epoch_0_step_10", dirs)
549-
550457
def test_keep_last_n_checkpoints_e2e(self) -> None:
551458
"""
552459
Tests removing checkpoint directories e2e
@@ -581,66 +488,6 @@ def test_keep_last_n_checkpoints_e2e(self) -> None:
581488
os.listdir(temp_dir),
582489
)
583490

584-
def test_does_checkpoint_exist(self) -> None:
585-
with tempfile.TemporaryDirectory() as temp_dir:
586-
with open(os.path.join(temp_dir, ".metadata"), "w"):
587-
pass
588-
bc = BaseCheckpointSaver(
589-
temp_dir,
590-
save_every_n_train_steps=2,
591-
keep_last_n_checkpoints=1,
592-
)
593-
# checkpointer doesn't have a metadata_fname
594-
does_checkpoint_exist = bc._does_checkpoint_exist(temp_dir)
595-
self.assertFalse(does_checkpoint_exist)
596-
597-
# checkpointer has metadata_fname and the file exists
598-
bc.metadata_fname = ".metadata"
599-
does_checkpoint_exist = bc._does_checkpoint_exist(temp_dir)
600-
self.assertTrue(does_checkpoint_exist)
601-
602-
# checkpointer has metadata_fname but the file doesn't exist
603-
os.remove(os.path.join(temp_dir, ".metadata"))
604-
does_checkpoint_exist = bc._does_checkpoint_exist(temp_dir)
605-
self.assertFalse(does_checkpoint_exist)
606-
607-
def test_should_save_checkpoint(self) -> None:
608-
"""
609-
Tests basic functionality of should_save_checkpoint
610-
"""
611-
bc = BaseCheckpointSaver("foo")
612-
613-
# test default behavior
614-
self.assertTrue(bc._should_save_checkpoint())
615-
616-
bc._ckpt_dirpaths = ["foo/epoch_0_step_1"]
617-
self.assertTrue(bc._should_save_checkpoint())
618-
bc._keep_last_n_checkpoints = 1
619-
self.assertTrue(bc._should_save_checkpoint())
620-
621-
bc._ckpt_dirpaths = ["foo/epoch_0_step_1_val_loss=0.01"]
622-
bc._best_checkpoint_config = BestCheckpointConfig(
623-
monitored_metric="val_loss",
624-
mode="min",
625-
)
626-
bc._keep_last_n_checkpoints = None
627-
self.assertTrue(bc._should_save_checkpoint(0.02))
628-
bc._keep_last_n_checkpoints = 1
629-
self.assertFalse(bc._should_save_checkpoint(0.02))
630-
self.assertTrue(bc._should_save_checkpoint(0.001))
631-
bc._keep_last_n_checkpoints = 2
632-
self.assertTrue(bc._should_save_checkpoint(0.02))
633-
634-
bc._best_checkpoint_config = BestCheckpointConfig(
635-
monitored_metric="val_loss",
636-
mode="max",
637-
)
638-
bc._keep_last_n_checkpoints = 1
639-
self.assertTrue(bc._should_save_checkpoint(0.02))
640-
self.assertFalse(bc._should_save_checkpoint(0.001))
641-
bc._keep_last_n_checkpoints = 2
642-
self.assertTrue(bc._should_save_checkpoint(0.001))
643-
644491
def test_best_checkpoint_attr_missing(self) -> None:
645492
bcs = BaseCheckpointSaver(
646493
"foo",
@@ -686,21 +533,21 @@ def test_best_checkpoint_no_top_k(self) -> None:
686533
my_train_unit.train_loss = None
687534
bcs.on_train_epoch_end(state, my_train_unit)
688535
# none metric-value will not be updated in checkpoint dirpaths
689-
self.assertEqual(bcs._ckpt_dirpaths, [])
536+
self.assertEqual(bcs._checkpoint_manager._ckpt_paths, [])
690537
self.assertEqual(os.listdir(temp_dir), ["epoch_0_step_0"])
691538

692539
my_train_unit.train_loss = 0.01
693540
bcs.on_train_epoch_end(state, my_train_unit)
694541
self.assertEqual(
695-
bcs._ckpt_dirpaths,
542+
[str(x) for x in bcs._checkpoint_manager._ckpt_paths],
696543
[os.path.join(temp_dir, "epoch_0_step_0_train_loss=0.01")],
697544
)
698545

699546
my_train_unit.train_loss = 0.02
700547
my_train_unit.train_progress.increment_epoch()
701548
bcs.on_train_epoch_end(state, my_train_unit)
702549
self.assertEqual(
703-
bcs._ckpt_dirpaths,
550+
[str(x) for x in bcs._checkpoint_manager._ckpt_paths],
704551
(
705552
[
706553
os.path.join(temp_dir, "epoch_1_step_0_train_loss=0.02"),
@@ -718,7 +565,7 @@ def test_best_checkpoint_no_top_k(self) -> None:
718565
my_train_unit.train_progress.increment_epoch()
719566
bcs.on_train_epoch_end(state, my_train_unit)
720567
self.assertEqual(
721-
bcs._ckpt_dirpaths,
568+
[str(x) for x in bcs._checkpoint_manager._ckpt_paths],
722569
(
723570
[
724571
os.path.join(temp_dir, "epoch_1_step_0_train_loss=0.02"),
@@ -752,15 +599,15 @@ def test_best_checkpoint_top_k(self) -> None:
752599

753600
bcs.on_train_epoch_end(state, my_train_unit)
754601
self.assertEqual(
755-
bcs._ckpt_dirpaths,
602+
[str(x) for x in bcs._checkpoint_manager._ckpt_paths],
756603
[os.path.join(temp_dir, "epoch_0_step_0_train_loss=0.01")],
757604
)
758605

759606
my_train_unit.train_loss = 0.02
760607
my_train_unit.train_progress.increment_epoch()
761608
bcs.on_train_epoch_end(state, my_train_unit)
762609
self.assertEqual(
763-
bcs._ckpt_dirpaths,
610+
[str(x) for x in bcs._checkpoint_manager._ckpt_paths],
764611
[
765612
os.path.join(temp_dir, "epoch_0_step_0_train_loss=0.01"),
766613
],
@@ -770,7 +617,7 @@ def test_best_checkpoint_top_k(self) -> None:
770617
my_train_unit.train_progress.increment_epoch()
771618
bcs.on_train_epoch_end(state, my_train_unit)
772619
self.assertEqual(
773-
bcs._ckpt_dirpaths,
620+
[str(x) for x in bcs._checkpoint_manager._ckpt_paths],
774621
[
775622
os.path.join(temp_dir, "epoch_2_step_0_train_loss=0.001"),
776623
],
@@ -793,15 +640,15 @@ def test_best_checkpoint_top_k(self) -> None:
793640

794641
bcs.on_train_epoch_end(state, my_train_unit)
795642
self.assertEqual(
796-
bcs._ckpt_dirpaths,
643+
[str(x) for x in bcs._checkpoint_manager._ckpt_paths],
797644
[os.path.join(temp_dir, "epoch_0_step_0_train_loss=0.01")],
798645
)
799646

800647
my_train_unit.train_loss = 0.02
801648
my_train_unit.train_progress.increment_epoch()
802649
bcs.on_train_epoch_end(state, my_train_unit)
803650
self.assertEqual(
804-
bcs._ckpt_dirpaths,
651+
[str(x) for x in bcs._checkpoint_manager._ckpt_paths],
805652
[
806653
os.path.join(temp_dir, "epoch_1_step_0_train_loss=0.02"),
807654
os.path.join(temp_dir, "epoch_0_step_0_train_loss=0.01"),
@@ -812,7 +659,7 @@ def test_best_checkpoint_top_k(self) -> None:
812659
my_train_unit.train_progress.increment_epoch()
813660
bcs.on_train_epoch_end(state, my_train_unit)
814661
self.assertEqual(
815-
bcs._ckpt_dirpaths,
662+
[str(x) for x in bcs._checkpoint_manager._ckpt_paths],
816663
[
817664
os.path.join(temp_dir, "epoch_0_step_0_train_loss=0.01"),
818665
os.path.join(temp_dir, "epoch_2_step_0_train_loss=0.001"),

tests/utils/test_checkpoint.py

Lines changed: 13 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,8 @@
2424
from torchtnt.utils import get_global_rank, init_from_env
2525

2626
from torchtnt.utils.checkpoint import (
27-
_delete_checkpoint,
2827
_metadata_exists,
2928
_retrieve_checkpoint_dirpaths,
30-
_sort_by_metric_value,
31-
_sort_by_recency,
3229
BestCheckpointConfig,
3330
CheckpointManager,
3431
CheckpointPath,
@@ -195,6 +192,19 @@ def test_pickling(self) -> None:
195192
unpickled = pickle.loads(pickled)
196193
self.assertEqual(unpickled, ckpt)
197194

195+
def test_checkpoint_ordering(self) -> None:
196+
"""
197+
Tests the sort utilities
198+
"""
199+
paths = [
200+
"foo/epoch_1_step_20",
201+
"foo/epoch_4_step_130",
202+
"foo/epoch_0_step_10_val_loss=10.0",
203+
]
204+
ckpts = [CheckpointPath.from_str(x) for x in paths]
205+
sorted_paths = [str(x) for x in sorted(ckpts)]
206+
self.assertEqual(sorted_paths, [paths[2], paths[0], paths[1]])
207+
198208

199209
class CheckpointManagerTest(unittest.TestCase):
200210
def test_create_checkpoint_manager(self) -> None:
@@ -884,50 +894,6 @@ def test_get_checkpoint_dirpaths(self) -> None:
884894
[],
885895
)
886896

887-
def test_checkpoint_sorting_utils(self) -> None:
888-
"""
889-
Tests the sort utilities
890-
"""
891-
paths = [
892-
"foo/epoch_1_step_20",
893-
"foo/epoch_4_step_130",
894-
"foo/epoch_0_step_10_val_loss=10.0",
895-
]
896-
ckpts = [CheckpointPath.from_str(x) for x in paths]
897-
sorted_paths = [str(x) for x in _sort_by_recency(ckpts)]
898-
self.assertEqual(sorted_paths, [paths[2], paths[0], paths[1]])
899-
900-
paths = [
901-
"foo/epoch_1_step_20_val_loss=0.09",
902-
"foo/epoch_4_step_130_val_loss=29.0",
903-
"foo/epoch_0_step_10_val_loss=10.0",
904-
]
905-
ckpts = [CheckpointPath.from_str(x) for x in paths]
906-
907-
sorted_paths = [str(x) for x in _sort_by_metric_value(ckpts, mode="min")]
908-
self.assertEqual(sorted_paths, [paths[1], paths[2], paths[0]])
909-
910-
sorted_paths = [str(x) for x in _sort_by_metric_value(ckpts, mode="max")]
911-
self.assertEqual(sorted_paths, [paths[0], paths[2], paths[1]])
912-
913-
def test_delete_checkpoint(self) -> None:
914-
"""
915-
Tests removing checkpoint directories
916-
"""
917-
app_state = {"module": nn.Linear(2, 2)}
918-
with tempfile.TemporaryDirectory() as temp_dir:
919-
dirpath = os.path.join(temp_dir, "checkpoint")
920-
Snapshot.take(dirpath, app_state=app_state)
921-
self.assertTrue(os.path.exists(dirpath))
922-
# check that error is thrown if .snapshot_metadata is not found in the directory when deleting
923-
os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME))
924-
with self.assertRaisesRegex(
925-
RuntimeError, f"{temp_dir} does not contain .snapshot_metadata"
926-
):
927-
_delete_checkpoint(temp_dir, SNAPSHOT_METADATA_FNAME)
928-
_delete_checkpoint(dirpath)
929-
self.assertFalse(os.path.exists(dirpath))
930-
931897
def test_metadata_exists(self) -> None:
932898
app_state = {"module": nn.Linear(2, 2)}
933899
with tempfile.TemporaryDirectory() as temp_dir:

0 commit comments

Comments
 (0)