Skip to content

Commit 2f5a8aa

Browse files
authored
Merge pull request #1489 from rstudio/use-py_require
Use `py_require()` to declare Python dependencies.
2 parents 44cac16 + 5608827 commit 2f5a8aa

File tree

5 files changed

+213
-4
lines changed

5 files changed

+213
-4
lines changed

DESCRIPTION

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Depends:
2929
R (>= 4.0)
3030
Imports:
3131
generics (>= 0.0.1),
32-
reticulate (>= 1.36.0),
32+
reticulate (>= 1.40.0.9000),
3333
tensorflow (>= 2.16.0),
3434
tfruns (>= 1.5.2),
3535
magrittr,
@@ -54,3 +54,5 @@ Suggests:
5454
jpeg
5555
RoxygenNote: 7.3.2
5656
VignetteBuilder: knitr
57+
Remotes:
58+
rstudio/reticulate

NEWS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# keras3 (development version)
22

3+
- Keras now uses `reticulate::py_require()` to resolve Python dependencies.
4+
Calling `install_keras()` is no longer required (but is still supported).
5+
6+
- `use_backend()` gains a `gpu` argument, to specify if a GPU-capable set of
7+
dependencies should be resolved by `py_require()`.
8+
39
- The progress bar in `fit()`, `evaluate()` and `predict()` now
410
defaults to not presenting during testthat tests.
511

R/install.R

