@@ -119,6 +119,63 @@ callback_model_checkpoint <- function(filepath, monitor = "val_loss", verbose =
119119}
120120
121121
122+ # ' Callback to back up and restore the training state
123+ # '
124+ # ' @details
125+ # ' `BackupAndRestore` callback is intended to recover training from an
126+ # ' interruption that has happened in the middle of a `fit(model)` execution, by
127+ # ' backing up the training states in a temporary checkpoint file (with the help
128+ # ' of a `tf.train.CheckpointManager`), at the end of each epoch. Each backup
129+ # ' overwrites the previously written checkpoint file, so at any given time there
130+ # ' is at most one such checkpoint file for backup/restoring purpose.
131+ # '
132+ # ' If training restarts before completion, the training state (which includes the
133+ # ' `Model` weights and epoch number) is restored to the most recently saved state
134+ # ' at the beginning of a new `fit()` run. At the completion of a `fit()`
135+ # ' run, the temporary checkpoint file is deleted.
136+ # '
137+ # ' Note that the user is responsible to bring jobs back after the interruption.
138+ # ' This callback is important for the backup and restore mechanism for fault
139+ # ' tolerance purpose, and the model to be restored from an previous checkpoint is
140+ # ' expected to be the same as the one used to back up. If user changes arguments
141+ # ' passed to compile or fit, the checkpoint saved for fault tolerance can become
142+ # ' invalid.
143+ # '
144+ # ' Note:
145+ # '
146+ # ' 1. This callback is not compatible with eager execution disabled.
147+ # '
148+ # ' 2. A checkpoint is saved at the end of each epoch. After restoring,
149+ # ' `fit()` redoes any partial work during the unfinished epoch in which the
150+ # ' training got restarted (so the work done before the interruption doesn't
151+ # ' affect the final model state).
152+ # '
153+ # ' 3. This works for both single worker and multi-worker modes. When `fit()`
154+ # ' is used with `tf.distribute`, it supports `tf.distribute.MirroredStrategy`,
155+ # ' `tf.distribute.MultiWorkerMirroredStrategy`, `tf.distribute.TPUStrategy`, and
156+ # ' `tf.distribute.experimental.ParameterServerStrategy`.
157+ # '
158+ # ' @param backup_dir String, path to store the checkpoint.
159+ # ' e.g. `backup_dir = normalizePath('./backup')`
160+ # ' This is the directory in which the system stores temporary files to
161+ # ' recover the model from jobs terminated unexpectedly. The directory
162+ # ' cannot be reused elsewhere to store other files, e.g. by
163+ # ' `BackupAndRestore` callback of another training, or by another callback
164+ # ' (`ModelCheckpoint`) of the same training.
165+ # ' @param ... For backwards and forwards compatibility
166+ # '
167+ # ' @seealso
168+ # ' + <https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/BackupAndRestore>
169+ # '
170+ # ' @export
171+ callback_backup_and_restore <-
172+ function (backup_dir ) {
173+ args <- capture_args(match.call(), NULL )
174+ require_tf_version(" 2.8" , " callback_backup_and_restore" )
175+ do.call(keras $ callbacks $ BackupAndRestore , args )
176+ }
177+
178+
122179# ' Stop training when a monitored quantity has stopped improving.
123180# '
124181# ' @inheritParams callback_model_checkpoint
@@ -750,3 +807,5 @@ normalize_callbacks <- function(callbacks) {
750807}
751808
752809empty_fun <- function (batch , logs = NULL ) {}
810+
811+
0 commit comments