@@ -366,6 +366,7 @@ def __init__(
366
366
keep_last_n_checkpoints : Optional [int ] = None ,
367
367
metadata_fnames : Optional [List [str ]] = None ,
368
368
process_group : Optional [dist .ProcessGroup ] = None ,
369
+ file_system : Optional [fsspec .AbstractFileSystem ] = None ,
369
370
) -> None :
370
371
"""
371
372
Initialize a checkpoint manager. If a `keep_last_n_checkpoints` value is provided, this will read the
@@ -389,6 +390,11 @@ def __init__(
389
390
self ._keep_last_n_checkpoints = keep_last_n_checkpoints
390
391
self ._pg_wrapper = PGWrapper (process_group )
391
392
393
+ if file_system is None :
394
+ file_system , _ = url_to_fs (self .dirpath )
395
+
396
+ self ._file_system : fsspec .AbstractFileSystem = file_system
397
+
392
398
if metadata_fnames is None :
393
399
self ._metadata_fnames : List [str ] = []
394
400
else :
@@ -568,17 +574,16 @@ def does_checkpoint_exist(
568
574
ckpt .path , self ._metadata_fnames , process_group = process_group
569
575
)
570
576
571
- @staticmethod
572
577
def does_checkpoint_metadata_exist (
578
+ self ,
573
579
checkpoint_path : str ,
574
580
metadata_fname : str ,
575
581
) -> bool :
576
582
"""
577
583
Checking whether a checkpoint metadata file exists in the directory.
578
584
If the checkpointer has that metadata file, this function will returns True. Returns False otherwise.
579
585
"""
580
- fs , _ = url_to_fs (checkpoint_path )
581
- return _metadata_exists (fs , checkpoint_path , metadata_fname )
586
+ return _metadata_exists (self ._file_system , checkpoint_path , metadata_fname )
582
587
583
588
@staticmethod
584
589
@rank_zero_read_and_broadcast
@@ -596,9 +601,8 @@ def remove_checkpoint(self) -> None:
596
601
"""
597
602
worst_ckpt_path = self ._ckpt_paths .pop (0 )
598
603
if self ._pg_wrapper .get_rank () == 0 :
599
- fs , _ = url_to_fs (self .dirpath )
600
604
try :
601
- fs .rm (worst_ckpt_path .path , recursive = True )
605
+ self . _file_system .rm (worst_ckpt_path .path , recursive = True )
602
606
except Exception as exc :
603
607
logger .error (
604
608
(
@@ -612,6 +616,7 @@ def remove_checkpoint(self) -> None:
612
616
def does_checkpoint_exist (
613
617
ckpt_path : str ,
614
618
metadata_fname : Union [str , List [str ]],
619
+ file_system : Optional [fsspec .AbstractFileSystem ] = None ,
615
620
process_group : Optional [dist .ProcessGroup ] = None ,
616
621
) -> bool :
617
622
"""
@@ -622,6 +627,8 @@ def does_checkpoint_exist(
622
627
Args:
623
628
ckpt: The checkpoint to check.
624
629
metadata_fname: File to check for existence. If a list is provided, it will check that at least one of the files is present.
630
+ file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
631
+ used to match the file system of the dirpath.
625
632
process_group: Optional process group on which the ranks will communicate on. By default, the entire world is used.
626
633
"""
627
634
if not metadata_fname :
@@ -631,14 +638,18 @@ def does_checkpoint_exist(
631
638
[metadata_fname ] if isinstance (metadata_fname , str ) else metadata_fname
632
639
)
633
640
634
- fs , _ = url_to_fs (ckpt_path )
641
+ fs = file_system
642
+ if fs is None :
643
+ fs , _ = url_to_fs (ckpt_path )
644
+
635
645
return any (_metadata_exists (fs , ckpt_path , fname ) for fname in metadata_fnames )
636
646
637
647
638
648
@rank_zero_read_and_broadcast
639
649
def get_latest_checkpoint_path (
640
650
dirpath : str ,
641
651
metadata_fname : Optional [Union [str , List [str ]]] = None ,
652
+ file_system : Optional [fsspec .AbstractFileSystem ] = None ,
642
653
process_group : Optional [dist .ProcessGroup ] = None ,
643
654
) -> Optional [str ]:
644
655
"""
@@ -648,6 +659,8 @@ def get_latest_checkpoint_path(
648
659
dirpath: parent directory where checkpoints are saved.
649
660
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
650
661
If a list is provided, it will check that at least one of the files is present.
662
+ file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
663
+ used to match the file system of the dirpath.
651
664
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
652
665
653
666
Raises:
@@ -658,14 +671,17 @@ def get_latest_checkpoint_path(
658
671
gloo process groups are recommended over nccl.
659
672
"""
660
673
661
- return _get_latest_checkpoint_path (dirpath , metadata_fname )
674
+ return _get_latest_checkpoint_path (dirpath , metadata_fname , file_system )
662
675
663
676
664
677
def _get_latest_checkpoint_path (
665
678
dirpath : str ,
666
679
metadata_fname : Optional [Union [str , List [str ]]] = None ,
680
+ file_system : Optional [fsspec .AbstractFileSystem ] = None ,
667
681
) -> Optional [str ]:
668
- candidate_dirpaths = _retrieve_checkpoint_dirpaths (dirpath , metadata_fname )
682
+ candidate_dirpaths = _retrieve_checkpoint_dirpaths (
683
+ dirpath , metadata_fname , file_system = file_system
684
+ )
669
685
if not candidate_dirpaths :
670
686
return None
671
687
@@ -683,6 +699,7 @@ def get_best_checkpoint_path(
683
699
metric_name : str ,
684
700
mode : Literal ["min" , "max" ],
685
701
metadata_fname : Optional [Union [str , List [str ]]] = None ,
702
+ file_system : Optional [fsspec .AbstractFileSystem ] = None ,
686
703
process_group : Optional [dist .ProcessGroup ] = None ,
687
704
) -> Optional [str ]:
688
705
"""
@@ -697,14 +714,18 @@ def get_best_checkpoint_path(
697
714
mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest.
698
715
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
699
716
If a list is provided, it will check that at least one of the files is present.
717
+ file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
718
+ used to match the file system of the dirpath.
700
719
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
701
720
702
721
Note:
703
722
When doing distributed training, only rank 0 will read the file system. The result will be broadcasted to all ranks.
704
723
gloo process groups are recommended over nccl.
705
724
"""
706
725
707
- dirpaths = _retrieve_checkpoint_dirpaths (dirpath , metadata_fname , metric_name )
726
+ dirpaths = _retrieve_checkpoint_dirpaths (
727
+ dirpath , metadata_fname , metric_name , file_system = file_system
728
+ )
708
729
if not dirpaths :
709
730
return None
710
731
@@ -721,6 +742,7 @@ def get_checkpoint_dirpaths(
721
742
dirpath : str ,
722
743
metadata_fname : Optional [Union [str , List [str ]]] = None ,
723
744
metric_name : Optional [str ] = None ,
745
+ file_system : Optional [fsspec .AbstractFileSystem ] = None ,
724
746
process_group : Optional [dist .ProcessGroup ] = None ,
725
747
) -> List [CheckpointPath ]:
726
748
"""
@@ -736,20 +758,25 @@ def get_checkpoint_dirpaths(
736
758
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
737
759
If a list is provided, it will check that at least one of the files is present.
738
760
metric_name: fetches all the checkpoint directories containing the metric name only.
761
+ file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
762
+ used to match the file system of the dirpath.
739
763
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
740
764
741
765
Note:
742
766
When doing distributed training, only rank 0 will read the file system. The result will be broadcasted to all ranks.
743
767
gloo process groups are recommended over nccl.
744
768
"""
745
769
746
- return _retrieve_checkpoint_dirpaths (dirpath , metadata_fname , metric_name )
770
+ return _retrieve_checkpoint_dirpaths (
771
+ dirpath , metadata_fname , metric_name , file_system = file_system
772
+ )
747
773
748
774
749
775
def _retrieve_checkpoint_dirpaths (
750
776
dirpath : str ,
751
777
metadata_fname : Optional [Union [str , List [str ]]],
752
778
metric_name : Optional [str ] = None ,
779
+ file_system : Optional [fsspec .AbstractFileSystem ] = None ,
753
780
) -> List [CheckpointPath ]:
754
781
"""
755
782
Given a parent directory where checkpoints are saved, return the unsorted checkpoint subdirectories
@@ -759,9 +786,12 @@ def _retrieve_checkpoint_dirpaths(
759
786
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
760
787
If a list is provided, it will check that at least one of the files is present.
761
788
metric_name: Name of the metric that must exist in checkpoint name.
789
+ file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
790
+ used to match the file system of the dirpath.
762
791
"""
763
-
764
- fs , _ = url_to_fs (dirpath )
792
+ fs = file_system
793
+ if fs is None :
794
+ fs , _ = url_to_fs (dirpath )
765
795
766
796
if not fs .exists (dirpath ):
767
797
logger .warning (f"Input dirpath doesn't exist: { dirpath } " )
0 commit comments