Lines changed: 194 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ is_linux <- function() {
127127
#' Configure a Keras backend
128128
#'
129129
#' @param backend string, can be `"tensorflow"`, `"jax"`, `"numpy"`, or `"torch"`.
130+
#' @param gpu bool, whether to use the GPU.
130131
#'
131132
#' @details
132133
#' These functions allow configuring which backend keras will use.
@@ -143,20 +144,153 @@ is_linux <- function() {
143144
#' ```
144145
#' @returns Called primarily for side effects. Returns the provided `backend`, invisibly.
145146
#' @export
146-
use_backend <- function(backend) {
147+
use_backend <- function(backend, gpu = NA) {
147148

148149
if (is_keras_loaded()) {
149150
if (config_backend() != backend)
150151
stop("The keras backend must be set before keras has inititialized. Please restart the R session.")
151152
}
153+
152154
Sys.setenv(KERAS_BACKEND = backend)
153155

154-
if (reticulate::py_available())
156+
if (reticulate::py_available()) {
155157
reticulate::import("os")$environ$update(list(KERAS_BACKEND = backend))
158+
}
159+
160+
161+
switch(
162+
paste0(get_os(), "_", backend),
163+
164+
macOS_tensorflow = {
165+
166+
if (is.na(gpu))
167+
gpu <- TRUE
168+
169+
if (gpu) {
170+
py_require("tensorflow", action = "remove")
171+
py_require(c("tensorflow-macos", "tensorflow-metal"), python_version = "<3.12")
172+
} else {
173+
py_require(action = "remove", c("tensorflow-macos", "tensorflow-metal"))
174+
py_require("tensorflow")
175+
}
176+
177+
},
178+
179+
macOS_jax = {
180+
181+
py_require(c("tensorflow-metal", "tensorflow-macos"),
182+
action = "remove")
183+
184+
if (is.na(gpu))
185+
gpu <- TRUE
186+
187+
if (gpu) {
188+
py_require(c("tensorflow", "jax", "jax-metal"))
189+
} else {
190+
py_require("tensorflow", "jax[cpu]")
191+
}
192+
},
193+
194+
macOS_torch = {
195+
if(isTRUE(gpu))
196+
warning("GPU usage not supported on macOS. Please use a different backend to use the GPU (jax)")
197+
198+
py_require(c("tensorflow-metal", "tensorflow-macos"),
199+
action = "remove")
200+
201+
py_require(c("tensorflow", "torch", "torchvision", "torchaudio"))
202+
},
203+
204+
macOS_numpy = {
205+
py_require(c("tensorflow-metal", "tensorflow-macos"), action = "remove")
206+
py_require(c("tensorflow", "numpy"))
207+
},
208+
209+
Linux_tensorflow = {
210+
211+
if (is.na(gpu))
212+
gpu <- has_gpu()
213+
214+
if (gpu) {
215+
py_require(action = "remove", c("tensorflow", "tensorflow-cpu"))
216+
py_require("tensorflow[and-cuda]")
217+
} else {
218+
py_require(action = "remove", c("tensorflow", "tensorflow[and-cuda]"))
219+
py_require("tensorflow-cpu")
220+
}
221+
},
222+
223+
Linux_jax = {
224+
py_require(c("tensorflow", "tensorflow[and-cuda]"),
225+
action = "remove")
226+
227+
if (is.na(gpu))
228+
gpu <- has_gpu()
229+
230+
if (gpu) {
231+
py_require(c("tensorflow-cpu", "jax[cuda12]"))
232+
} else {
233+
py_require(c("tensorflow-cpu", "jax[cpu]"))
234+
}
235+
},
236+
237+
Linux_torch = {
238+
py_require(c("tensorflow", "tensorflow[and-cuda]"), action = "remove")
239+
240+
if (is.na(gpu))
241+
gpu <- has_gpu()
242+
243+
if (gpu) {
244+
py_require(c("tensorflow-cpu", "torch", "torchvision", "torchaudio"))
245+
} else {
246+
Sys.setenv("UV_INDEX" = "https://download.pytorch.org/whl/cpu")
247+
py_require(c("tensorflow-cpu", "torch", "torchvision", "torchaudio"))
248+
# additional_args = c("--index", "https://download.pytorch.org/whl/cpu"))
249+
}
250+
},
251+
252+
Linux_numpy = {
253+
py_require(c("tensorflow", "tensorflow[and-cuda]"), action = "remove")
254+
py_require(c("tensorflow-cpu", "numpy"))
255+
},
256+
257+
Windows_tensorflow = {
258+
if(isTRUE(gpu)) warning("GPU usage not supported on Windows. Please use WSL.")
259+
py_require("tensorflow")
260+
},
261+
262+
Windows_jax = {
263+
if(isTRUE(gpu)) warning("GPU usage not supported on Windows. Please use WSL.")
264+
py_require("jax")
265+
},
266+
267+
Windows_torch = {
268+
if (is.na(gpu))
269+
gpu <- has_gpu()
270+
271+
if (gpu) {
272+
Sys.setenv("UV_INDEX" = "https://download.pytorch.org/whl/cu126")
273+
py_require(c("torch", "torchvision", "torchaudio"))
274+
# additional_args = c("--index", "https://download.pytorch.org/whl/cu126"))
275+
} else {
276+
py_require(c("torch", "torchvision", "torchaudio"))
277+
}
278+
},
279+
280+
Windows_numpy = {
281+
py_require("numpy")
282+
}
283+
)
284+
156285
invisible(backend)
157286
}
158287

159288

289+
290+
get_os <- function() {
291+
if (is_windows()) "Windows" else if (is_mac_arm64()) "macOS" else "Linux"
292+
}
293+
160294
is_keras_loaded <- function() {
161295
# package .onLoad() has run (can be FALSE if in devtools::load_all())
162296
!is.null(keras) &&
@@ -171,6 +305,64 @@ is_keras_loaded <- function() {
171305
}
172306

173307

308+
has_gpu <- function() {
309+
310+
has_nvidia_gpu <- function() {
311+
lspci_listed <- tryCatch(
312+
as.logical(length(
313+
system("lspci | grep -i nvidia", intern = TRUE)
314+
)),
315+
# warning emitted by system for non-0 exit status
316+
warning = function(w) FALSE,
317+
error = function(e) FALSE
318+
)
319+
320+
if (lspci_listed)
321+
return(TRUE)
322+
323+
# lspci doens't list GPUs on WSL Linux, but nvidia-smi does.
324+
nvidia_smi_listed <- tryCatch(
325+
system("nvidia-smi -L", intern = TRUE),
326+
warning = function(w) character(),
327+
error = function(e) character()
328+
)
329+
if (isTRUE(any(grepl("^GPU [0-9]: ", nvidia_smi_listed))))
330+
return(TRUE)
331+
FALSE
332+
}
333+
334+
is_linux() && has_nvidia_gpu()
335+
336+
}
337+
338+
339+
get_py_requirements <- function() {
340+
python_version <- ">=3.10"
341+
packages <- "tensorflow"
342+
343+
if(is_linux()) {
344+
345+
if(has_gpu()) {
346+
packages <- "tensorflow[and-cuda]"
347+
} else {
348+
packages <- "tensorflow-cpu"
349+
}
350+
351+
} else if (is_mac_arm64()) {
352+
353+
use_gpu <- FALSE
354+
if (use_gpu) {
355+
packages <- c("tensorflow-macos", "tensorflow-metal")
356+
python_version <- ">=3.9,<=3.11"
357+
}
358+
359+
} else if (is_windows()) {
360+
361+
}
362+
363+
list(packages = packages, python_version = python_version)
364+
}
365+
174366

175367
python_module_dir <- function(python, module, stderr = TRUE) {
176368

R/package.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ keras <- NULL
6161
if (!is.null(keras_python))
6262
Sys.setenv(RETICULATE_PYTHON = keras_python)
6363

64+
# default backend is tensorflow for now
65+
# the tensorflow R package calls `py_require()` to ensure GPU is usable on Linux
66+
py_require(c(
67+
"keras", "pydot", "scipy", "pandas", "Pillow",
68+
"ipython", "tensorflow_datasets"
69+
))
70+
6471
# delay load keras
6572
try(keras <<- import("keras", delay_load = list(
6673

man/use_backend.Rd

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)