Skip to content

Commit 484dff4

Browse files
committed
fix recalling use_backend("jax") with changed gpu value
1 parent 130ec29 commit 484dff4

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

R/install.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ use_backend <- function(backend, gpu = NA) {
226226
},
227227

228228
Linux_jax = {
229-
py_require(c("tensorflow", "tensorflow[and-cuda]"), action = "remove")
229+
py_require(c("tensorflow", "tensorflow[and-cuda]", "jax[cuda12]", "jax[cpu]"), action = "remove")
230230

231231
if (is.na(gpu))
232232
gpu <- has_gpu()

0 commit comments

Comments
 (0)