Skip to content

Commit ae12926

Browse files
committed
use as_py_array() in loss functions
1 parent 241a8a7 commit ae12926

File tree

1 file changed

+42
-119
lines changed

1 file changed

+42
-119
lines changed

R/losses.R

Lines changed: 42 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,9 @@ loss_binary_crossentropy <-
118118
function (y_true, y_pred, from_logits = FALSE, label_smoothing = 0,
119119
axis = -1L, ..., reduction = "sum_over_batch_size", name = "binary_crossentropy")
120120
{
121-
args <- capture_args(list(axis = as_axis, y_true = function (x)
122-
if (is_py_object(x))
123-
x
124-
else np_array(x), y_pred = function (x)
125-
if (is_py_object(x))
126-
x
127-
else np_array(x)))
121+
args <- capture_args(list(axis = as_axis,
122+
y_true = as_py_array,
123+
y_pred = as_py_array))
128124
callable <- if (missing(y_true) && missing(y_pred))
129125
keras$losses$BinaryCrossentropy
130126
else keras$losses$binary_crossentropy
@@ -330,13 +326,9 @@ function (y_true, y_pred, apply_class_balancing = FALSE,
330326
alpha = 0.25, gamma = 2, from_logits = FALSE, label_smoothing = 0,
331327
axis = -1L, ..., reduction = "sum_over_batch_size", name = "binary_focal_crossentropy")
332328
{
333-
args <- capture_args(list(axis = as_axis, y_true = function (x)
334-
if (is_py_object(x))
335-
x
336-
else np_array(x), y_pred = function (x)
337-
if (is_py_object(x))
338-
x
339-
else np_array(x)))
329+
args <- capture_args(list(axis = as_axis,
330+
y_true = as_py_array,
331+
y_pred = as_py_array))
340332
callable <- if (missing(y_true) && missing(y_pred))
341333
keras$losses$BinaryFocalCrossentropy
342334
else keras$losses$binary_focal_crossentropy
@@ -440,13 +432,9 @@ loss_categorical_crossentropy <-
440432
function (y_true, y_pred, from_logits = FALSE, label_smoothing = 0,
441433
axis = -1L, ..., reduction = "sum_over_batch_size", name = "categorical_crossentropy")
442434
{
443-
args <- capture_args(list(axis = as_axis, y_true = function (x)
444-
if (is_py_object(x))
445-
x
446-
else np_array(x), y_pred = function (x)
447-
if (is_py_object(x))
448-
x
449-
else np_array(x)))
435+
args <- capture_args(list(axis = as_axis,
436+
y_true = as_py_array,
437+
y_pred = as_py_array))
450438
callable <- if (missing(y_true) && missing(y_pred))
451439
keras$losses$CategoricalCrossentropy
452440
else keras$losses$categorical_crossentropy
@@ -597,13 +585,9 @@ function (y_true, y_pred, alpha = 0.25, gamma = 2,
597585
from_logits = FALSE, label_smoothing = 0, axis = -1L, ...,
598586
reduction = "sum_over_batch_size", name = "categorical_focal_crossentropy")
599587
{
600-
args <- capture_args(list(axis = as_axis, y_true = function (x)
601-
if (is_py_object(x))
602-
x
603-
else np_array(x), y_pred = function (x)
604-
if (is_py_object(x))
605-
x
606-
else np_array(x)))
588+
args <- capture_args(list(axis = as_axis,
589+
y_true = as_py_array,
590+
y_pred = as_py_array))
607591
callable <- if (missing(y_true) && missing(y_pred))
608592
keras$losses$CategoricalFocalCrossentropy
609593
else keras$losses$categorical_focal_crossentropy
@@ -662,13 +646,8 @@ loss_categorical_hinge <-
662646
function (y_true, y_pred, ..., reduction = "sum_over_batch_size",
663647
name = "categorical_hinge")
664648
{
665-
args <- capture_args(list(y_true = function (x)
666-
if (is_py_object(x))
667-
x
668-
else np_array(x), y_pred = function (x)
669-
if (is_py_object(x))
670-
x
671-
else np_array(x)))
649+
args <- capture_args(list(y_true = as_py_array,
650+
y_pred = as_py_array))
672651
callable <- if (missing(y_true) && missing(y_pred))
673652
keras$losses$CategoricalHinge
674653
else keras$losses$categorical_hinge
@@ -735,13 +714,9 @@ loss_cosine_similarity <-
735714
function (y_true, y_pred, axis = -1L, ..., reduction = "sum_over_batch_size",
736715
name = "cosine_similarity")
737716
{
738-
args <- capture_args(list(axis = as_axis, y_true = function (x)
739-
if (is_py_object(x))
740-
x
741-
else np_array(x), y_pred = function (x)
742-
if (is_py_object(x))
743-
x
744-
else np_array(x)))
717+
args <- capture_args(list(axis = as_axis,
718+
y_true = as_py_array,
719+
y_pred = as_py_array))
745720
callable <- if (missing(y_true) && missing(y_pred))
746721
keras$losses$CosineSimilarity
747722
else keras$losses$cosine_similarity
@@ -849,13 +824,8 @@ loss_hinge <-
849824
function (y_true, y_pred, ..., reduction = "sum_over_batch_size",
850825
name = "hinge")
851826
{
852-
args <- capture_args(list(y_true = function (x)
853-
if (is_py_object(x))
854-
x
855-
else np_array(x), y_pred = function (x)
856-
if (is_py_object(x))
857-
x
858-
else np_array(x)))
827+
args <- capture_args(list(y_true = as_py_array,
828+
y_pred = as_py_array))
859829
callable <- if (missing(y_true) && missing(y_pred))
860830
keras$losses$Hinge
861831
else keras$losses$hinge
@@ -921,13 +891,8 @@ loss_huber <-
921891
function (y_true, y_pred, delta = 1, ..., reduction = "sum_over_batch_size",
922892
name = "huber_loss")
923893
{
924-
args <- capture_args(list(y_true = function (x)
925-
if (is_py_object(x))
926-
x
927-
else np_array(x), y_pred = function (x)
928-
if (is_py_object(x))
929-
x
930-
else np_array(x)))
894+
args <- capture_args(list(y_true = as_py_array,
895+
y_pred = as_py_array))
931896
callable <- if (missing(y_true) && missing(y_pred))
932897
keras$losses$Huber
933898
else keras$losses$huber
@@ -987,13 +952,9 @@ loss_kl_divergence <-
987952
function (y_true, y_pred, ..., reduction = "sum_over_batch_size",
988953
name = "kl_divergence")
989954
{
990-
args <- capture_args(list(y_true = function (x)
991-
if (is_py_object(x))
992-
x
993-
else np_array(x), y_pred = function (x)
994-
if (is_py_object(x))
995-
x
996-
else np_array(x)))
955+
args <- capture_args(list(axis = as_axis,
956+
y_true = as_py_array,
957+
y_pred = as_py_array))
997958
callable <- if (missing(y_true) && missing(y_pred))
998959
keras$losses$KLDivergence
999960
else keras$losses$kl_divergence
@@ -1053,13 +1014,8 @@ loss_log_cosh <-
10531014
function (y_true, y_pred, ..., reduction = "sum_over_batch_size",
10541015
name = "log_cosh")
10551016
{
1056-
args <- capture_args(list(y_true = function (x)
1057-
if (is_py_object(x))
1058-
x
1059-
else np_array(x), y_pred = function (x)
1060-
if (is_py_object(x))
1061-
x
1062-
else np_array(x)))
1017+
args <- capture_args(list(y_true = as_py_array,
1018+
y_pred = as_py_array))
10631019
callable <- if (missing(y_true) && missing(y_pred))
10641020
keras$losses$LogCosh
10651021
else keras$losses$log_cosh
@@ -1114,13 +1070,8 @@ loss_mean_absolute_error <-
11141070
function (y_true, y_pred, ..., reduction = "sum_over_batch_size",
11151071
name = "mean_absolute_error")
11161072
{
1117-
args <- capture_args(list(y_true = function (x)
1118-
if (is_py_object(x))
1119-
x
1120-
else np_array(x), y_pred = function (x)
1121-
if (is_py_object(x))
1122-
x
1123-
else np_array(x)))
1073+
args <- capture_args(list(y_true = as_py_array,
1074+
y_pred = as_py_array))
11241075
callable <- if (missing(y_true) && missing(y_pred))
11251076
keras$losses$MeanAbsoluteError
11261077
else keras$losses$mean_absolute_error
@@ -1180,13 +1131,8 @@ loss_mean_absolute_percentage_error <-
11801131
function (y_true, y_pred, ..., reduction = "sum_over_batch_size",
11811132
name = "mean_absolute_percentage_error")
11821133
{
1183-
args <- capture_args(list(y_true = function (x)
1184-
if (is_py_object(x))
1185-
x
1186-
else np_array(x), y_pred = function (x)
1187-
if (is_py_object(x))
1188-
x
1189-
else np_array(x)))
1134+
args <- capture_args(list(y_true = as_py_array,
1135+
y_pred = as_py_array))
11901136
callable <- if (missing(y_true) && missing(y_pred))
11911137
keras$losses$MeanAbsolutePercentageError
11921138
else keras$losses$mean_absolute_percentage_error
@@ -1241,13 +1187,8 @@ loss_mean_squared_error <-
12411187
function (y_true, y_pred, ..., reduction = "sum_over_batch_size",
12421188
name = "mean_squared_error")
12431189
{
1244-
args <- capture_args(list(y_true = function (x)
1245-
if (is_py_object(x))
1246-
x
1247-
else np_array(x), y_pred = function (x)
1248-
if (is_py_object(x))
1249-
x
1250-
else np_array(x)))
1190+
args <- capture_args(list(y_true = as_py_array,
1191+
y_pred = as_py_array))
12511192
callable <- if (missing(y_true) && missing(y_pred))
12521193
keras$losses$MeanSquaredError
12531194
else keras$losses$mean_squared_error
@@ -1306,13 +1247,8 @@ loss_mean_squared_logarithmic_error <-
13061247
function (y_true, y_pred, ..., reduction = "sum_over_batch_size",
13071248
name = "mean_squared_logarithmic_error")
13081249
{
1309-
args <- capture_args(list(y_true = function (x)
1310-
if (is_py_object(x))
1311-
x
1312-
else np_array(x), y_pred = function (x)
1313-
if (is_py_object(x))
1314-
x
1315-
else np_array(x)))
1250+
args <- capture_args(list(y_true = as_py_array,
1251+
y_pred = as_py_array))
13161252
callable <- if (missing(y_true) && missing(y_pred))
13171253
keras$losses$MeanSquaredLogarithmicError
13181254
else keras$losses$mean_squared_logarithmic_error
@@ -1368,13 +1304,8 @@ loss_poisson <-
13681304
function (y_true, y_pred, ..., reduction = "sum_over_batch_size",
13691305
name = "poisson")
13701306
{
1371-
args <- capture_args(list(y_true = function (x)
1372-
if (is_py_object(x))
1373-
x
1374-
else np_array(x), y_pred = function (x)
1375-
if (is_py_object(x))
1376-
x
1377-
else np_array(x)))
1307+
args <- capture_args(list(y_true = as_py_array,
1308+
y_pred = as_py_array))
13781309
callable <- if (missing(y_true) && missing(y_pred))
13791310
keras$losses$Poisson
13801311
else keras$losses$poisson
@@ -1485,13 +1416,10 @@ loss_sparse_categorical_crossentropy <-
14851416
function (y_true, y_pred, from_logits = FALSE, ignore_class = NULL,
14861417
axis = -1L, ..., reduction = "sum_over_batch_size", name = "sparse_categorical_crossentropy")
14871418
{
1488-
args <- capture_args(list(ignore_class = as_integer, y_true = function (x)
1489-
if (is_py_object(x))
1490-
x
1491-
else np_array(x), y_pred = function (x)
1492-
if (is_py_object(x))
1493-
x
1494-
else np_array(x), axis = as_axis))
1419+
args <- capture_args(list(ignore_class = as_integer,
1420+
axis = as_axis,
1421+
y_true = as_py_array,
1422+
y_pred = as_py_array))
14951423
callable <- if (missing(y_true) && missing(y_pred))
14961424
keras$losses$SparseCategoricalCrossentropy
14971425
else keras$losses$sparse_categorical_crossentropy
@@ -1551,13 +1479,8 @@ loss_squared_hinge <-
15511479
function (y_true, y_pred, ..., reduction = "sum_over_batch_size",
15521480
name = "squared_hinge")
15531481
{
1554-
args <- capture_args(list(y_true = function (x)
1555-
if (is_py_object(x))
1556-
x
1557-
else np_array(x), y_pred = function (x)
1558-
if (is_py_object(x))
1559-
x
1560-
else np_array(x)))
1482+
args <- capture_args(list(y_true = as_py_array,
1483+
y_pred = as_py_array))
15611484
callable <- if (missing(y_true) && missing(y_pred))
15621485
keras$losses$SquaredHinge
15631486
else keras$losses$squared_hinge

0 commit comments

Comments
 (0)