Skip to content

Commit 2c99a4f

Browse files
committed
use as_py_array() in metrics
1 parent ae12926 commit 2c99a4f

File tree

1 file changed

+44
-153
lines changed

1 file changed

+44
-153
lines changed

R/metrics.R

Lines changed: 44 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,9 @@ metric_binary_focal_crossentropy <-
7474
function (y_true, y_pred, apply_class_balancing = FALSE, alpha = 0.25,
7575
gamma = 2, from_logits = FALSE, label_smoothing = 0, axis = -1L)
7676
{
77-
args <- capture_args(list(
78-
y_true = function (x)
79-
if (is_py_object(x)) x
80-
else np_array(x),
81-
y_pred = function (x)
82-
if (is_py_object(x)) x
83-
else np_array(x), axis = as_axis)
84-
)
77+
args <- capture_args(list(axis = as_axis,
78+
y_true = as_py_array,
79+
y_pred = as_py_array))
8580
do.call(keras$metrics$binary_focal_crossentropy, args)
8681
}
8782

@@ -145,13 +140,9 @@ metric_categorical_focal_crossentropy <-
145140
function (y_true, y_pred, alpha = 0.25, gamma = 2, from_logits = FALSE,
146141
label_smoothing = 0, axis = -1L)
147142
{
148-
args <- capture_args(list(y_true = function (x)
149-
if (is_py_object(x))
150-
x
151-
else np_array(x), y_pred = function (x)
152-
if (is_py_object(x))
153-
x
154-
else np_array(x), axis = as_axis))
143+
args <- capture_args(list(axis = as_axis,
144+
y_true = as_py_array,
145+
y_pred = as_py_array))
155146
do.call(keras$metrics$categorical_focal_crossentropy, args)
156147
}
157148

