You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: torchtnt/framework/callbacks/base_checkpointer.py
+9-7Lines changed: 9 additions & 7 deletions
Original file line number
Diff line number
Diff line change
@@ -9,7 +9,7 @@
9
9
importabc
10
10
importlogging
11
11
fromdatetimeimporttimedelta
12
-
fromtypingimportAny, cast, Dict, Iterable, Literal, Optional, Union
12
+
fromtypingimportAny, cast, Dict, Iterable, List, Literal, Optional, Union
13
13
14
14
importtorch.distributedasdist
15
15
frompyre_extensionsimportnone_throws
@@ -43,7 +43,8 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
43
43
2) ``restore`` which implements restoring the checkpoint given the relevant checkpoint path.
44
44
45
45
The subclass may override the ``metadata_fname`` attribute to specify the filename of the metadata file that will be written within the checkpoint directory.
46
-
This will be used by this base class to ensure the integrity of the checkpoint.
46
+
This will be used by this base class to ensure the integrity of the checkpoint. This is a list because some checkpointers may allow more than one valid
47
+
``metadata_fnames``, depending on storage or optimization configurations.
47
48
48
49
Args:
49
50
dirpath: Parent directory to save checkpoints to.
@@ -67,7 +68,8 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
67
68
checkpoint will be saved, without the metric value in the checkpoint name
68
69
"""
69
70
70
-
metadata_fname: Optional[str] =None
71
+
# No metadata file is checked by default. This can be overridden by subclasses.
Copy file name to clipboardExpand all lines: torchtnt/framework/callbacks/dcp_saver.py
+2-2Lines changed: 2 additions & 2 deletions
Original file line number
Diff line number
Diff line change
@@ -9,7 +9,7 @@
9
9
importlogging
10
10
importtime
11
11
fromconcurrent.futuresimportFuture
12
-
fromtypingimportAny, Dict, Iterable, Optional, Union
12
+
fromtypingimportAny, Dict, Iterable, List, Optional, Union
13
13
14
14
importtorch
15
15
importtorch.distributedasdist
@@ -102,7 +102,7 @@ class DistributedCheckpointSaver(BaseCheckpointer):
102
102
appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends.
Copy file name to clipboardExpand all lines: torchtnt/framework/callbacks/torchsnapshot_saver.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -96,7 +96,7 @@ class TorchSnapshotSaver(BaseCheckpointer):
96
96
appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends.
0 commit comments