Skip to content

Commit 142ddce

Browse files
authored
Merge pull request #1020 from dfalbel/bugfix/metrics-callback
Fixes metric callbacks with TF 2.2
2 parents 0e90135 + e7805aa commit 142ddce

File tree

9 files changed

+188
-23
lines changed

9 files changed

+188
-23
lines changed

.github/workflows/main.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
fail-fast: false
1313
matrix:
1414
os: ['windows-latest', 'macOS-latest', 'ubuntu-16.04']
15-
tf: ['1.14.0', '1.15.2', '2.0.1', '2.1.0', 'nightly']
15+
tf: ['1.14.0', '1.15.2', '2.0.1', '2.1.0', '2.2-rc3', 'nightly']
1616
include:
1717
- os: ubuntu-16.04
1818
cran: https://demo.rstudiopm.com/all/__linux__/xenial/latest

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## Development Version
22

33
- Added `layer_attention` (#1000) by @atroiano.
4+
- Fixed issue regarding the KerasMetricsCallback with TF v2.2 (#1020)
45

56
## Keras 2.2.5.0 (CRAN)
67

R/applications.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,9 @@ imagenet_preprocess_input <- function(x, data_format = NULL, mode = "caffe") {
314314
)
315315
if (keras_version() >= "2.0.9") {
316316
args$data_format <- data_format
317-
args$mode <- mode
317+
# no longer exists in 2.2
318+
if (tensorflow::tf_version() <= "2.1")
319+
args$mode <- mode
318320
}
319321
do.call(preprocess_input, args)
320322
}

R/callbacks.R

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,12 @@ normalize_callbacks_with_metrics <- function(view_metrics, callbacks) {
627627
callbacks <- list(callbacks)
628628

629629
# always include the metrics callback
630-
callbacks <- append(callbacks, KerasMetricsCallback$new(view_metrics))
630+
if (tensorflow::tf_version() >= "2.2.0")
631+
metrics_callback <- KerasMetricsCallbackV2$new(view_metrics)
632+
else
633+
metrics_callback <- KerasMetricsCallback$new(view_metrics)
634+
635+
callbacks <- append(callbacks, metrics_callback)
631636

632637
normalize_callbacks(callbacks)
633638
}
@@ -717,11 +722,11 @@ normalize_callbacks <- function(callbacks) {
717722
)
718723

719724
# on_batch_* -> on_train_batch_*
720-
if (!identical(callback$on_batch_begin, empty_fun)) {
725+
if (!isTRUE(all.equal(callback$on_batch_begin, empty_fun))) {
721726
args$r_on_train_batch_begin <- callback$on_batch_begin
722727
}
723728

724-
if (!identical(callback$on_batch_end, empty_fun)) {
729+
if (!isTRUE(all.equal(callback$on_batch_end, empty_fun))) {
725730
args$r_on_train_batch_end <- callback$on_batch_end
726731
}
727732

R/history.R

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,17 @@ plot.keras_training_history <- function(x, y, metrics = NULL, method = c("auto",
7979
# select the correct metrics
8080
df <- df[df$metric %in% metrics, ]
8181

82+
if (tensorflow::tf_version() < "2.2")
83+
do_validation <- x$params$do_validation
84+
else
85+
do_validation <- any(grepl("^val_", names(x$metrics)))
86+
87+
8288
if (method == "ggplot2") {
8389
# helper function for correct breaks (integers only)
8490
int_breaks <- function(x) pretty(x)[pretty(x) %% 1 == 0]
8591

86-
if (x$params$do_validation) {
92+
if (do_validation) {
8793
if (theme_bw)
8894
p <- ggplot2::ggplot(df, ggplot2::aes_(~epoch, ~value, color = ~data, fill = ~data, linetype = ~data, shape = ~data))
8995
else
@@ -151,7 +157,7 @@ plot.keras_training_history <- function(x, y, metrics = NULL, method = c("auto",
151157
legend_location <- ifelse(
152158
df2[df2$data == 'training', 'value'][1] > df2[df2$data == 'training', 'value'][x$params$epochs],
153159
"topright", "bottomright")
154-
if (x$params$do_validation)
160+
if (do_validation)
155161
graphics::legend(legend_location, legend = c(metric, paste0("val_", metric)), pch = c(1, 4))
156162
else
157163
graphics::legend(legend_location, legend = metric, pch = 1)
@@ -164,7 +170,8 @@ plot.keras_training_history <- function(x, y, metrics = NULL, method = c("auto",
164170
as.data.frame.keras_training_history <- function(x, ...) {
165171

166172
# filter out metrics that were collected for callbacks (e.g. lr)
167-
x$metrics <- x$metrics[x$params$metrics]
173+
if (tensorflow::tf_version() < "2.2")
174+
x$metrics <- x$metrics[x$params$metrics]
168175

169176
# pad to epochs if necessary
170177
values <- x$metrics
@@ -193,15 +200,21 @@ as.data.frame.keras_training_history <- function(x, ...) {
193200

194201
to_keras_training_history <- function(history) {
195202

203+
196204
# turn history into an R object so it can be persited and
197205
# and give it a class so we can write print/plot methods
198206
params <- history$params
199-
if (params$do_validation) {
200-
if (!is.null(params$validation_steps))
201-
params$validation_samples <- params$validation_steps
202-
else
203-
params$validation_samples <- dim(history$validation_data[[1]])[[1]]
207+
208+
# we only see this info before TF 2.2
209+
if (tensorflow::tf_version() < "2.2") {
210+
if (params$do_validation) {
211+
if (!is.null(params$validation_steps))
212+
params$validation_samples <- params$validation_steps
213+
else
214+
params$validation_samples <- dim(history$validation_data[[1]])[[1]]
215+
}
204216
}
217+
205218
# normalize metrics
206219
metrics <- history$history
207220
metrics <- lapply(metrics, function(metric) {

R/install.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ install_keras <- function(method = c("auto", "virtualenv", "conda"),
145145
paste0("keras", version),
146146
extra_packages,
147147
"h5py",
148-
"pyyaml",
148+
"pyyaml==3.12",
149149
"requests",
150150
"Pillow",
151151
"scipy"

R/metrics-callback.R

Lines changed: 143 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
2-
3-
KerasMetricsCallback <- R6::R6Class("KerasMetricsCallback",
4-
1+
KerasMetricsCallback <- R6::R6Class(
2+
"KerasMetricsCallback",
3+
54
inherit = KerasCallback,
65

76
public = list(
@@ -87,7 +86,7 @@ KerasMetricsCallback <- R6::R6Class("KerasMetricsCallback",
8786

8887
# convert keras history to metrics data frame suitable for plotting
8988
as_metrics_df = function(history) {
90-
89+
9190
# create metrics data frame
9291
df <- as.data.frame(history$metrics)
9392

@@ -128,8 +127,146 @@ KerasMetricsCallback <- R6::R6Class("KerasMetricsCallback",
128127
}, error = function(e) {
129128
warning("Unable to log model info: ", e$message, call. = FALSE)
130129
})
131-
130+
131+
}
132+
)
133+
)
134+
135+
KerasMetricsCallbackV2 <- R6::R6Class(
136+
"KerasMetricsCallbackV2",
137+
138+
inherit = KerasCallback,
139+
140+
public = list(
141+
142+
# instance data
143+
metrics = list(),
144+
metrics_viewer = NULL,
145+
view_metrics = FALSE,
146+
147+
initialize = function(view_metrics = FALSE) {
148+
self$view_metrics <- view_metrics
149+
},
150+
151+
on_train_begin = function(logs = NULL) {
152+
if (tfruns::is_run_active()) {
153+
self$write_params(self$params)
154+
self$write_model_info(self$model)
155+
}
156+
},
157+
158+
on_epoch_end = function(epoch, logs = NULL) {
159+
160+
if (epoch == 0) {
161+
162+
metric_names <- names(logs)
163+
for (metric in metric_names)
164+
self$metrics[[metric]] <- numeric()
165+
166+
sleep <- 0.5
167+
} else {
168+
169+
sleep <- 0.1
170+
171+
}
172+
173+
# handle metrics
174+
self$on_metrics(logs, sleep)
175+
176+
},
177+
178+
on_metrics = function(logs, sleep) {
179+
180+
# record metrics
181+
for (metric in names(self$metrics)) {
182+
# guard against metrics not yet available by using NA
183+
# when a named metrics isn't passed in 'logs'
184+
value <- logs[[metric]]
185+
if (is.null(value))
186+
value <- NA
187+
else
188+
value <- mean(value)
189+
190+
self$metrics[[metric]] <- c(self$metrics[[metric]], value)
191+
}
192+
193+
# create history object and convert to metrics data frame
194+
195+
history <- keras_training_history(self$params, self$metrics)
196+
metrics <- self$as_metrics_df(history)
197+
198+
# view metrics if requested
199+
if (self$view_metrics) {
200+
201+
# create the metrics_viewer or update if we already have one
202+
if (is.null(self$metrics_viewer)) {
203+
self$metrics_viewer <- tfruns::view_run_metrics(metrics)
204+
} else {
205+
tfruns::update_run_metrics(self$metrics_viewer, metrics)
206+
}
207+
208+
# pump events
209+
Sys.sleep(sleep)
210+
}
211+
212+
# record metrics
213+
tfruns::write_run_metadata("metrics", metrics)
214+
215+
},
216+
217+
# convert keras history to metrics data frame suitable for plotting
218+
as_metrics_df = function(history) {
219+
# create metrics data frame
220+
df <- as.data.frame(history$metrics)
221+
222+
# pad to epochs if necessary
223+
pad <- history$params$epochs - nrow(df)
224+
pad_data <- list()
225+
226+
if (tensorflow::tf_version() < "2.2")
227+
metric_names <- history$params$metrics
228+
else
229+
metric_names <- names(history$metrics)
230+
231+
for (metric in metric_names)
232+
pad_data[[metric]] <- rep_len(NA, pad)
233+
234+
df <- rbind(df, pad_data)
235+
236+
# return df
237+
df
238+
},
239+
240+
write_params = function(params) {
241+
properties <- list()
242+
properties$samples <- params$samples
243+
properties$validation_samples <- params$validation_samples
244+
properties$epochs <- params$epochs
245+
properties$batch_size <- params$batch_size
246+
tfruns::write_run_metadata("properties", properties)
247+
},
248+
249+
write_model_info = function(model) {
250+
tryCatch({
251+
model_info <- list()
252+
model_info$model <- py_str(model, line_length = 80L)
253+
if (is.character(model$loss))
254+
model_info$loss_function <- model$loss
255+
else if (inherits(model$loss, "python.builtin.function"))
256+
model_info$loss_function <- model$loss$`__name__`
257+
optimizer <- model$optimizer
258+
if (!is.null(optimizer)) {
259+
model_info$optimizer <- py_str(optimizer)
260+
model_info$learning_rate <- k_eval(optimizer$lr)
261+
}
262+
tfruns::write_run_metadata("properties", model_info)
263+
}, error = function(e) {
264+
warning("Unable to log model info: ", e$message, call. = FALSE)
265+
})
266+
132267
}
133268
)
134269
)
135270

271+
272+

R/model.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,11 @@ fit.keras.engine.training.Model <-
447447
dataset <- resolve_tensorflow_dataset(x)
448448
if (inherits(dataset, "tensorflow.python.data.ops.dataset_ops.DatasetV2")) {
449449
args$x <- dataset
450+
451+
if (!is.null(batch_size))
452+
stop("You should not specify a `batch_size` if using a tfdataset.",
453+
call. = FALSE)
454+
450455
} else if (!is.null(dataset)) {
451456
args$x <- dataset[[1]]
452457
args$y <- dataset[[2]]
@@ -1146,8 +1151,6 @@ py_str.keras.engine.training.Model <- function(object, line_length = getOption(
11461151
# determine whether to view metrics or not
11471152
resolve_view_metrics <- function(verbose, epochs, metrics) {
11481153
(epochs > 1) && # more than 1 epoch
1149-
!is.null(metrics) && # have metrics
1150-
(length(metrics) > 0) && # capturing at least one metric
11511154
(verbose > 0) && # verbose mode is on
11521155
!is.null(getOption("viewer")) && # have an internal viewer available
11531156
nzchar(Sys.getenv("RSTUDIO")) # running under RStudio

tests/testthat/test-callbacks.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ test_callback <- function(name, callback, h5py = FALSE, required_version = NULL)
2626
})
2727
}
2828

29-
test_callback("progbar_logger", callback_progbar_logger())
29+
# disable progbar test as per: https://github.com/tensorflow/tensorflow/issues/38618#issuecomment-617907735
30+
if (tensorflow::tf_version() <= "2.1")
31+
test_callback("progbar_logger", callback_progbar_logger())
32+
33+
3034
test_callback("model_checkpoint", callback_model_checkpoint(tempfile(fileext = ".h5")), h5py = TRUE)
3135
test_callback("learning_rate_scheduler", callback_learning_rate_scheduler(schedule = function (index, ...) {
3236
0.1

0 commit comments

Comments
 (0)