Skip to content

Commit d728d31

Browse files
authored
Merge pull request #1527 from rstudio/fixes
Fixes for JAX updates
2 parents 0f4a4fa + 616740f commit d728d31

File tree

11 files changed

+69
-15
lines changed

11 files changed

+69
-15
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ S3method("[[",python.builtin.super)
1212
S3method("[[",python_builtin_super_getter)
1313
S3method(Arg,keras.src.backend.Tensor)
1414
S3method(Arg,keras.src.backend.common.keras_tensor.KerasTensor)
15+
S3method(Ops,jax._src.export.shape_poly._DimExpr)
1516
S3method(Summary,keras_shape)
1617
S3method(all,equal.numpy.ndarray)
1718
S3method(as.array,jax.Array)
@@ -40,6 +41,7 @@ S3method(as.numeric,keras.src.backend.common.variables.KerasVariable)
4041
S3method(base::all.equal,keras.src.backend.Tensor)
4142
S3method(base::all.equal,keras.src.backend.common.keras_tensor.KerasTensor)
4243
S3method(base::all.equal,keras.src.backend.common.variables.KerasVariable)
44+
S3method(base::as.array,PIL.Image.Image)
4345
S3method(compile,keras.src.models.model.Model)
4446
S3method(destructure,keras_shape)
4547
S3method(evaluate,keras.src.models.model.Model)

NEWS.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
- Added elastic deformation utilities for images: `layer_random_elastic_transform()`
2424
and the lower-level `op_image_elastic_transform()`.
2525

26+
- Added `as.array()` support for `PIL.Image.Image` objects.
27+
2628
- Transposed convolution utilities now follow the latest Keras API:
2729
`op_conv_transpose()` defaults `strides = 1` and the `layer_conv_*_transpose()`
2830
layers expose `output_padding` for precise shape control.
@@ -38,6 +40,11 @@
3840

3941
- `layer_layer_normalization()` removes the `rms_scaling` argument.
4042

