@@ -539,21 +539,18 @@ def append_checkpoint(self, ckpt: CheckpointPath) -> None:
539
539
# No metric tracked, most recents goes last
540
540
self ._ckpt_paths .append (ckpt )
541
541
542
- @rank_zero_read_and_broadcast
543
542
def does_checkpoint_exist (
544
- self , ckpt : CheckpointPath , process_group : Optional [dist .ProcessGroup ] = None
543
+ self ,
544
+ ckpt : CheckpointPath ,
545
+ process_group : Optional [dist .ProcessGroup ] = None ,
545
546
) -> bool :
546
547
"""
547
548
Checking whether a checkpoint already exists by verifying whether the optional metadata file is present in the directory.
548
549
If the checkpointer doesn't have a metadata file, this function will always return False. Check is executed in rank 0, but
549
550
result is broadcasted to all ranks.
550
551
"""
551
- if not self ._metadata_fnames :
552
- return False
553
-
554
- fs , _ = url_to_fs (self .dirpath )
555
- return any (
556
- _metadata_exists (fs , ckpt .path , fname ) for fname in self ._metadata_fnames
552
+ return does_checkpoint_exist (
553
+ ckpt .path , self ._metadata_fnames , process_group = process_group
557
554
)
558
555
559
556
@staticmethod
@@ -596,6 +593,33 @@ def remove_checkpoint(self) -> None:
596
593
)
597
594
598
595
596
+ @rank_zero_read_and_broadcast
597
+ def does_checkpoint_exist (
598
+ ckpt_path : str ,
599
+ metadata_fname : Union [str , List [str ]],
600
+ process_group : Optional [dist .ProcessGroup ] = None ,
601
+ ) -> bool :
602
+ """
603
+ Checking whether a checkpoint already exists by verifying whether the optional metadata file is present in the directory.
604
+ Will return False if the metadata_fname is None. Check is executed in rank 0, but
605
+ result is broadcasted to all ranks.
606
+
607
+ Args:
608
+ ckpt: The checkpoint to check.
609
+ metadata_fname: File to check for existence. If a list is provided, it will check that at least one of the files is present.
610
+ process_group: Optional process group on which the ranks will communicate on. By default, the entire world is used.
611
+ """
612
+ if not metadata_fname :
613
+ return False
614
+ else :
615
+ metadata_fnames = (
616
+ [metadata_fname ] if isinstance (metadata_fname , str ) else metadata_fname
617
+ )
618
+
619
+ fs , _ = url_to_fs (ckpt_path )
620
+ return any (_metadata_exists (fs , ckpt_path , fname ) for fname in metadata_fnames )
621
+
622
+
599
623
@rank_zero_read_and_broadcast
600
624
def get_latest_checkpoint_path (
601
625
dirpath : str ,
0 commit comments