@@ -346,7 +346,7 @@ def test_save_on_train_end(self) -> None:
346
346
self .assertTrue (os .path .exists (os .path .join (temp_dir , expected_path )))
347
347
348
348
with self .assertLogs (level = "WARNING" ) as log :
349
- checkpoint_cb .metadata_fname = ".metadata"
349
+ checkpoint_cb ._checkpoint_manager . _metadata_fname = ".metadata"
350
350
# create metadata file
351
351
with open (os .path .join (temp_dir , expected_path , ".metadata" ), "w" ):
352
352
pass
@@ -454,99 +454,6 @@ def _test_process_group_plumbing_nccl(_: MagicMock) -> None:
454
454
# check that a new process group was created
455
455
tc .assertNotEqual (checkpoint_cb ._process_group , dist .group .WORLD )
456
456
457
- @patch (
458
- "torchtnt.framework.callbacks.base_checkpointer.get_checkpoint_dirpaths" ,
459
- return_value = ["epoch_1_step_10" , "epoch_2_step_20" ],
460
- )
461
- def test_ckpt_dirpaths (self , _ : MagicMock ) -> None :
462
- """
463
- Tests that ckpt_dirpaths is populated correctly
464
- based on if ``keep_last_n_checkpoints`` is set.
465
- """
466
- bc = BaseCheckpointSaver ("foo" )
467
- self .assertEqual (bc ._ckpt_dirpaths , [])
468
-
469
- bc = BaseCheckpointSaver ("foo" , keep_last_n_checkpoints = 10 )
470
- self .assertEqual (bc ._ckpt_dirpaths , ["epoch_1_step_10" , "epoch_2_step_20" ])
471
-
472
- def test_should_remove_checkpoint (self ) -> None :
473
- """
474
- Tests the helper function that checks if checkpoint should be removed or not
475
- """
476
- bc = BaseCheckpointSaver ("temp" )
477
-
478
- # keep_last_n_checkpoints is toggled off
479
- self .assertFalse (bc ._should_remove_checkpoint ())
480
-
481
- # not enough checkpoints are saved yet to be removed
482
- bc ._keep_last_n_checkpoints = 2
483
- bc ._ckpt_dirpaths = ["bar" ]
484
- self .assertFalse (bc ._should_remove_checkpoint ())
485
-
486
- # enough checkpoints are there to remove
487
- bc ._keep_last_n_checkpoints = 2
488
- bc ._ckpt_dirpaths = ["foo" , "bar" ]
489
- self .assertTrue (bc ._should_remove_checkpoint ())
490
-
491
- @patch ("torchtnt.framework.callbacks.base_checkpointer._delete_checkpoint" )
492
- def test_cleanup_surplus (self , mock_delete_checkpoint : MagicMock ) -> None :
493
- """
494
- Tests surplus of checkpoints being cleaned up
495
- """
496
- state = get_dummy_train_state ()
497
- unit = DummyTrainUnit (input_dim = 2 )
498
- warning_messages = []
499
- with tempfile .TemporaryDirectory () as temp_dir :
500
- bc = BaseCheckpointSaver (temp_dir , keep_last_n_checkpoints = 1 )
501
- bc ._ckpt_dirpaths = ["foo" , "bar" , "baz" ]
502
-
503
- expected_warning_msg = " " .join (
504
- [
505
- f"3 checkpoints found in { temp_dir } ." ,
506
- f"Deleting { 2 } oldest" ,
507
- "checkpoints to enforce ``keep_last_n_checkpoints`` argument." ,
508
- ]
509
- )
510
-
511
- with patch (
512
- "torchtnt.framework.callbacks.base_checkpointer.logging.Logger.warning" ,
513
- warning_messages .append ,
514
- ):
515
- bc .on_train_start (state , unit )
516
- self .assertEqual (bc ._ckpt_dirpaths , ["baz" ])
517
- self .assertEqual (warning_messages [0 ], expected_warning_msg )
518
-
519
- bc = BaseCheckpointSaver (temp_dir )
520
- bc ._ckpt_dirpaths = ["foo" , "bar" , "baz" ]
521
-
522
- bc .on_train_start (state , unit )
523
- self .assertEqual (bc ._ckpt_dirpaths , ["foo" , "bar" , "baz" ])
524
-
525
- def test_keep_last_n_checkpoints (self ) -> None :
526
- """
527
- Tests removing checkpoint directories
528
- """
529
- unit = DummyTrainUnit (input_dim = 2 )
530
- state = get_dummy_train_state ()
531
- with tempfile .TemporaryDirectory () as temp_dir :
532
- bc = BaseCheckpointSaver (
533
- temp_dir ,
534
- save_every_n_train_steps = 1 ,
535
- keep_last_n_checkpoints = 2 ,
536
- )
537
-
538
- # take 10 steps
539
- for _ in range (10 ):
540
- unit .train_progress .increment_step ()
541
- bc .on_train_step_end (state , unit )
542
- # TODO remove time.sleep to avoid potential flaky test
543
- time .sleep (0.1 ) # sleep to ensure enough time to checkpoint
544
-
545
- dirs = os .listdir (temp_dir )
546
- self .assertEqual (len (dirs ), 2 )
547
- self .assertIn ("epoch_0_step_9" , dirs )
548
- self .assertIn ("epoch_0_step_10" , dirs )
549
-
550
457
def test_keep_last_n_checkpoints_e2e (self ) -> None :
551
458
"""
552
459
Tests removing checkpoint directories e2e
@@ -581,66 +488,6 @@ def test_keep_last_n_checkpoints_e2e(self) -> None:
581
488
os .listdir (temp_dir ),
582
489
)
583
490
584
- def test_does_checkpoint_exist (self ) -> None :
585
- with tempfile .TemporaryDirectory () as temp_dir :
586
- with open (os .path .join (temp_dir , ".metadata" ), "w" ):
587
- pass
588
- bc = BaseCheckpointSaver (
589
- temp_dir ,
590
- save_every_n_train_steps = 2 ,
591
- keep_last_n_checkpoints = 1 ,
592
- )
593
- # checkpointer doesn't have a metadata_fname
594
- does_checkpoint_exist = bc ._does_checkpoint_exist (temp_dir )
595
- self .assertFalse (does_checkpoint_exist )
596
-
597
- # checkpointer has metadata_fname and the file exists
598
- bc .metadata_fname = ".metadata"
599
- does_checkpoint_exist = bc ._does_checkpoint_exist (temp_dir )
600
- self .assertTrue (does_checkpoint_exist )
601
-
602
- # checkpointer has metadata_fname but the file doesn't exist
603
- os .remove (os .path .join (temp_dir , ".metadata" ))
604
- does_checkpoint_exist = bc ._does_checkpoint_exist (temp_dir )
605
- self .assertFalse (does_checkpoint_exist )
606
-
607
- def test_should_save_checkpoint (self ) -> None :
608
- """
609
- Tests basic functionality of should_save_checkpoint
610
- """
611
- bc = BaseCheckpointSaver ("foo" )
612
-
613
- # test default behavior
614
- self .assertTrue (bc ._should_save_checkpoint ())
615
-
616
- bc ._ckpt_dirpaths = ["foo/epoch_0_step_1" ]
617
- self .assertTrue (bc ._should_save_checkpoint ())
618
- bc ._keep_last_n_checkpoints = 1
619
- self .assertTrue (bc ._should_save_checkpoint ())
620
-
621
- bc ._ckpt_dirpaths = ["foo/epoch_0_step_1_val_loss=0.01" ]
622
- bc ._best_checkpoint_config = BestCheckpointConfig (
623
- monitored_metric = "val_loss" ,
624
- mode = "min" ,
625
- )
626
- bc ._keep_last_n_checkpoints = None
627
- self .assertTrue (bc ._should_save_checkpoint (0.02 ))
628
- bc ._keep_last_n_checkpoints = 1
629
- self .assertFalse (bc ._should_save_checkpoint (0.02 ))
630
- self .assertTrue (bc ._should_save_checkpoint (0.001 ))
631
- bc ._keep_last_n_checkpoints = 2
632
- self .assertTrue (bc ._should_save_checkpoint (0.02 ))
633
-
634
- bc ._best_checkpoint_config = BestCheckpointConfig (
635
- monitored_metric = "val_loss" ,
636
- mode = "max" ,
637
- )
638
- bc ._keep_last_n_checkpoints = 1
639
- self .assertTrue (bc ._should_save_checkpoint (0.02 ))
640
- self .assertFalse (bc ._should_save_checkpoint (0.001 ))
641
- bc ._keep_last_n_checkpoints = 2
642
- self .assertTrue (bc ._should_save_checkpoint (0.001 ))
643
-
644
491
def test_best_checkpoint_attr_missing (self ) -> None :
645
492
bcs = BaseCheckpointSaver (
646
493
"foo" ,
@@ -686,21 +533,21 @@ def test_best_checkpoint_no_top_k(self) -> None:
686
533
my_train_unit .train_loss = None
687
534
bcs .on_train_epoch_end (state , my_train_unit )
688
535
# none metric-value will not be updated in checkpoint dirpaths
689
- self .assertEqual (bcs ._ckpt_dirpaths , [])
536
+ self .assertEqual (bcs ._checkpoint_manager . _ckpt_paths , [])
690
537
self .assertEqual (os .listdir (temp_dir ), ["epoch_0_step_0" ])
691
538
692
539
my_train_unit .train_loss = 0.01
693
540
bcs .on_train_epoch_end (state , my_train_unit )
694
541
self .assertEqual (
695
- bcs ._ckpt_dirpaths ,
542
+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
696
543
[os .path .join (temp_dir , "epoch_0_step_0_train_loss=0.01" )],
697
544
)
698
545
699
546
my_train_unit .train_loss = 0.02
700
547
my_train_unit .train_progress .increment_epoch ()
701
548
bcs .on_train_epoch_end (state , my_train_unit )
702
549
self .assertEqual (
703
- bcs ._ckpt_dirpaths ,
550
+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
704
551
(
705
552
[
706
553
os .path .join (temp_dir , "epoch_1_step_0_train_loss=0.02" ),
@@ -718,7 +565,7 @@ def test_best_checkpoint_no_top_k(self) -> None:
718
565
my_train_unit .train_progress .increment_epoch ()
719
566
bcs .on_train_epoch_end (state , my_train_unit )
720
567
self .assertEqual (
721
- bcs ._ckpt_dirpaths ,
568
+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
722
569
(
723
570
[
724
571
os .path .join (temp_dir , "epoch_1_step_0_train_loss=0.02" ),
@@ -752,15 +599,15 @@ def test_best_checkpoint_top_k(self) -> None:
752
599
753
600
bcs .on_train_epoch_end (state , my_train_unit )
754
601
self .assertEqual (
755
- bcs ._ckpt_dirpaths ,
602
+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
756
603
[os .path .join (temp_dir , "epoch_0_step_0_train_loss=0.01" )],
757
604
)
758
605
759
606
my_train_unit .train_loss = 0.02
760
607
my_train_unit .train_progress .increment_epoch ()
761
608
bcs .on_train_epoch_end (state , my_train_unit )
762
609
self .assertEqual (
763
- bcs ._ckpt_dirpaths ,
610
+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
764
611
[
765
612
os .path .join (temp_dir , "epoch_0_step_0_train_loss=0.01" ),
766
613
],
@@ -770,7 +617,7 @@ def test_best_checkpoint_top_k(self) -> None:
770
617
my_train_unit .train_progress .increment_epoch ()
771
618
bcs .on_train_epoch_end (state , my_train_unit )
772
619
self .assertEqual (
773
- bcs ._ckpt_dirpaths ,
620
+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
774
621
[
775
622
os .path .join (temp_dir , "epoch_2_step_0_train_loss=0.001" ),
776
623
],
@@ -793,15 +640,15 @@ def test_best_checkpoint_top_k(self) -> None:
793
640
794
641
bcs .on_train_epoch_end (state , my_train_unit )
795
642
self .assertEqual (
796
- bcs ._ckpt_dirpaths ,
643
+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
797
644
[os .path .join (temp_dir , "epoch_0_step_0_train_loss=0.01" )],
798
645
)
799
646
800
647
my_train_unit .train_loss = 0.02
801
648
my_train_unit .train_progress .increment_epoch ()
802
649
bcs .on_train_epoch_end (state , my_train_unit )
803
650
self .assertEqual (
804
- bcs ._ckpt_dirpaths ,
651
+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
805
652
[
806
653
os .path .join (temp_dir , "epoch_1_step_0_train_loss=0.02" ),
807
654
os .path .join (temp_dir , "epoch_0_step_0_train_loss=0.01" ),
@@ -812,7 +659,7 @@ def test_best_checkpoint_top_k(self) -> None:
812
659
my_train_unit .train_progress .increment_epoch ()
813
660
bcs .on_train_epoch_end (state , my_train_unit )
814
661
self .assertEqual (
815
- bcs ._ckpt_dirpaths ,
662
+ [ str ( x ) for x in bcs ._checkpoint_manager . _ckpt_paths ] ,
816
663
[
817
664
os .path .join (temp_dir , "epoch_0_step_0_train_loss=0.01" ),
818
665
os .path .join (temp_dir , "epoch_2_step_0_train_loss=0.001" ),
0 commit comments