@@ -52,7 +52,7 @@ keras <- NULL
5252.onLoad <- function (libname , pkgname ) {
5353
5454 # tensorflow:::.onLoad() registers some reticulate class filter hooks
55- # we need to identify tensors reliably.
55+ # we need to identify tensorflow tensors reliably.
5656 requireNamespace(" tensorflow" , quietly = TRUE )
5757 maybe_register_S3_methods()
5858
@@ -61,13 +61,17 @@ keras <- NULL
6161 if (! is.null(keras_python ))
6262 Sys.setenv(RETICULATE_PYTHON = keras_python )
6363
64- # default backend is tensorflow for now
65- # the tensorflow R package calls `py_require()` to ensure GPU is usable on Linux
6664 py_require(c(
6765 " keras" , " pydot" , " scipy" , " pandas" , " Pillow" ,
68- " ipython" , " tensorflow_datasets"
66+ " ipython" # , "tensorflow_datasets"
6967 ))
7068
69+ # default backend is tensorflow for now
70+ # the tensorflow R package calls `py_require()` to ensure GPU is usable on Linux
71+ # use_backend() includes py_require(action = "remove") calls to undo
72+ # what tensorflow:::.onLoad() did. Keep them in sync!
73+ use_backend(Sys.getenv(" KERAS_BACKEND" , " tensorflow" ))
74+
7175 # delay load keras
7276 try(keras <<- import(" keras" , delay_load = list (
7377
0 commit comments