Skip to content

Commit 442bad2

Browse files
committed
back to return_dict=TRUE
1 parent a0b6822 commit 442bad2

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

R/model-training.R

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -269,17 +269,20 @@ function (object, x = NULL, y = NULL, ..., batch_size = NULL,
269269
verbose = as_model_verbose_arg),
270270
ignore = "object",
271271
force = "verbose")
272-
args[["return_dict"]] <- FALSE
272+
273+
## return_dict=TRUE because object$metrics_names returns wrong value
274+
## (e.g., "compile_metrics" instead of "mae")
275+
args[["return_dict"]] <- TRUE
273276

274277
if(inherits(args$x, "tensorflow.python.data.ops.dataset_ops.DatasetV2") &&
275278
!is.null(args$batch_size))
276279
stop("batch_size can not be specified with a TF Dataset")
277280

278281
result <- do.call(object$evaluate, args)
279-
if (length(result) > 1L) {
280-
result <- as.list(result)
281-
names(result) <- object$metrics_names
282-
}
282+
# if (length(result) > 1L) { ## if return_dict=FALSE
283+
# result <- as.list(result)
284+
# names(result) <- object$metrics_names
285+
# }
283286

284287
tfruns::write_run_metadata("evaluation", unlist(result))
285288

@@ -761,11 +764,12 @@ function (object, x, y = NULL, sample_weight = NULL, ...)
761764
result <- object$test_on_batch(as_array(x),
762765
as_array(y),
763766
as_array(sample_weight), ...,
764-
return_dict = FALSE)
765-
if (length(result) > 1L) {
766-
result <- as.list(result)
767-
names(result) <- object$metrics_names
768-
} else if (is_scalar(result)) {
767+
return_dict = TRUE)
768+
# if (length(result) > 1L) {
769+
# result <- as.list(result)
770+
# names(result) <- object$metrics_names
771+
# } else
772+
if (is_scalar(result)) {
769773
result <- result[[1L]]
770774
}
771775
result
@@ -824,11 +828,12 @@ function (object, x, y = NULL, sample_weight = NULL, class_weight = NULL)
824828
as_array(y),
825829
as_array(sample_weight),
826830
class_weight = as_class_weight(class_weight),
827-
return_dict = FALSE)
828-
if (length(result) > 1L) {
829-
result <- as.list(result)
830-
names(result) <- object$metrics_names
831-
} else if (is_scalar(result)) {
831+
return_dict = TRUE)
832+
# if (length(result) > 1L) {
833+
# result <- as.list(result)
834+
# names(result) <- object$metrics_names
835+
# } else
836+
if (is_scalar(result)) {
832837
result <- result[[1L]]
833838
}
834839

0 commit comments

Comments
 (0)