Skip to content

Commit b467a0e

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Use TNT's ManifoldPathHandler for listing checkpoints internally
Reviewed By: JKSenthil Differential Revision: D65370757 fbshipit-source-id: 4a6d523d9870a63a3b42f9748917af2d8300675a
1 parent 8150bcc commit b467a0e

File tree

3 files changed

+57
-23
lines changed

3 files changed

+57
-23
lines changed

tests/utils/test_checkpoint.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,18 +1422,10 @@ def test_does_checkpoint_metadata_exist(self) -> None:
14221422
dirpath = os.path.join(temp_dir, "checkpoint")
14231423
Snapshot.take(dirpath, app_state=app_state)
14241424

1425-
self.assertTrue(
1426-
CheckpointManager.does_checkpoint_metadata_exist(
1427-
dirpath, SNAPSHOT_METADATA_FNAME
1428-
)
1429-
)
1425+
self.assertTrue(does_checkpoint_exist(dirpath, SNAPSHOT_METADATA_FNAME))
14301426

14311427
os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME))
1432-
self.assertFalse(
1433-
CheckpointManager.does_checkpoint_metadata_exist(
1434-
dirpath, SNAPSHOT_METADATA_FNAME
1435-
)
1436-
)
1428+
self.assertFalse(does_checkpoint_exist(dirpath, SNAPSHOT_METADATA_FNAME))
14371429

14381430
def test_does_checkpoint_exist(self) -> None:
14391431
with tempfile.TemporaryDirectory() as temp_dir:

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from datetime import timedelta
1313
from typing import Any, cast, Iterable, List, Literal, Optional, Union
1414

