Skip to content

Commit ca21465

Browse files
authored
Merge pull request #1532 from rstudio/jax-updates
Updates for latest JAX
2 parents 346dfad + 1e11c5d commit ca21465

File tree

8 files changed

+55
-35
lines changed

8 files changed

+55
-35
lines changed

R/install.R

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,9 @@ is_linux <- function() {
134134
#'
135135
#' @param backend string, can be `"tensorflow"`, `"jax"`, `"numpy"`, or
136136
#' `"torch"`.
137-
#' @param gpu bool, whether to use the GPU. If `NA` (default), it will attempt
138-
#' to detect GPU availability on Linux. On M-series Macs, it defaults to
139-
#' `FALSE` for TensorFlow and `TRUE` for JAX. On Windows, it defaults to
140-
#' `FALSE`.
137+
#' @param gpu bool, whether to use the GPU. If `NA` (default), it will
138+
#' attempt to detect GPU availability on Linux. On macOS and Windows it
139+
#' defaults to `FALSE`.
141140
#'
142141
#' @details
143142
#'
@@ -147,14 +146,21 @@ is_linux <- function() {
147146
#' The function should be called after `library(keras3)` and before calling
148147
#' other functions within the package (see below for an example).
149148
#'
149+
#' Note that macOS packages like `tensorflow-metal` and `jax-metal` that
150+
#' purportedly enabled GPU usage on M-series macs all are currently broken
151+
#' and seemingly abandoned.
152+
#'
150153
#' There is experimental support for changing the backend after keras has
151-
#' initialized. using `config_set_backend()`.
154+
#' initialized with `config_set_backend()`. Usage of `config_set_backend` is
155+
#' generall not recommended for regular workflow---restarting the R session
156+
#' is the only reliable way to change the backend.
157+
#'
152158
#' ```r
153159
#' library(keras3)
154160
#' use_backend("tensorflow")
155161
#' ```
156-
#' @returns Called primarily for side effects. Returns the provided `backend`,
157-
#' invisibly.
162+
#' @returns Called primarily for side effects. Returns the provided
163+
#' `backend`, invisibly.
158164
#' @export
159165
use_backend <- function(backend, gpu = NA) {
160166

@@ -197,12 +203,14 @@ use_backend <- function(backend, gpu = NA) {
197203

198204
macOS_jax = {
199205
if (is.na(gpu))
200-
gpu <- TRUE
206+
gpu <- FALSE
201207

202208
if (gpu) {
209+
# jax-metal is abandoned
210+
# https://github.com/jax-ml/jax/issues/34109#issuecomment-3774392604
203211
py_require(c("tensorflow", "jax", "jax-metal"))
204212
} else {
205-
py_require(c("tensorflow", "jax[cpu]"))
213+
py_require(c("tensorflow", "jax")) # jax[cpu] ?
206214
}
207215
},
208216

@@ -363,9 +371,7 @@ py_require_remove_all_torch <- function() {
363371
py_require_tensorflow_cpu <- function() {
364372
if (is_linux()) {
365373

366-
# pin 2.18.* because later versions of 'tensorflow-cpu' are not
367-
# compatible with 'tensorflow-text', used by 'keras-hub'
368-
py_require("tensorflow-cpu==2.18.*")
374+
py_require("tensorflow-cpu")
369375

370376
# set override so tensorflow-text is prevented from pulling in 'tensorflow'
371377
uv_set_override_never_tensorflow()

R/package.R

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -216,15 +216,24 @@ keras <- NULL
216216
registerS3method("aperm", backend_tensor_class, op_transpose, baseenv())
217217
registerS3method("all.equal", backend_tensor_class, all.equal.numpy.ndarray, baseenv())
218218

219-
if(keras$config$backend() == "jax") {
220-
for(py_type in import("jax")$Array$`__subclasses__`()) {
221-
s3_classname <- nameOfClass__python.builtin.type(py_type)
222-
registerS3method("@" , s3_classname, at.keras_backend_tensor, baseenv())
223-
registerS3method("@<-" , s3_classname, at_set.keras_backend_tensor, baseenv())
224-
registerS3method("as.array", s3_classname, op_convert_to_array, baseenv())
225-
registerS3method("^" , s3_classname, `^__keras.backend.tensor`, baseenv())
226-
registerS3method("%*%" , s3_classname, op_matmul, baseenv())
227-
}
219+
# "jax._src.core.Tracer"
220+
if (keras$config$backend() == "jax") {
221+
local({
222+
#
223+
jax <- import("jax")
224+
jax_types <- c(
225+
jax$Array$`__subclasses__`(),
226+
jax$core$Tracer
227+
)
228+
for (py_type in jax_types) {
229+
s3_classname <- nameOfClass__python.builtin.type(py_type)
230+
registerS3method("@" , s3_classname, at.keras_backend_tensor, baseenv())
231+
registerS3method("@<-" , s3_classname, at_set.keras_backend_tensor, baseenv())
232+
registerS3method("as.array", s3_classname, op_convert_to_array, baseenv())
233+
registerS3method("^" , s3_classname, `^__keras.backend.tensor`, baseenv())
234+
registerS3method("%*%" , s3_classname, op_matmul, baseenv())
235+
}
236+
})
228237
}
229238
})
230239

man/deserialize_keras_object.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_tfsm.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/metric_mean_absolute_percentage_error.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/op_erf.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/op_gelu.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/use_backend.Rd

Lines changed: 12 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)