43+
- Merging layers now capture `...` with tidy dots (fixes #1525).
44+
45+
- Fixed Ops on JAX `_DimExpr` so symbolic shapes survive arithmetic with R
46+
double scalars.
47+
4148
- `layer_reshape()` can now accept `-1` as a sentinel for an automatically calculated axis size.
4249

4350
- `layer_torch_module_wrapper()` gains an `output_shape` argument to help Keras

R/jax-methods.R

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,20 @@ type_sum.keras.src.backend.jax.core.JaxVariable <- type_sum.keras.src.backend.ja
9494

9595
# "keras.src.backend.Variable" too?
9696
# "keras.src.backend.common.variables.Variable" too?
97+
98+
#' @exportS3Method Ops jax._src.export.shape_poly._DimExpr
99+
Ops.jax._src.export.shape_poly._DimExpr <- function(e1, e2) {
100+
if (missing(e2)) {
101+
return(e1)
102+
}
103+
conv <- function(x) {
104+
if (is.double(x) && isTRUE(all(x == suppressWarnings(as.integer(x))))) {
105+
storage.mode(x) <- "integer"
106+
}
107+
x
108+
}
109+
e1 <- conv(e1)
110+
e2 <- conv(e2)
111+
NextMethod()
112+
}
113+

R/layers-merging.R

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ function (inputs, ...)
5252
args <- capture_args(list(input_shape = normalize_shape,
5353
batch_size = as_integer, batch_input_shape = normalize_shape),
5454
ignore = c("...", "inputs"))
55-
dots <- split_dots_named_unnamed(list(...))
55+
dots <- split_dots_named_unnamed(list2(...))
5656
if (missing(inputs))
5757
inputs <- NULL
5858
else if (!is.null(inputs) && !is.list(inputs))
@@ -116,7 +116,7 @@ function (inputs, ...)
116116
args <- capture_args(list(input_shape = normalize_shape,
117117
batch_size = as_integer, batch_input_shape = normalize_shape),
118118
ignore = c("...", "inputs"))
119-
dots <- split_dots_named_unnamed(list(...))
119+
dots <- split_dots_named_unnamed(list2(...))
120120
if (missing(inputs))
121121
inputs <- NULL
122122
else if (!is.null(inputs) && !is.list(inputs))
@@ -177,7 +177,7 @@ function (inputs, ..., axis = -1L)
177177
args <- capture_args(list(axis = as_axis, input_shape = normalize_shape,
178178
batch_size = as_integer, batch_input_shape = normalize_shape),
179179
ignore = c("...", "inputs"))
180-
dots <- split_dots_named_unnamed(list(...))
180+
dots <- split_dots_named_unnamed(list2(...))
181181
if (missing(inputs))
182182
inputs <- NULL
183183
else if (!is.null(inputs) && !is.list(inputs))
@@ -260,7 +260,7 @@ function (inputs, ..., axes, normalize = FALSE)
260260
args <- capture_args(list(axes = as_axis, input_shape = normalize_shape,
261261
batch_size = as_integer, batch_input_shape = normalize_shape),
262262
ignore = c("...", "inputs"))
263-
dots <- split_dots_named_unnamed(list(...))
263+
dots <- split_dots_named_unnamed(list2(...))
264264
if (missing(inputs))
265265
inputs <- NULL
266266
else if (!is.null(inputs) && !is.list(inputs))
@@ -322,7 +322,7 @@ function (inputs, ...)
322322
args <- capture_args(list(input_shape = normalize_shape,
323323
batch_size = as_integer, batch_input_shape = normalize_shape),
324324
ignore = c("...", "inputs"))
325-
dots <- split_dots_named_unnamed(list(...))
325+
dots <- split_dots_named_unnamed(list2(...))
326326
if (missing(inputs))
327327
inputs <- NULL
328328
else if (!is.null(inputs) && !is.list(inputs))
@@ -384,7 +384,7 @@ function (inputs, ...)
384384
args <- capture_args(list(input_shape = normalize_shape,
385385
batch_size = as_integer, batch_input_shape = normalize_shape),
386386
ignore = c("...", "inputs"))
387-
dots <- split_dots_named_unnamed(list(...))
387+
dots <- split_dots_named_unnamed(list2(...))
388388
if (missing(inputs))
389389
inputs <- NULL
390390
else if (!is.null(inputs) && !is.list(inputs))
@@ -446,7 +446,7 @@ function (inputs, ...)
446446
args <- capture_args(list(input_shape = normalize_shape,
447447
batch_size = as_integer, batch_input_shape = normalize_shape),
448448
ignore = c("...", "inputs"))
449-
dots <- split_dots_named_unnamed(list(...))
449+
dots <- split_dots_named_unnamed(list2(...))
450450
if (missing(inputs))
451451
inputs <- NULL
452452
else if (!is.null(inputs) && !is.list(inputs))
@@ -509,7 +509,7 @@ function (inputs, ...)
509509
args <- capture_args(list(input_shape = normalize_shape,
510510
batch_size = as_integer, batch_input_shape = normalize_shape),
511511
ignore = c("...", "inputs"))
512-
dots <- split_dots_named_unnamed(list(...))
512+
dots <- split_dots_named_unnamed(list2(...))
513513
if (missing(inputs))
514514
inputs <- NULL
515515
else if (!is.null(inputs) && !is.list(inputs))

R/r-utils.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ split_dots_named_unnamed <- function(dots) {
2727
if (is.null(nms))
2828
return(list(unnamed = dots, named = list()))
2929
named <- nzchar(nms)
30-
list(unnamed = dots[!named], named = dots[named])
30+
unnamed <- dots[!named]
31+
names(unnamed) <- NULL
32+
list(unnamed = unnamed, named = dots[named])
3133
}
3234

3335
drop_nulls <- function(x, i = NULL) {

R/s3-methods.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,7 @@ py_to_r__keras.src.utils.tracking.TrackedSet <- function(x) import("builtins")$l
110110
# }
111111
# rm(list = c("generic", "cls"))
112112

113+
#' @exportS3Method base::as.array
114+
as.array.PIL.Image.Image <- function(x, ...) {
115+
as.array(image_to_array(x, ...))
116+
}

man/layer_discretization.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_tfsm.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/op_angle.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
test_that("DimExpr Ops keeps symbolic dims when R uses double scalars", {
2+
skip_if_not(reticulate::py_module_available("jax"))
3+
4+
export <- reticulate::import("jax.export", convert = FALSE)
5+
dim <- export$symbolic_shape("n")[[1]]
6+
7+
expr <- dim - 1 # 1 is a double in R; Ops method should coerce to int
8+
9+
expect_s3_class(expr, "jax._src.export.shape_poly._DimExpr")
10+
expect_match(reticulate::py_str(expr), "n - 1")
11+
expect_false(any(grepl("Array", class(expr))))
12+
})

0 commit comments

Comments
 (0)