Skip to content

Commit 1fb96c3

Browse files
committed
.onLoad set backend
1 parent 1a8234c commit 1fb96c3

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

R/package.R

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ keras <- NULL
5252
.onLoad <- function(libname, pkgname) {
5353

5454
# tensorflow:::.onLoad() registers some reticulate class filter hooks
55-
# we need to identify tensors reliably.
55+
# we need to identify tensorflow tensors reliably.
5656
requireNamespace("tensorflow", quietly = TRUE)
5757
maybe_register_S3_methods()
5858

@@ -61,13 +61,17 @@ keras <- NULL
6161
if (!is.null(keras_python))
6262
Sys.setenv(RETICULATE_PYTHON = keras_python)
6363

64-
# default backend is tensorflow for now
65-
# the tensorflow R package calls `py_require()` to ensure GPU is usable on Linux
6664
py_require(c(
6765
"keras", "pydot", "scipy", "pandas", "Pillow",
68-
"ipython", "tensorflow_datasets"
66+
"ipython" #, "tensorflow_datasets"
6967
))
7068

69+
# default backend is tensorflow for now
70+
# the tensorflow R package calls `py_require()` to ensure GPU is usable on Linux
71+
# use_backend() includes py_require(action = "remove") calls to undo
72+
# what tensorflow:::.onLoad() did. Keep them in sync!
73+
use_backend(Sys.getenv("KERAS_BACKEND", "tensorflow"))
74+
7175
# delay load keras
7276
try(keras <<- import("keras", delay_load = list(
7377

0 commit comments

Comments
 (0)