Skip to content

Commit c25aebf

Browse files
committed
Merge branch 'main' of https://github.com/rstudio/keras into main
2 parents 43c92ba + 9934b16 commit c25aebf

File tree

7 files changed

+138
-10
lines changed

7 files changed

+138
-10
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ export(array_reshape)
7979
export(as_tensor)
8080
export(backend)
8181
export(bidirectional)
82+
export(callback_backup_and_restore)
8283
export(callback_csv_logger)
8384
export(callback_early_stopping)
8485
export(callback_lambda)

NEWS.md

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# keras (development version)
22

3+
- New `callback_backup_and_restore()`, for resuming an interrupted `fit()` call.
4+
35
# keras 2.9.0
46

57
- New functions for constructing custom keras subclasses:
@@ -14,7 +16,7 @@
1416
should be an active binding (i.e., decorated with Python's `@property`).
1517
`mark_active()` can be used in the `new_*_class` family of class constructors
1618
as well as `%py_class%`.
17-
19+
1820
- `r_to_py()` method for R6 classes and `%py_class%` gain support for
1921
`private` fields and methods. Any R objects stored in `private` will only be
2022
available to methods, and will not be converted to Python.
@@ -32,7 +34,7 @@
3234

3335
- New L2 unit normilization layer: `layer_unit_normalization()`.
3436

35-
- New `regularizer_orthogonal`, a regularizer that encourages
37+
- New `regularizer_orthogonal`, a regularizer that encourages
3638
orthogonality between the rows (or columns) or a weight matrix.
3739

3840
- New `zip_lists()` function for transposing lists, optionally matching by name.
@@ -90,16 +92,16 @@
9092

9193
- KerasTensor objects (e.g, returned by `layer_input()`) now inherit S3 methods
9294
for `"tensorflow.tensor"`.
93-
94-
- `plot.keras_training_history()` no longer issues message
95+
96+
- `plot.keras_training_history()` no longer issues message
9597
``` `geom_smooth()` using formula 'y ~ x' ``` when `method = "ggplot2"`.
96-
97-
- `print` and related methods for models (`format`, `summary`) now accept
98+
99+
- `print` and related methods for models (`format`, `summary`) now accept
98100
a `width` argument.
99101

100-
- `evaluate()`, `fit()`, and `predict()` methods for keras Models now default
101-
to `verbose = "auto"`, with verbosity adjusted appropriately based on calls to
102-
`keras$utils$disable_interactive_logging()`, and contexts like
102+
- `evaluate()`, `fit()`, and `predict()` methods for keras Models now default
103+
to `verbose = "auto"`, with verbosity adjusted appropriately based on calls to
104+
`keras$utils$disable_interactive_logging()`, and contexts like
103105
`ParameterServerStrategy`.
104106

105107
- `install_keras()` now accepts `version = "release-cpu"` as a valid specification.

R/callbacks.R

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,63 @@ callback_model_checkpoint <- function(filepath, monitor = "val_loss", verbose =
119119
}
120120

121121

122+
#' Callback to back up and restore the training state
123+
#'
124+
#' @details
125+
#' `BackupAndRestore` callback is intended to recover training from an
126+
#' interruption that has happened in the middle of a `fit(model)` execution, by
127+
#' backing up the training states in a temporary checkpoint file (with the help
128+
#' of a `tf.train.CheckpointManager`), at the end of each epoch. Each backup
129+
#' overwrites the previously written checkpoint file, so at any given time there
130+
#' is at most one such checkpoint file for backup/restoring purpose.
131+
#'
132+
#' If training restarts before completion, the training state (which includes the
133+
#' `Model` weights and epoch number) is restored to the most recently saved state
134+
#' at the beginning of a new `fit()` run. At the completion of a `fit()`
135+
#' run, the temporary checkpoint file is deleted.
136+
#'
137+
#' Note that the user is responsible to bring jobs back after the interruption.
138+
#' This callback is important for the backup and restore mechanism for fault
139+
#' tolerance purpose, and the model to be restored from an previous checkpoint is
140+
#' expected to be the same as the one used to back up. If user changes arguments
141+
#' passed to compile or fit, the checkpoint saved for fault tolerance can become
142+
#' invalid.
143+
#'
144+
#' Note:
145+
#'
146+
#' 1. This callback is not compatible with eager execution disabled.
147+
#'
148+
#' 2. A checkpoint is saved at the end of each epoch. After restoring,
149+
#' `fit()` redoes any partial work during the unfinished epoch in which the
150+
#' training got restarted (so the work done before the interruption doesn't
151+
#' affect the final model state).
152+
#'
153+
#' 3. This works for both single worker and multi-worker modes. When `fit()`
154+
#' is used with `tf.distribute`, it supports `tf.distribute.MirroredStrategy`,
155+
#' `tf.distribute.MultiWorkerMirroredStrategy`, `tf.distribute.TPUStrategy`, and
156+
#' `tf.distribute.experimental.ParameterServerStrategy`.
157+
#'
158+
#' @param backup_dir String, path to store the checkpoint.
159+
#' e.g. `backup_dir = normalizePath('./backup')`
160+
#' This is the directory in which the system stores temporary files to
161+
#' recover the model from jobs terminated unexpectedly. The directory
162+
#' cannot be reused elsewhere to store other files, e.g. by
163+
#' `BackupAndRestore` callback of another training, or by another callback
164+
#' (`ModelCheckpoint`) of the same training.
165+
#' @param ... For backwards and forwards compatibility
166+
#'
167+
#' @seealso
168+
#' + <https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/BackupAndRestore>
169+
#'
170+
#' @export
171+
callback_backup_and_restore <-
172+
function(backup_dir) {
173+
args <- capture_args(match.call(), NULL)
174+
require_tf_version("2.8", "callback_backup_and_restore")
175+
do.call(keras$callbacks$BackupAndRestore, args)
176+
}
177+
178+
122179
#' Stop training when a monitored quantity has stopped improving.
123180
#'
124181
#' @inheritParams callback_model_checkpoint
@@ -750,3 +807,5 @@ normalize_callbacks <- function(callbacks) {
750807
}
751808

752809
empty_fun <- function(batch, logs = NULL) {}
810+
811+

keras.Rproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ StripTrailingWhitespace: Yes
1717

1818
BuildType: Package
1919
PackageUseDevtools: Yes
20+
PackageCleanBeforeInstall: Yes
2021
PackageInstallArgs: --no-multiarch --with-keep.source
2122
PackageRoxygenize: rd,collate,namespace

man/callback_backup_and_restore.Rd

Lines changed: 60 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-callbacks.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ if (tensorflow::tf_version() <= "2.1")
3232

3333

3434
test_callback("model_checkpoint", callback_model_checkpoint(tempfile(fileext = ".h5")), h5py = TRUE)
35+
36+
if(tf_version() >= "2.8")
37+
test_callback("backup_and_restore", callback_backup_and_restore(tempfile()))
38+
3539
test_callback("learning_rate_scheduler", callback_learning_rate_scheduler(schedule = function (index, ...) {
3640
0.1
3741
}))

tools/make-wrapper.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,4 +175,5 @@ print.r_py_wrapper2 <- function(x, ...) {
175175
#
176176
# new_wrapper("learning_rate_schedule", keras$optimizers$schedules$PolynomialDecay)
177177

178-
new_wrapper("regularizer", keras$regularizers$OrthogonalRegularizer)
178+
# new_wrapper("regularizer", keras$regularizers$OrthogonalRegularizer)
179+
new_wrapper("callback", keras$callbacks$BackupAndRestore)

0 commit comments

Comments
 (0)