Skip to content

Commit 45ca0fc

Browse files
committed
use return_dict=FALSE in evaluate() and related test functions
1 parent 2c99a4f commit 45ca0fc

File tree

3 files changed

+25
-11
lines changed

3 files changed

+25
-11
lines changed

R/model-training.R

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,17 @@ function (object, x = NULL, y = NULL, ..., batch_size = NULL,
269269
verbose = as_model_verbose_arg),
270270
ignore = "object",
271271
force = "verbose")
272-
args[["return_dict"]] <- TRUE
272+
args[["return_dict"]] <- FALSE
273273

274274
if(inherits(args$x, "tensorflow.python.data.ops.dataset_ops.DatasetV2") &&
275275
!is.null(args$batch_size))
276276
stop("batch_size can not be specified with a TF Dataset")
277277

278278
result <- do.call(object$evaluate, args)
279+
if (length(result) > 1L) {
280+
result <- as.list(result)
281+
names(result) <- object$metrics_names
282+
}
279283

280284
tfruns::write_run_metadata("evaluation", unlist(result))
281285

@@ -756,8 +760,15 @@ function (object, x, y = NULL, sample_weight = NULL, ...)
756760
{
757761
result <- object$test_on_batch(as_array(x),
758762
as_array(y),
759-
as_array(sample_weight), ..., return_dict = TRUE)
760-
if (is_scalar(result)) result[[1L]] else result
763+
as_array(sample_weight), ...,
764+
return_dict = FALSE)
765+
if (length(result) > 1L) {
766+
result <- as.list(result)
767+
names(result) <- object$metrics_names
768+
} else if (is_scalar(result)) {
769+
result <- result[[1L]]
770+
}
771+
result
761772
}
762773

763774
# ---- test_on_batch ----
@@ -813,8 +824,15 @@ function (object, x, y = NULL, sample_weight = NULL, class_weight = NULL)
813824
as_array(y),
814825
as_array(sample_weight),
815826
class_weight = as_class_weight(class_weight),
816-
return_dict = TRUE)
817-
if(is_scalar(result)) result[[1L]] else result
827+
return_dict = FALSE)
828+
if (length(result) > 1L) {
829+
result <- as.list(result)
830+
names(result) <- object$metrics_names
831+
} else if (is_scalar(result)) {
832+
result <- result[[1L]]
833+
}
834+
835+
result
818836
}
819837

820838

tools/archive/make.R

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,6 @@ if(!"source:tools/utils.R" %in% search()) envir::attach_source("tools/utils.R")
111111
# }
112112
#
113113
# TODO: layer_category_encoding()(count_weights) call arg example not working
114-
# TODO: backout usage of `return_dict=TRUE` in evaluate() and friends - the output order is not stable.
115-
# use `setNames(as.list())`
116-
# ## Deferred until upstream bug fixed,
117-
# ## model.metrics_names returns wrong result
118114

119115
## Docs ----
120116

vignettes-src/training_with_built_in_methods.Rmd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ when using built-in APIs for training & validation (such as `fit()`,
2323
`evaluate()` and `predict()`).
2424

2525
If you are interested in leveraging `fit()` while specifying your
26-
own training step function, see the
26+
own training step function, see the
2727
[Customizing what happens in `fit()` guide](custom_train_step_in_tensorflow.html).
2828

2929
<!-- guides on customizing what happens in `fit()`: -->
@@ -134,7 +134,7 @@ We evaluate the model on the test data via `evaluate()`:
134134
```{r}
135135
# Evaluate the model on the test data using `evaluate`
136136
results <- model |> evaluate(x_test, y_test, batch_size=128)
137-
results
137+
str(results)
138138
139139
# Generate predictions (probabilities -- the output of the last layer)
140140
# on new data using `predict`

0 commit comments

Comments
 (0)