@@ -326,7 +326,7 @@ def __init__(
326
326
dirpath : str ,
327
327
best_checkpoint_config : Optional [BestCheckpointConfig ] = None ,
328
328
keep_last_n_checkpoints : Optional [int ] = None ,
329
- metadata_fname : Optional [str ] = None ,
329
+ metadata_fnames : Optional [List [ str ] ] = None ,
330
330
process_group : Optional [dist .ProcessGroup ] = None ,
331
331
) -> None :
332
332
"""
@@ -338,7 +338,8 @@ def __init__(
338
338
dirpath: The base directory path that checkpoints are saved in. This is synced from rank 0 to every other rank upon initialization.
339
339
best_checkpoint_config: Optional configuration for the best checkpoint.
340
340
keep_last_n_checkpoints: Optional number of checkpoints to keep.
341
- metadata_fname: Optional name of the metadata file.
341
+ metadata_fname: Optional names of the metadata files. These are used to verify checkpoint integrity. If more than one is provided, a
342
+ checkpoint is considered if at least one of them exists.
342
343
process_group: Optional process group to use for distributed training. gloo process groups are known
343
344
to perform better.
344
345
"""
@@ -348,9 +349,13 @@ def __init__(
348
349
349
350
self ._best_checkpoint_config = best_checkpoint_config
350
351
self ._keep_last_n_checkpoints = keep_last_n_checkpoints
351
- self ._metadata_fname = metadata_fname
352
352
self ._pg_wrapper = PGWrapper (process_group )
353
353
354
+ if metadata_fnames is None :
355
+ self ._metadata_fnames : List [str ] = []
356
+ else :
357
+ self ._metadata_fnames = metadata_fnames
358
+
354
359
self ._ckpt_paths : List [CheckpointPath ] = []
355
360
if not self ._keep_last_n_checkpoints :
356
361
return
@@ -361,7 +366,7 @@ def __init__(
361
366
)
362
367
self ._ckpt_paths = get_checkpoint_dirpaths (
363
368
dirpath = dirpath ,
364
- metadata_fname = self ._metadata_fname ,
369
+ metadata_fname = self ._metadata_fnames ,
365
370
metric_name = metric_name ,
366
371
process_group = process_group ,
367
372
)
@@ -512,12 +517,13 @@ def does_checkpoint_exist(
512
517
If the checkpointer doesn't have a metadata file, this function will always return False. Check is executed in rank 0, but
513
518
result is broadcasted to all ranks.
514
519
"""
515
- metadata_fname = self ._metadata_fname
516
- if not metadata_fname :
520
+ if not self ._metadata_fnames :
517
521
return False
518
522
519
523
fs , _ = url_to_fs (self .dirpath )
520
- return _metadata_exists (fs , ckpt .path , metadata_fname )
524
+ return any (
525
+ _metadata_exists (fs , ckpt .path , fname ) for fname in self ._metadata_fnames
526
+ )
521
527
522
528
@staticmethod
523
529
@rank_zero_read_and_broadcast
@@ -542,7 +548,7 @@ def remove_checkpoint(self) -> None:
542
548
@rank_zero_read_and_broadcast
543
549
def get_latest_checkpoint_path (
544
550
dirpath : str ,
545
- metadata_fname : Optional [str ] = None ,
551
+ metadata_fname : Optional [Union [ str , List [ str ]] ] = None ,
546
552
process_group : Optional [dist .ProcessGroup ] = None ,
547
553
) -> Optional [str ]:
548
554
"""
@@ -551,6 +557,7 @@ def get_latest_checkpoint_path(
551
557
Args:
552
558
dirpath: parent directory where checkpoints are saved.
553
559
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
560
+ If a list is provided, it will check that at least one of the files is present.
554
561
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
555
562
556
563
Raises:
@@ -578,7 +585,7 @@ def get_best_checkpoint_path(
578
585
dirpath : str ,
579
586
metric_name : str ,
580
587
mode : Literal ["min" , "max" ],
581
- metadata_fname : Optional [str ] = None ,
588
+ metadata_fname : Optional [Union [ str , List [ str ]] ] = None ,
582
589
process_group : Optional [dist .ProcessGroup ] = None ,
583
590
) -> Optional [str ]:
584
591
"""
@@ -592,6 +599,7 @@ def get_best_checkpoint_path(
592
599
metric_name: Name of the metric to use to find the best checkpoint.
593
600
mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest.
594
601
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
602
+ If a list is provided, it will check that at least one of the files is present.
595
603
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
596
604
597
605
Note:
@@ -614,7 +622,7 @@ def get_best_checkpoint_path(
614
622
@rank_zero_read_and_broadcast
615
623
def get_checkpoint_dirpaths (
616
624
dirpath : str ,
617
- metadata_fname : Optional [str ] = None ,
625
+ metadata_fname : Optional [Union [ str , List [ str ]] ] = None ,
618
626
metric_name : Optional [str ] = None ,
619
627
process_group : Optional [dist .ProcessGroup ] = None ,
620
628
) -> List [CheckpointPath ]:
@@ -629,6 +637,7 @@ def get_checkpoint_dirpaths(
629
637
Args:
630
638
dirpath: parent directory where checkpoints are saved.
631
639
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
640
+ If a list is provided, it will check that at least one of the files is present.
632
641
metric_name: fetches all the checkpoint directories containing the metric name only.
633
642
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
634
643
@@ -642,7 +651,7 @@ def get_checkpoint_dirpaths(
642
651
643
652
def _retrieve_checkpoint_dirpaths (
644
653
dirpath : str ,
645
- metadata_fname : Optional [str ],
654
+ metadata_fname : Optional [Union [ str , List [ str ]] ],
646
655
metric_name : Optional [str ] = None ,
647
656
) -> List [CheckpointPath ]:
648
657
"""
@@ -651,6 +660,7 @@ def _retrieve_checkpoint_dirpaths(
651
660
Args:
652
661
dirpath: parent directory where checkpoints are saved.
653
662
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
663
+ If a list is provided, it will check that at least one of the files is present.
654
664
metric_name: Name of the metric that must exist in checkpoint name.
655
665
"""
656
666
@@ -687,16 +697,21 @@ def _retrieve_checkpoint_dirpaths(
687
697
return candidate_checkpoints
688
698
689
699
# Iterate through all files and directories in the specified directory
690
- # and check if metedata is present or not
700
+ # and check if metadata is present or not
701
+ metadata_fnames = (
702
+ [metadata_fname ] if isinstance (metadata_fname , str ) else metadata_fname
703
+ )
691
704
valid_ckpt_dirpaths : List [CheckpointPath ] = []
692
705
for candidate in candidate_checkpoints :
693
- if not _metadata_exists ( fs , candidate . path , metadata_fname ):
694
- logger . warning (
695
- f"Snapshot metadata is missing from { candidate } ! Skipping this path"
696
- )
706
+ if any (
707
+ _metadata_exists ( fs , candidate . path , fname ) for fname in metadata_fnames
708
+ ):
709
+ valid_ckpt_dirpaths . append ( candidate )
697
710
continue
698
711
699
- valid_ckpt_dirpaths .append (candidate )
712
+ logger .warning (
713
+ f"Snapshot metadata ({ metadata_fnames } ) missing from { candidate } ! Skipping this path"
714
+ )
700
715
701
716
return valid_ckpt_dirpaths
702
717
0 commit comments