@@ -203,13 +194,7 @@ function (y_true, y_pred, alpha = 0.25, gamma = 2, from_logits = FALSE,
203194
metric_huber <-
204195
function (y_true, y_pred, delta = 1)
205196
{
206-
args <- capture_args(list(y_true = function (x)
207-
if (is_py_object(x))
208-
x
209-
else np_array(x), y_pred = function (x)
210-
if (is_py_object(x))
211-
x
212-
else np_array(x)))
197+
args <- capture_args(list(y_true = as_py_array, y_pred = as_py_array))
213198
do.call(keras$metrics$huber, args)
214199
}
215200

@@ -255,13 +240,7 @@ function (y_true, y_pred, delta = 1)
255240
metric_log_cosh <-
256241
function (y_true, y_pred)
257242
{
258-
args <- capture_args(list(y_true = function (x)
259-
if (is_py_object(x))
260-
x
261-
else np_array(x), y_pred = function (x)
262-
if (is_py_object(x))
263-
x
264-
else np_array(x)))
243+
args <- capture_args(list(y_true = as_py_array, y_pred = as_py_array))
265244
do.call(keras$metrics$log_cosh, args)
266245
}
267246

@@ -339,13 +318,7 @@ metric_binary_accuracy <-
339318
function (y_true, y_pred, threshold = 0.5, ..., name = "binary_accuracy",
340319
dtype = NULL)
341320
{
342-
args <- capture_args(list(y_true = function (x)
343-
if (is_py_object(x))
344-
x
345-
else np_array(x), y_pred = function (x)
346-
if (is_py_object(x))
347-
x
348-
else np_array(x)))
321+
args <- capture_args(list(y_true = as_py_array, y_pred = as_py_array))
349322
callable <- if (missing(y_true) && missing(y_pred))
350323
keras$metrics$BinaryAccuracy
351324
else keras$metrics$binary_accuracy
@@ -426,13 +399,7 @@ metric_categorical_accuracy <-
426399
function (y_true, y_pred, ..., name = "categorical_accuracy",
427400
dtype = NULL)
428401
{
429-
args <- capture_args(list(y_true = function (x)
430-
if (is_py_object(x))
431-
x
432-
else np_array(x), y_pred = function (x)
433-
if (is_py_object(x))
434-
x
435-
else np_array(x)))
402+
args <- capture_args(list(y_true = as_py_array, y_pred = as_py_array))
436403
callable <- if (missing(y_true) && missing(y_pred))
437404
keras$metrics$CategoricalAccuracy
438405
else keras$metrics$categorical_accuracy
@@ -510,13 +477,8 @@ metric_sparse_categorical_accuracy <-
510477
function (y_true, y_pred, ..., name = "sparse_categorical_accuracy",
511478
dtype = NULL)
512479
{
513-
args <- capture_args(list(y_true = function (x)
514-
if (is_py_object(x))
515-
x
516-
else np_array(x), y_pred = function (x)
517-
if (is_py_object(x))
518-
x
519-
else np_array(x)))
480+
args <- capture_args(list(y_true = as_py_array,
481+
y_pred = as_py_array))
520482
callable <- if (missing(y_true) && missing(y_pred))
521483
keras$metrics$SparseCategoricalAccuracy
522484
else keras$metrics$sparse_categorical_accuracy
@@ -590,13 +552,9 @@ metric_sparse_top_k_categorical_accuracy <-
590552
function (y_true, y_pred, k = 5L, ..., name = "sparse_top_k_categorical_accuracy",
591553
dtype = NULL)
592554
{
593-
args <- capture_args(list(k = as_integer, y_true = function (x)
594-
if (is_py_object(x))
595-
x
596-
else np_array(x), y_pred = function (x)
597-
if (is_py_object(x))
598-
x
599-
else np_array(x)))
555+
args <- capture_args(list(k = as_integer,
556+
y_true = as_py_array,
557+
y_pred = as_py_array))
600558
callable <- if (missing(y_true) && missing(y_pred))
601559
keras$metrics$SparseTopKCategoricalAccuracy
602560
else keras$metrics$sparse_top_k_categorical_accuracy
@@ -669,13 +627,9 @@ metric_top_k_categorical_accuracy <-
669627
function (y_true, y_pred, k = 5L, ..., name = "top_k_categorical_accuracy",
670628
dtype = NULL)
671629
{
672-
args <- capture_args(list(
673-
k = as_integer,
674-
y_true = function(x)
675-
if (is_py_object(x)) x else np_array(x),
676-
y_pred = function(x)
677-
if (is_py_object(x)) x else np_array(x)
678-
))
630+
args <- capture_args(list(k = as_integer,
631+
y_true = as_py_array,
632+
y_pred = as_py_array))
679633
callable <- if (missing(y_true) && missing(y_pred))
680634
keras$metrics$TopKCategoricalAccuracy
681635
else keras$metrics$top_k_categorical_accuracy
@@ -1891,13 +1845,7 @@ metric_categorical_hinge <-
18911845
function (y_true, y_pred, ..., name = "categorical_hinge",
18921846
dtype = NULL)
18931847
{
1894-
args <- capture_args(list(y_true = function (x)
1895-
if (is_py_object(x))
1896-
x
1897-
else np_array(x), y_pred = function (x)
1898-
if (is_py_object(x))
1899-
x
1900-
else np_array(x)))
1848+
args <- capture_args(list(y_true = as_py_array, y_pred = as_py_array))
19011849
callable <- if (missing(y_true) && missing(y_pred))
19021850
keras$metrics$CategoricalHinge
19031851
else keras$metrics$categorical_hinge
@@ -1961,13 +1909,7 @@ function (y_true, y_pred, ..., name = "categorical_hinge",
19611909
metric_hinge <-
19621910
function (y_true, y_pred, ..., name = "hinge", dtype = NULL)
19631911
{
1964-
args <- capture_args(list(y_true = function (x)
1965-
if (is_py_object(x))
1966-
x
1967-
else np_array(x), y_pred = function (x)
1968-
if (is_py_object(x))
1969-
x
1970-
else np_array(x)))
1912+
args <- capture_args(list(y_true = as_py_array, y_pred = as_py_array))
19711913
callable <- if (missing(y_true) && missing(y_pred))
19721914
keras$metrics$Hinge
19731915
else keras$metrics$hinge
@@ -2032,13 +1974,7 @@ metric_squared_hinge <-
20321974
function (y_true, y_pred, ..., name = "squared_hinge",
20331975
dtype = NULL)
20341976
{
2035-
args <- capture_args(list(y_true = function (x)
2036-
if (is_py_object(x))
2037-
x
2038-
else np_array(x), y_pred = function (x)
2039-
if (is_py_object(x))
2040-
x
2041-
else np_array(x)))
1977+
args <- capture_args(list(y_true = as_py_array, y_pred = as_py_array))
20421978
callable <- if (missing(y_true) && missing(y_pred))
20431979
keras$metrics$SquaredHinge
20441980
else keras$metrics$squared_hinge
@@ -2450,9 +2386,10 @@ metric_one_hot_iou <-
24502386
function (..., num_classes, target_class_ids, name = NULL, dtype = NULL,
24512387
ignore_class = NULL, sparse_y_pred = FALSE, axis = -1L)
24522388
{
2453-
args <- capture_args(list(ignore_class = as_integer, axis = as_axis,
2454-
num_classes = as_integer, target_class_ids = function (x)
2455-
lapply(x, as_integer)))
2389+
args <- capture_args(list(
2390+
ignore_class = as_integer,
2391+
axis = as_axis, num_classes = as_integer,
2392+
target_class_ids = function (x) lapply(x, as_integer)))
24562393
do.call(keras$metrics$OneHotIoU, args)
24572394
}
24582395

@@ -2633,14 +2570,9 @@ metric_binary_crossentropy <-
26332570
function (y_true, y_pred, from_logits = FALSE, label_smoothing = 0,
26342571
axis = -1L, ..., name = "binary_crossentropy", dtype = NULL)
26352572
{
2636-
args <- capture_args(list(label_smoothing = as_integer,
2637-
y_true = function (x)
2638-
if (is_py_object(x))
2639-
x
2640-
else np_array(x), y_pred = function (x)
2641-
if (is_py_object(x))
2642-
x
2643-
else np_array(x), axis = as_axis))
2573+
args <- capture_args(list(axis = as_axis,
2574+
y_true = as_py_array,
2575+
y_pred = as_py_array))
26442576
callable <- if (missing(y_true) && missing(y_pred))
26452577
keras$metrics$BinaryCrossentropy
26462578
else keras$metrics$binary_crossentropy
@@ -2736,14 +2668,9 @@ metric_categorical_crossentropy <-
27362668
function (y_true, y_pred, from_logits = FALSE, label_smoothing = 0,
27372669
axis = -1L, ..., name = "categorical_crossentropy", dtype = NULL)
27382670
{
2739-
args <- capture_args(list(label_smoothing = as_integer,
2740-
axis = as_axis, y_true = function (x)
2741-
if (is_py_object(x))
2742-
x
2743-
else np_array(x), y_pred = function (x)
2744-
if (is_py_object(x))
2745-
x
2746-
else np_array(x)))
2671+
args <- capture_args(list(axis = as_axis,
2672+
y_true = as_py_array,
2673+
y_pred = as_py_array))
27472674
callable <- if (missing(y_true) && missing(y_pred))
27482675
keras$metrics$CategoricalCrossentropy
27492676
else keras$metrics$categorical_crossentropy
@@ -2817,13 +2744,7 @@ metric_kl_divergence <-
28172744
function (y_true, y_pred, ..., name = "kl_divergence",
28182745
dtype = NULL)
28192746
{
2820-
args <- capture_args(list(y_true = function (x)
2821-
if (is_py_object(x))
2822-
x
2823-
else np_array(x), y_pred = function (x)
2824-
if (is_py_object(x))
2825-
x
2826-
else np_array(x)))
2747+
args <- capture_args(list(y_true = as_py_array, y_pred = as_py_array))
28272748
callable <- if (missing(y_true) && missing(y_pred))
28282749
keras$metrics$KLDivergence
28292750
else keras$metrics$kl_divergence
@@ -2895,13 +2816,7 @@ function (y_true, y_pred, ..., name = "kl_divergence",
28952816
metric_poisson <-
28962817
function (y_true, y_pred, ..., name = "poisson", dtype = NULL)
28972818
{
2898-
args <- capture_args(list(y_true = function (x)
2899-
if (is_py_object(x))
2900-
x
2901-
else np_array(x), y_pred = function (x)
2902-
if (is_py_object(x))
2903-
x
2904-
else np_array(x)))
2819+
args <- capture_args(list(y_true = as_py_array, y_pred = as_py_array))
29052820
callable <- if (missing(y_true) && missing(y_pred))
29062821
keras$metrics$Poisson
29072822
else keras$metrics$poisson
@@ -2994,13 +2909,10 @@ function (y_true, y_pred, from_logits = FALSE, ignore_class = NULL,
29942909
axis = -1L, ..., name = "sparse_categorical_crossentropy",
29952910
dtype = NULL)
29962911
{
2997-
args <- capture_args(list(axis = as_axis, y_true = function (x)
2998-
if (is_py_object(x))
2999-
x
3000-
else np_array(x), y_pred = function (x)
3001-
if (is_py_object(x))
3002-
x
3003-
else np_array(x), ignore_class = as_integer))
2912+
args <- capture_args(list(axis = as_axis,
2913+
ignore_class = as_integer,
2914+
y_true = as_py_array,
2915+
y_pred = as_py_array))
30042916
callable <- if (missing(y_true) && missing(y_pred))
30052917
keras$metrics$SparseCategoricalCrossentropy
30062918
else keras$metrics$sparse_categorical_crossentropy
@@ -3345,13 +3257,7 @@ metric_mean_absolute_error <-
33453257
function (y_true, y_pred, ..., name = "mean_absolute_error",
33463258
dtype = NULL)
33473259
{
3348-
args <- capture_args(list(y_true = function (x)
3349-
if (is_py_object(x))
3350-
x
3351-
else np_array(x), y_pred = function (x)
3352-
if (is_py_object(x))
3353-
x
3354-
else np_array(x)))
3260+
args <- capture_args(list(y_true = as_py_array, y_pred = as_py_array))
33553261
callable <- if (missing(y_true) && missing(y_pred))
33563262
keras$metrics$MeanAbsoluteError
33573263
else keras$metrics$mean_absolute_error
@@ -3424,13 +3330,8 @@ metric_mean_absolute_percentage_error <-
34243330
function (y_true, y_pred, ..., name = "mean_absolute_percentage_error",
34253331
dtype = NULL)
34263332
{
3427-
args <- capture_args(list(y_true = function (x)
3428-
if (is_py_object(x))
3429-
x
3430-
else np_array(x), y_pred = function (x)
3431-
if (is_py_object(x))
3432-
x
3433-
else np_array(x)))
3333+
args <- capture_args(list(y_true = as_py_array,
3334+
y_pred = as_py_array))
34343335
callable <- if (missing(y_true) && missing(y_pred))
34353336
keras$metrics$MeanAbsolutePercentageError
34363337
else keras$metrics$mean_absolute_percentage_error
@@ -3484,13 +3385,8 @@ metric_mean_squared_error <-
34843385
function (y_true, y_pred, ..., name = "mean_squared_error",
34853386
dtype = NULL)
34863387
{
3487-
args <- capture_args(list(y_true = function (x)
3488-
if (is_py_object(x))
3489-
x
3490-
else np_array(x), y_pred = function (x)
3491-
if (is_py_object(x))
3492-
x
3493-
else np_array(x)))
3388+
args <- capture_args(list(y_true = as_py_array,
3389+
y_pred = as_py_array))
34943390
callable <- if (missing(y_true) && missing(y_pred))
34953391
keras$metrics$MeanSquaredError
34963392
else keras$metrics$mean_squared_error
@@ -3563,13 +3459,8 @@ metric_mean_squared_logarithmic_error <-
35633459
function (y_true, y_pred, ..., name = "mean_squared_logarithmic_error",
35643460
dtype = NULL)
35653461
{
3566-
args <- capture_args(list(y_true = function (x)
3567-
if (is_py_object(x))
3568-
x
3569-
else np_array(x), y_pred = function (x)
3570-
if (is_py_object(x))
3571-
x
3572-
else np_array(x)))
3462+
args <- capture_args(list(y_true = as_py_array,
3463+
y_pred = as_py_array))
35733464
callable <- if (missing(y_true) && missing(y_pred))
35743465
keras$metrics$MeanSquaredLogarithmicError
35753466
else keras$metrics$mean_squared_logarithmic_error

0 commit comments

Comments
 (0)