diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py index bc782c477d..11797d77a0 100644 --- a/tests/utils/test_checkpoint.py +++ b/tests/utils/test_checkpoint.py @@ -234,7 +234,8 @@ def test_latest_checkpoint_path(self) -> None: path_4 = os.path.join(temp_dir, "epoch_700") os.mkdir(path_4) self.assertEqual( - get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_2 + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), + path_2, ) @skip_if_not_distributed @@ -284,7 +285,8 @@ def _latest_checkpoint_path_distributed() -> None: expected_path = path_container[0] tc.assertIsNotNone(expected_path) tc.assertEqual( - get_latest_checkpoint_path(temp_dir, METADATA_FNAME), expected_path + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), + expected_path, ) if is_rank0: @@ -368,7 +370,12 @@ def test_retrieve_checkpoint_dirpaths(self) -> None: # compares set equality since order of returned dirpaths is not guaranteed # in _retrieve_checkpoint_dirpaths self.assertEqual( - set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)), + { + str(x) + for x in _retrieve_checkpoint_dirpaths( + temp_dir, metadata_fname=None + ) + }, {os.path.join(temp_dir, path) for path in paths[:-1]}, ) self.assertEqual( @@ -382,9 +389,12 @@ def test_retrieve_checkpoint_dirpaths(self) -> None: pass self.assertEqual( - set( - _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata") - ), + { + str(x) + for x in _retrieve_checkpoint_dirpaths( + temp_dir, metadata_fname=".metadata" + ) + }, {os.path.join(temp_dir, paths[2])}, ) @@ -394,30 +404,36 @@ def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None: """ with tempfile.TemporaryDirectory() as temp_dir: paths = [ - "epoch_0_step_10_val_loss=10", - "epoch_1_step_10_val_loss=5", + "epoch_0_step_10_val_loss=10.0", + "epoch_1_step_10_val_loss=5.0", "epoch_2_step_10", "epoch_0_step_5", - "epoch_0_step_6_train_loss=13", + "epoch_0_step_6_train_loss=13.0", ] for path in paths: os.mkdir(os.path.join(temp_dir, path)) # make last path a file instead of a directory - with open(os.path.join(temp_dir, "epoch_0_step_3_val_loss=3"), "w"): + with open(os.path.join(temp_dir, "epoch_0_step_3_val_loss=3.0"), "w"): pass # compares set equality since order of returned dirpaths is not guaranteed # in _retrieve_checkpoint_dirpaths self.assertEqual( - set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)), + { + str(x) + for x in _retrieve_checkpoint_dirpaths( + temp_dir, metadata_fname=None + ) + }, {os.path.join(temp_dir, path) for path in paths}, ) self.assertEqual( - set( - _retrieve_checkpoint_dirpaths( + { + str(x) + for x in _retrieve_checkpoint_dirpaths( temp_dir, metadata_fname=None, metric_name="val_loss" ) - ), + }, { os.path.join(temp_dir, path) for path in paths[:2] }, # since last path is a file @@ -433,11 +449,12 @@ def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None: pass self.assertEqual( - set( - _retrieve_checkpoint_dirpaths( + { + str(x) + for x in _retrieve_checkpoint_dirpaths( temp_dir, metadata_fname=".metadata", metric_name="val_loss" ) - ), + }, {os.path.join(temp_dir, paths[1])}, ) @@ -467,7 +484,7 @@ def create_tmp_dir() -> str: os.mkdir(path2) torch.distributed.barrier() - ckpt_dirpaths = get_checkpoint_dirpaths(temp_dir) + ckpt_dirpaths = [str(x) for x in get_checkpoint_dirpaths(temp_dir)] tc = unittest.TestCase() tc.assertEqual(set(ckpt_dirpaths), {path1, path2}) @@ -492,7 +509,7 @@ def test_get_checkpoint_dirpaths(self) -> None: os.mkdir(path3) self.assertEqual( - set(get_checkpoint_dirpaths(temp_dir)), + {str(x) for x in get_checkpoint_dirpaths(temp_dir)}, {path1, path2, path3}, ) @@ -505,7 +522,10 @@ def test_get_checkpoint_dirpaths(self) -> None: os.mkdir(path3) self.assertEqual( - set(get_checkpoint_dirpaths(temp_dir, metric_name="val_loss")), + { + str(x) + for x in get_checkpoint_dirpaths(temp_dir, metric_name="val_loss") + }, {path1, path2, path3}, ) @@ -519,20 +539,27 @@ def test_checkpoint_sorting_utils(self) -> None: """ Tests the sort utilities """ - paths = ["epoch_1_step_20", "epoch_4_step_130", "epoch_0_step_10_val_loss=10"] - self.assertEqual(_sort_by_recency(paths), [paths[2], paths[0], paths[1]]) + paths = [ + "foo/epoch_1_step_20", + "foo/epoch_4_step_130", + "foo/epoch_0_step_10_val_loss=10.0", + ] + ckpts = [CheckpointPath.from_str(x) for x in paths] + sorted_paths = [str(x) for x in _sort_by_recency(ckpts)] + self.assertEqual(sorted_paths, [paths[2], paths[0], paths[1]]) paths = [ - "epoch_1_step_20_val_loss=0.09", - "epoch_4_step_130_val_loss=29", - "epoch_0_step_10_val_loss=10", + "foo/epoch_1_step_20_val_loss=0.09", + "foo/epoch_4_step_130_val_loss=29.0", + "foo/epoch_0_step_10_val_loss=10.0", ] - self.assertEqual( - _sort_by_metric_value(paths, mode="min"), [paths[1], paths[2], paths[0]] - ) - self.assertEqual( - _sort_by_metric_value(paths, mode="max"), [paths[0], paths[2], paths[1]] - ) + ckpts = [CheckpointPath.from_str(x) for x in paths] + + sorted_paths = [str(x) for x in _sort_by_metric_value(ckpts, mode="min")] + self.assertEqual(sorted_paths, [paths[1], paths[2], paths[0]]) + + sorted_paths = [str(x) for x in _sort_by_metric_value(ckpts, mode="max")] + self.assertEqual(sorted_paths, [paths[0], paths[2], paths[1]]) def test_delete_checkpoint(self) -> None: """ diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index 3cf4cd1746..ee6b7f7991 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -124,11 +124,14 @@ def __init__( # sort by metric value if doing best checkpoint, else by recency if best_checkpoint_config: - self._ckpt_dirpaths = _sort_by_metric_value( + ckpt_dirpaths = _sort_by_metric_value( ckpt_dirpaths, mode=best_checkpoint_config.mode ) else: - self._ckpt_dirpaths = _sort_by_recency(ckpt_dirpaths) + ckpt_dirpaths = _sort_by_recency(ckpt_dirpaths) + + # TODO Remove this when using CheckpointManager + self._ckpt_dirpaths = [str(x) for x in ckpt_dirpaths] self._process_group: Optional[dist.ProcessGroup] = None self._setup_gloo_pg(process_group) diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index 5438afc830..5ef013cde0 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -10,7 +10,7 @@ import re from dataclasses import dataclass from functools import total_ordering -from typing import List, Literal, Optional, Pattern, Tuple +from typing import List, Literal, Optional, Pattern import fsspec import torch.distributed as dist @@ -234,49 +234,21 @@ def get_latest_checkpoint_path( Raises: AssertionError if the checkpoint subdirectories are not named in the format epoch_{epoch}_step_{step}. - """ - - return _latest_checkpoint_path(dirpath, metadata_fname) + Note: When fetching checkpoints in a distributed environment, gloo process groups are recommended over nccl. + """ -def _latest_checkpoint_path( - dirpath: str, metadata_fname: Optional[str] -) -> Optional[str]: candidate_dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname) - - # Initialize variables to store the largest epoch and step numbers - largest_subdirectory = None - largest_epoch = -1 - largest_step = -1 + if not candidate_dirpaths: + return None # Iterate through all files and directories in the specified directory - for candidate in candidate_dirpaths: - # Extract the epoch and step numbers from the directory name - dirname = os.path.basename(candidate) - - # dirname will be of the format epoch_N_step_M - # where N is the epoch number and M is the step number as integers - split = dirname.split("_") - if len(split) < 4: - raise AssertionError( - f"Expected 4 or more elements for pattern of epoch_N_step_M, but received {split})" - ) - - epoch_num, step_num = int(split[1]), int(split[3]) - # Check if the current epoch and step numbers are larger than the largest ones found so far - if epoch_num > largest_epoch: - largest_epoch = epoch_num - largest_step = step_num - largest_subdirectory = dirname - elif largest_epoch == epoch_num and step_num > largest_step: - largest_step = step_num - largest_subdirectory = dirname - - if largest_subdirectory is None: - return None + latest_checkpoint = candidate_dirpaths[0] + for candidate in candidate_dirpaths[1:]: + if candidate.newer_than(latest_checkpoint): + latest_checkpoint = candidate - # Rejoin with the parent directory path and return the largest subdirectory - return os.path.join(dirpath, none_throws(largest_subdirectory)) + return latest_checkpoint.path @rank_zero_read_and_broadcast @@ -296,29 +268,21 @@ def get_best_checkpoint_path( mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest. metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) + + Note: When fetching checkpoints in a distributed environment, gloo process groups are recommended over nccl. """ dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) - if len(dirpaths) == 0: + if not dirpaths: # no checkpoints found return None - best_checkpoint_path = None - best_metric_value = float("inf") if mode == "min" else float("-inf") - for dirpath in dirpaths: - dirname = os.path.basename(dirpath) - metric_value = float(dirname.split("=")[-1]) - - if mode == "min": - if metric_value < best_metric_value: - best_metric_value = metric_value - best_checkpoint_path = dirpath - else: - if metric_value > best_metric_value: - best_metric_value = metric_value - best_checkpoint_path = dirpath + best_checkpoint = dirpaths[0] + for checkpoint in dirpaths[1:]: + if checkpoint.more_optimal_than(best_checkpoint, mode): + best_checkpoint = checkpoint - return best_checkpoint_path + return best_checkpoint.path @rank_zero_read_and_broadcast @@ -327,7 +291,7 @@ def get_checkpoint_dirpaths( metadata_fname: Optional[str] = None, metric_name: Optional[str] = None, process_group: Optional[dist.ProcessGroup] = None, -) -> List[str]: +) -> List[CheckpointPath]: """ Given a parent directory where checkpoints are saved, returns the checkpoint subdirectories. The order of the checkpoints is not guarenteed. @@ -337,12 +301,14 @@ def get_checkpoint_dirpaths( metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. metric_name: fetches all the checkpoint directories containing the metric name only. process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) + + Note: When fetching checkpoints in a distributed environment, gloo process groups are recommended over nccl. """ return _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) -def _sort_by_recency(dirpaths: List[str]) -> List[str]: +def _sort_by_recency(dirpaths: List[CheckpointPath]) -> List[CheckpointPath]: """ Sorts the given list of directories by oldest to newest. @@ -353,16 +319,12 @@ def _sort_by_recency(dirpaths: List[str]) -> List[str]: A sorted list of directory paths, sorted by recency. """ - def sort_fn(path: str) -> Tuple[int, int]: - x = os.path.basename(path) - return (int(x.split("_")[1]), int(x.split("_")[3])) - - return sorted(dirpaths, key=sort_fn) + return sorted(dirpaths) # CheckpointPath is well ordered by recency def _sort_by_metric_value( - dirpaths: List[str], mode: Literal["min", "max"] -) -> List[str]: + dirpaths: List[CheckpointPath], mode: Literal["min", "max"] +) -> List[CheckpointPath]: """ Sorts the given list of directories by the metric values. @@ -373,15 +335,9 @@ def _sort_by_metric_value( Returns: A sorted list of directory paths, sorted by the metric values. """ - - def sort_metric_fn(path: str) -> float: - x = os.path.basename(path) - metric_val = float(x.split("=")[-1]) - return metric_val - return sorted( dirpaths, - key=sort_metric_fn, + key=lambda x: x.metric_data.value, # sort descending if min, placing worst metric at top of list reverse=(mode == "min"), ) @@ -391,21 +347,15 @@ def _retrieve_checkpoint_dirpaths( dirpath: str, metadata_fname: Optional[str], metric_name: Optional[str] = None, -) -> List[str]: +) -> List[CheckpointPath]: """ Given a parent directory where checkpoints are saved, return the unsorted checkpoint subdirectories Args: dirpath: parent directory where checkpoints are saved. - metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. - metric_name: Name of the metric that must exist in checkpoint name. + metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. metric_name: Name of the metric that must exist in checkpoint name. """ - if dirpath[-1] == "/": - # removes trailing forward slash if present - # required for regex search to work - dirpath = dirpath[:-1] - fs, _ = url_to_fs(dirpath) if not fs.exists(dirpath): @@ -418,23 +368,31 @@ def _retrieve_checkpoint_dirpaths( logger.warning(f"Input dirpath doesn't contain any subdirectories: {dirpath}") return [] - # Define the regex pattern to match the directory names - pattern = rf"^{dirpath}/epoch_\d+_step_\d+" - if metric_name: - # inject metric name in regex search - pattern += rf"_{metric_name}=" - snapshot_dirpath_pattern: Pattern[str] = re.compile(pattern) - candidate_dirpaths = list(filter(snapshot_dirpath_pattern.match, contents)) + # Parse the valid checkpoint directories + candidate_checkpoints: List[CheckpointPath] = [] + for candidate_dirpath in contents: + try: + ckpt = CheckpointPath.from_str(candidate_dirpath) + except ValueError: + continue + + # If a metric was provided, keep only the checkpoints tracking it + if metric_name and not ( + ckpt.metric_data and ckpt.metric_data.name == metric_name + ): + continue + + candidate_checkpoints.append(ckpt) if not metadata_fname: # return early as we don't need to filter out any paths - return candidate_dirpaths + return candidate_checkpoints # Iterate through all files and directories in the specified directory # and check if metedata is present or not - valid_ckpt_dirpaths = [] - for candidate in candidate_dirpaths: - if not _metadata_exists(fs, candidate, metadata_fname): + valid_ckpt_dirpaths: List[CheckpointPath] = [] + for candidate in candidate_checkpoints: + if not _metadata_exists(fs, candidate.path, metadata_fname): logger.warning( f"Snapshot metadata is missing from {candidate}! Skipping this path" )