15+
import fsspec
16+
1517
import torch.distributed as dist
1618
from pyre_extensions import none_throws
1719
from torchtnt.framework.callback import Callback
@@ -449,6 +451,7 @@ def restore_from_latest(
449451
train_dataloader: Optional[Iterable[TTrainData]] = None,
450452
process_group: Optional[dist.ProcessGroup] = None,
451453
restore_options: Optional[RestoreOptions] = None,
454+
file_system: Optional[fsspec.AbstractFileSystem] = None,
452455
**kwargs: Any,
453456
) -> bool:
454457
"""
@@ -463,12 +466,17 @@ def restore_from_latest(
463466
train_dataloader: An optional train dataloader to restore.
464467
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
465468
restore_options: Controls what to filter when restoring the state.
469+
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
470+
used to match the file system of the dirpath.
466471
467472
Returns:
468473
True if the latest checkpoint directory was found and successfully restored, otherwise False.
469474
"""
470475
path = get_latest_checkpoint_path(
471-
dirpath, metadata_fname=cls.metadata_fnames, process_group=process_group
476+
dirpath,
477+
metadata_fname=cls.metadata_fnames,
478+
process_group=process_group,
479+
file_system=file_system,
472480
)
473481
if path is None:
474482
logger.info(
@@ -497,6 +505,7 @@ def restore_from_best(
497505
train_dataloader: Optional[Iterable[TTrainData]] = None,
498506
process_group: Optional[dist.ProcessGroup] = None,
499507
restore_options: Optional[RestoreOptions] = None,
508+
file_system: Optional[fsspec.AbstractFileSystem] = None,
500509
**kwargs: Any,
501510
) -> bool:
502511
"""
@@ -512,6 +521,8 @@ def restore_from_best(
512521
mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest.
513522
train_dataloader: An optional train dataloader to restore.
514523
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
524+
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
525+
used to match the file system of the dirpath.
515526
restore_options: Controls what to filter when restoring the state.
516527
517528
Returns:
@@ -522,6 +533,7 @@ def restore_from_best(
522533
metric_name=metric_name,
523534
mode=mode,
524535
metadata_fname=cls.metadata_fnames,
536+
file_system=file_system,
525537
process_group=process_group,
526538
)
527539

torchtnt/utils/checkpoint.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ def __init__(
366366
keep_last_n_checkpoints: Optional[int] = None,
367367
metadata_fnames: Optional[List[str]] = None,
368368
process_group: Optional[dist.ProcessGroup] = None,
369+
file_system: Optional[fsspec.AbstractFileSystem] = None,
369370
) -> None:
370371
"""
371372
Initialize a checkpoint manager. If a `keep_last_n_checkpoints` value is provided, this will read the
@@ -389,6 +390,11 @@ def __init__(
389390
self._keep_last_n_checkpoints = keep_last_n_checkpoints
390391
self._pg_wrapper = PGWrapper(process_group)
391392

393+
if file_system is None:
394+
file_system, _ = url_to_fs(self.dirpath)
395+
396+
self._file_system: fsspec.AbstractFileSystem = file_system
397+
392398
if metadata_fnames is None:
393399
self._metadata_fnames: List[str] = []
394400
else:
@@ -568,17 +574,16 @@ def does_checkpoint_exist(
568574
ckpt.path, self._metadata_fnames, process_group=process_group
569575
)
570576

571-
@staticmethod
572577
def does_checkpoint_metadata_exist(
578+
self,
573579
checkpoint_path: str,
574580
metadata_fname: str,
575581
) -> bool:
576582
"""
577583
Checking whether a checkpoint metadata file exists in the directory.
578584
If the checkpointer has that metadata file, this function will returns True. Returns False otherwise.
579585
"""
580-
fs, _ = url_to_fs(checkpoint_path)
581-
return _metadata_exists(fs, checkpoint_path, metadata_fname)
586+
return _metadata_exists(self._file_system, checkpoint_path, metadata_fname)
582587

583588
@staticmethod
584589
@rank_zero_read_and_broadcast
@@ -596,9 +601,8 @@ def remove_checkpoint(self) -> None:
596601
"""
597602
worst_ckpt_path = self._ckpt_paths.pop(0)
598603
if self._pg_wrapper.get_rank() == 0:
599-
fs, _ = url_to_fs(self.dirpath)
600604
try:
601-
fs.rm(worst_ckpt_path.path, recursive=True)
605+
self._file_system.rm(worst_ckpt_path.path, recursive=True)
602606
except Exception as exc:
603607
logger.error(
604608
(
@@ -612,6 +616,7 @@ def remove_checkpoint(self) -> None:
612616
def does_checkpoint_exist(
613617
ckpt_path: str,
614618
metadata_fname: Union[str, List[str]],
619+
file_system: Optional[fsspec.AbstractFileSystem] = None,
615620
process_group: Optional[dist.ProcessGroup] = None,
616621
) -> bool:
617622
"""
@@ -622,6 +627,8 @@ def does_checkpoint_exist(
622627
Args:
623628
ckpt: The checkpoint to check.
624629
metadata_fname: File to check for existence. If a list is provided, it will check that at least one of the files is present.
630+
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
631+
used to match the file system of the dirpath.
625632
process_group: Optional process group on which the ranks will communicate on. By default, the entire world is used.
626633
"""
627634
if not metadata_fname:
@@ -631,14 +638,18 @@ def does_checkpoint_exist(
631638
[metadata_fname] if isinstance(metadata_fname, str) else metadata_fname
632639
)
633640

634-
fs, _ = url_to_fs(ckpt_path)
641+
fs = file_system
642+
if fs is None:
643+
fs, _ = url_to_fs(ckpt_path)
644+
635645
return any(_metadata_exists(fs, ckpt_path, fname) for fname in metadata_fnames)
636646

637647

638648
@rank_zero_read_and_broadcast
639649
def get_latest_checkpoint_path(
640650
dirpath: str,
641651
metadata_fname: Optional[Union[str, List[str]]] = None,
652+
file_system: Optional[fsspec.AbstractFileSystem] = None,
642653
process_group: Optional[dist.ProcessGroup] = None,
643654
) -> Optional[str]:
644655
"""
@@ -648,6 +659,8 @@ def get_latest_checkpoint_path(
648659
dirpath: parent directory where checkpoints are saved.
649660
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
650661
If a list is provided, it will check that at least one of the files is present.
662+
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
663+
used to match the file system of the dirpath.
651664
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
652665
653666
Raises:
@@ -658,14 +671,17 @@ def get_latest_checkpoint_path(
658671
gloo process groups are recommended over nccl.
659672
"""
660673

661-
return _get_latest_checkpoint_path(dirpath, metadata_fname)
674+
return _get_latest_checkpoint_path(dirpath, metadata_fname, file_system)
662675

663676

664677
def _get_latest_checkpoint_path(
665678
dirpath: str,
666679
metadata_fname: Optional[Union[str, List[str]]] = None,
680+
file_system: Optional[fsspec.AbstractFileSystem] = None,
667681
) -> Optional[str]:
668-
candidate_dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname)
682+
candidate_dirpaths = _retrieve_checkpoint_dirpaths(
683+
dirpath, metadata_fname, file_system=file_system
684+
)
669685
if not candidate_dirpaths:
670686
return None
671687

@@ -683,6 +699,7 @@ def get_best_checkpoint_path(
683699
metric_name: str,
684700
mode: Literal["min", "max"],
685701
metadata_fname: Optional[Union[str, List[str]]] = None,
702+
file_system: Optional[fsspec.AbstractFileSystem] = None,
686703
process_group: Optional[dist.ProcessGroup] = None,
687704
) -> Optional[str]:
688705
"""
@@ -697,14 +714,18 @@ def get_best_checkpoint_path(
697714
mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest.
698715
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
699716
If a list is provided, it will check that at least one of the files is present.
717+
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
718+
used to match the file system of the dirpath.
700719
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
701720
702721
Note:
703722
When doing distributed training, only rank 0 will read the file system. The result will be broadcasted to all ranks.
704723
gloo process groups are recommended over nccl.
705724
"""
706725

707-
dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name)
726+
dirpaths = _retrieve_checkpoint_dirpaths(
727+
dirpath, metadata_fname, metric_name, file_system=file_system
728+
)
708729
if not dirpaths:
709730
return None
710731

@@ -721,6 +742,7 @@ def get_checkpoint_dirpaths(
721742
dirpath: str,
722743
metadata_fname: Optional[Union[str, List[str]]] = None,
723744
metric_name: Optional[str] = None,
745+
file_system: Optional[fsspec.AbstractFileSystem] = None,
724746
process_group: Optional[dist.ProcessGroup] = None,
725747
) -> List[CheckpointPath]:
726748
"""
@@ -736,20 +758,25 @@ def get_checkpoint_dirpaths(
736758
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
737759
If a list is provided, it will check that at least one of the files is present.
738760
metric_name: fetches all the checkpoint directories containing the metric name only.
761+
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
762+
used to match the file system of the dirpath.
739763
process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world)
740764
741765
Note:
742766
When doing distributed training, only rank 0 will read the file system. The result will be broadcasted to all ranks.
743767
gloo process groups are recommended over nccl.
744768
"""
745769

746-
return _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name)
770+
return _retrieve_checkpoint_dirpaths(
771+
dirpath, metadata_fname, metric_name, file_system=file_system
772+
)
747773

748774

749775
def _retrieve_checkpoint_dirpaths(
750776
dirpath: str,
751777
metadata_fname: Optional[Union[str, List[str]]],
752778
metric_name: Optional[str] = None,
779+
file_system: Optional[fsspec.AbstractFileSystem] = None,
753780
) -> List[CheckpointPath]:
754781
"""
755782
Given a parent directory where checkpoints are saved, return the unsorted checkpoint subdirectories
@@ -759,9 +786,12 @@ def _retrieve_checkpoint_dirpaths(
759786
metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist.
760787
If a list is provided, it will check that at least one of the files is present.
761788
metric_name: Name of the metric that must exist in checkpoint name.
789+
file_system: If a custom file system should be used to fetch the checkpoint directories. Otherwise, fsspec will be
790+
used to match the file system of the dirpath.
762791
"""
763-
764-
fs, _ = url_to_fs(dirpath)
792+
fs = file_system
793+
if fs is None:
794+
fs, _ = url_to_fs(dirpath)
765795

766796
if not fs.exists(dirpath):
767797
logger.warning(f"Input dirpath doesn't exist: {dirpath}")

0 commit comments

Comments
 (0)