Skip to content

Commit 98edbb6

Browse files
committed
reflect Python objects in keras_array()
closes: #1341
1 parent 13ae6cf commit 98edbb6

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
- `model$get_config()` method now returns an R object that can be safely serialized
1616
to rds.
17+
18+
- `keras_array()` now reflects unconverted Python objects. This enables passing
19+
objects like `pandas.Series()` to `fit()` and `evaluate()` methods. (#1341)
1720

1821
# keras 2.9.0
1922

R/utils.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,11 @@ keras_array <- function(x, dtype = NULL) {
278278
x <- as.list(x)
279279
}
280280

281+
# allow passing things like pandas.Series(), for workarounds like
282+
# https://github.com/rstudio/keras/issues/1341
283+
if(inherits(x, "python.builtin.object"))
284+
return(x)
285+
281286
# recurse for lists
282287
if (is.list(x))
283288
return(lapply(x, keras_array))

tests/testthat/test-model.R

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,41 @@ test_succeeds("can use functional api with dicts", {
362362
chk(inputs, outputs, unname(x)[c(3,1,2)], unname(y), error = TRUE)
363363

364364
})
365+
366+
367+
368+
test_succeeds("can pass pandas.Series() to fit()", {
369+
#https://github.com/rstudio/keras/issues/1341
370+
n <- 30
371+
p <- 10
372+
373+
w <- runif(n)
374+
y <- runif(n)
375+
X <- matrix(runif(n * p), ncol = p)
376+
377+
make_nn <- function() {
378+
input <- layer_input(p)
379+
output <- input %>%
380+
layer_dense(2 * p, activation = "tanh") %>%
381+
layer_dense(1)
382+
keras_model(inputs = input, outputs = output)
383+
}
384+
385+
nn <- make_nn()
386+
387+
pd <- reticulate::import("pandas", convert = FALSE)
388+
w <- pd$Series(w)
389+
390+
nn %>%
391+
compile(optimizer = optimizer_adam(0.02), loss = "mse",
392+
weighted_metrics = list()) %>% # silence warning
393+
fit(
394+
x = X,
395+
y = y,
396+
sample_weight = w,
397+
weighted_metrics = list(),
398+
epochs = 2,
399+
validation_split = 0.2,
400+
verbose = 0
401+
)
402+
})

0 commit comments

Comments
 (0)