@@ -127,6 +127,7 @@ is_linux <- function() {
127127# ' Configure a Keras backend
128128# '
129129# ' @param backend string, can be `"tensorflow"`, `"jax"`, `"numpy"`, or `"torch"`.
130+ # ' @param gpu bool, whether to use the GPU.
130131# '
131132# ' @details
132133# ' These functions allow configuring which backend keras will use.
@@ -143,20 +144,153 @@ is_linux <- function() {
143144# ' ```
144145# ' @returns Called primarily for side effects. Returns the provided `backend`, invisibly.
145146# ' @export
146- use_backend <- function (backend ) {
147+ use_backend <- function (backend , gpu = NA ) {
147148
148149 if (is_keras_loaded()) {
149150 if (config_backend() != backend )
150151 stop(" The keras backend must be set before keras has inititialized. Please restart the R session." )
151152 }
153+
152154 Sys.setenv(KERAS_BACKEND = backend )
153155
154- if (reticulate :: py_available())
156+ if (reticulate :: py_available()) {
155157 reticulate :: import(" os" )$ environ $ update(list (KERAS_BACKEND = backend ))
158+ }
159+
160+
161+ switch (
162+ paste0(get_os(), " _" , backend ),
163+
164+ macOS_tensorflow = {
165+
166+ if (is.na(gpu ))
167+ gpu <- TRUE
168+
169+ if (gpu ) {
170+ py_require(" tensorflow" , action = " remove" )
171+ py_require(c(" tensorflow-macos" , " tensorflow-metal" ), python_version = " <3.12" )
172+ } else {
173+ py_require(action = " remove" , c(" tensorflow-macos" , " tensorflow-metal" ))
174+ py_require(" tensorflow" )
175+ }
176+
177+ },
178+
179+ macOS_jax = {
180+
181+ py_require(c(" tensorflow-metal" , " tensorflow-macos" ),
182+ action = " remove" )
183+
184+ if (is.na(gpu ))
185+ gpu <- TRUE
186+
187+ if (gpu ) {
188+ py_require(c(" tensorflow" , " jax" , " jax-metal" ))
189+ } else {
190+ py_require(" tensorflow" , " jax[cpu]" )
191+ }
192+ },
193+
194+ macOS_torch = {
195+ if (isTRUE(gpu ))
196+ warning(" GPU usage not supported on macOS. Please use a different backend to use the GPU (jax)" )
197+
198+ py_require(c(" tensorflow-metal" , " tensorflow-macos" ),
199+ action = " remove" )
200+
201+ py_require(c(" tensorflow" , " torch" , " torchvision" , " torchaudio" ))
202+ },
203+
204+ macOS_numpy = {
205+ py_require(c(" tensorflow-metal" , " tensorflow-macos" ), action = " remove" )
206+ py_require(c(" tensorflow" , " numpy" ))
207+ },
208+
209+ Linux_tensorflow = {
210+
211+ if (is.na(gpu ))
212+ gpu <- has_gpu()
213+
214+ if (gpu ) {
215+ py_require(action = " remove" , c(" tensorflow" , " tensorflow-cpu" ))
216+ py_require(" tensorflow[and-cuda]" )
217+ } else {
218+ py_require(action = " remove" , c(" tensorflow" , " tensorflow[and-cuda]" ))
219+ py_require(" tensorflow-cpu" )
220+ }
221+ },
222+
223+ Linux_jax = {
224+ py_require(c(" tensorflow" , " tensorflow[and-cuda]" ),
225+ action = " remove" )
226+
227+ if (is.na(gpu ))
228+ gpu <- has_gpu()
229+
230+ if (gpu ) {
231+ py_require(c(" tensorflow-cpu" , " jax[cuda12]" ))
232+ } else {
233+ py_require(c(" tensorflow-cpu" , " jax[cpu]" ))
234+ }
235+ },
236+
237+ Linux_torch = {
238+ py_require(c(" tensorflow" , " tensorflow[and-cuda]" ), action = " remove" )
239+
240+ if (is.na(gpu ))
241+ gpu <- has_gpu()
242+
243+ if (gpu ) {
244+ py_require(c(" tensorflow-cpu" , " torch" , " torchvision" , " torchaudio" ))
245+ } else {
246+ Sys.setenv(" UV_INDEX" = " https://download.pytorch.org/whl/cpu" )
247+ py_require(c(" tensorflow-cpu" , " torch" , " torchvision" , " torchaudio" ))
248+ # additional_args = c("--index", "https://download.pytorch.org/whl/cpu"))
249+ }
250+ },
251+
252+ Linux_numpy = {
253+ py_require(c(" tensorflow" , " tensorflow[and-cuda]" ), action = " remove" )
254+ py_require(c(" tensorflow-cpu" , " numpy" ))
255+ },
256+
257+ Windows_tensorflow = {
258+ if (isTRUE(gpu )) warning(" GPU usage not supported on Windows. Please use WSL." )
259+ py_require(" tensorflow" )
260+ },
261+
262+ Windows_jax = {
263+ if (isTRUE(gpu )) warning(" GPU usage not supported on Windows. Please use WSL." )
264+ py_require(" jax" )
265+ },
266+
267+ Windows_torch = {
268+ if (is.na(gpu ))
269+ gpu <- has_gpu()
270+
271+ if (gpu ) {
272+ Sys.setenv(" UV_INDEX" = " https://download.pytorch.org/whl/cu126" )
273+ py_require(c(" torch" , " torchvision" , " torchaudio" ))
274+ # additional_args = c("--index", "https://download.pytorch.org/whl/cu126"))
275+ } else {
276+ py_require(c(" torch" , " torchvision" , " torchaudio" ))
277+ }
278+ },
279+
280+ Windows_numpy = {
281+ py_require(" numpy" )
282+ }
283+ )
284+
156285 invisible (backend )
157286}
158287
159288
289+
290+ get_os <- function () {
291+ if (is_windows()) " Windows" else if (is_mac_arm64()) " macOS" else " Linux"
292+ }
293+
160294is_keras_loaded <- function () {
161295 # package .onLoad() has run (can be FALSE if in devtools::load_all())
162296 ! is.null(keras ) &&
@@ -171,6 +305,64 @@ is_keras_loaded <- function() {
171305}
172306
173307
308+ has_gpu <- function () {
309+
310+ has_nvidia_gpu <- function () {
311+ lspci_listed <- tryCatch(
312+ as.logical(length(
313+ system(" lspci | grep -i nvidia" , intern = TRUE )
314+ )),
315+ # warning emitted by system for non-0 exit status
316+ warning = function (w ) FALSE ,
317+ error = function (e ) FALSE
318+ )
319+
320+ if (lspci_listed )
321+ return (TRUE )
322+
323+ # lspci doens't list GPUs on WSL Linux, but nvidia-smi does.
324+ nvidia_smi_listed <- tryCatch(
325+ system(" nvidia-smi -L" , intern = TRUE ),
326+ warning = function (w ) character (),
327+ error = function (e ) character ()
328+ )
329+ if (isTRUE(any(grepl(" ^GPU [0-9]: " , nvidia_smi_listed ))))
330+ return (TRUE )
331+ FALSE
332+ }
333+
334+ is_linux() && has_nvidia_gpu()
335+
336+ }
337+
338+
339+ get_py_requirements <- function () {
340+ python_version <- " >=3.10"
341+ packages <- " tensorflow"
342+
343+ if (is_linux()) {
344+
345+ if (has_gpu()) {
346+ packages <- " tensorflow[and-cuda]"
347+ } else {
348+ packages <- " tensorflow-cpu"
349+ }
350+
351+ } else if (is_mac_arm64()) {
352+
353+ use_gpu <- FALSE
354+ if (use_gpu ) {
355+ packages <- c(" tensorflow-macos" , " tensorflow-metal" )
356+ python_version <- " >=3.9,<=3.11"
357+ }
358+
359+ } else if (is_windows()) {
360+
361+ }
362+
363+ list (packages = packages , python_version = python_version )
364+ }
365+
174366
175367python_module_dir <- function (python , module , stderr = TRUE ) {
176368
0 commit comments