@@ -269,17 +269,20 @@ function (object, x = NULL, y = NULL, ..., batch_size = NULL,
269269 verbose = as_model_verbose_arg ),
270270 ignore = " object" ,
271271 force = " verbose" )
272- args [[" return_dict" ]] <- FALSE
272+
273+ # # return_dict=TRUE because object$metrics_names returns wrong value
274+ # # (e.g., "compile_metrics" instead of "mae")
275+ args [[" return_dict" ]] <- TRUE
273276
274277 if (inherits(args $ x , " tensorflow.python.data.ops.dataset_ops.DatasetV2" ) &&
275278 ! is.null(args $ batch_size ))
276279 stop(" batch_size can not be specified with a TF Dataset" )
277280
278281 result <- do.call(object $ evaluate , args )
279- if (length(result ) > 1L ) {
280- result <- as.list(result )
281- names(result ) <- object $ metrics_names
282- }
282+ # if (length(result) > 1L) { ## if return_dict=FALSE
283+ # result <- as.list(result)
284+ # names(result) <- object$metrics_names
285+ # }
283286
284287 tfruns :: write_run_metadata(" evaluation" , unlist(result ))
285288
@@ -761,11 +764,12 @@ function (object, x, y = NULL, sample_weight = NULL, ...)
761764 result <- object $ test_on_batch(as_array(x ),
762765 as_array(y ),
763766 as_array(sample_weight ), ... ,
764- return_dict = FALSE )
765- if (length(result ) > 1L ) {
766- result <- as.list(result )
767- names(result ) <- object $ metrics_names
768- } else if (is_scalar(result )) {
767+ return_dict = TRUE )
768+ # if (length(result) > 1L) {
769+ # result <- as.list(result)
770+ # names(result) <- object$metrics_names
771+ # } else
772+ if (is_scalar(result )) {
769773 result <- result [[1L ]]
770774 }
771775 result
@@ -824,11 +828,12 @@ function (object, x, y = NULL, sample_weight = NULL, class_weight = NULL)
824828 as_array(y ),
825829 as_array(sample_weight ),
826830 class_weight = as_class_weight(class_weight ),
827- return_dict = FALSE )
828- if (length(result ) > 1L ) {
829- result <- as.list(result )
830- names(result ) <- object $ metrics_names
831- } else if (is_scalar(result )) {
831+ return_dict = TRUE )
832+ # if (length(result) > 1L) {
833+ # result <- as.list(result)
834+ # names(result) <- object$metrics_names
835+ # } else
836+ if (is_scalar(result )) {
832837 result <- result [[1L ]]
833838 }
834839
0 commit comments