Skip to content

Commit 2439e0c

Browse files
Merge pull request #692 from rstudio/bugfix/tf1.13-r-custom-generator
Wrap R functions in a custom generator
2 parents ed8c28f + b483b61 commit 2439e0c

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

R/model.R

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,13 @@ as_generator.tensorflow.python.data.ops.dataset_ops.DatasetV2 <-
934934
as_generator.tensorflow.python.data.ops.dataset_ops.Dataset
935935

936936
as_generator.function <- function(x) {
937-
reticulate::py_iterator(function() keras_array(x()))
937+
python_path <- system.file("python", package = "keras")
938+
tools <- reticulate::import_from_path("kerastools", path = python_path)
939+
iter <- reticulate::py_iterator(function() {
940+
elem <- x()
941+
reticulate::tuple(elem[1], elem[2])
942+
})
943+
tools$generator$iter_generator(iter)
938944
}
939945

940946
as_generator.keras_preprocessing.sequence.TimeseriesGenerator <- function(x) {

inst/python/kerastools/generator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11

22
import itertools
33

4+
5+
def iter_generator(iter):
6+
7+
def gen():
8+
while 1:
9+
yield iter.next()
10+
11+
return gen()
12+
413
def dataset_generator(dataset, session):
514

615
iter = dataset.make_one_shot_iterator()

0 commit comments

Comments
 (0)