Skip to content

Commit f4d21cb

Browse files
committed
normalize_shape() accepts tensors
1 parent 8e29360 commit f4d21cb

File tree

5 files changed

+19
-7
lines changed

5 files changed

+19
-7
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
Standard layer arguments include: `input_shape`, `batch_input_shape`, `batch_size`, `dtype`, `name`, `trainable`, `weights`.
3838
Layers updated: `layer_global_{max,average}_pooling_{1,2,3}d()`, `time_distributed()`, `bidirectional()`.
3939

40+
- All the backend function with a shape argument `k_*(shape =)` that now accept a
41+
a mix of integer tensors and R numerics in the supplied list.
42+
4043
- `k_random_uniform()` now automatically coerces `minval` and `maxval` to the output dtype.
4144

4245
# keras 2.6.1

R/backend.R

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3103,10 +3103,5 @@ backend_normalize_shape <- function(shape) {
31033103
if (inherits(shape, "python.builtin.object"))
31043104
return(shape)
31053105

3106-
if (is.list(shape)) {
3107-
if (any(sapply(unlist(shape), function(x) inherits(x, "python.builtin.object"))))
3108-
return(shape)
3109-
}
3110-
31113106
normalize_shape(shape)
31123107
}

R/layers-core.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,13 @@ normalize_shape <- function(shape) {
433433
# if it's a list or a numeric vector then convert to integer
434434
if (is.list(shape) || is.numeric(shape)) {
435435
shape <- lapply(shape, function(value) {
436+
437+
# Pass through python objects unmodified, only coerce R objects
438+
# supplied shapes, e.g., to tf$random$normal, can be a list that's a mix
439+
# of scalar integer tensors and regular integers
440+
if(inherits(value, "python.builtin.object"))
441+
return(value)
442+
436443
if (!is.null(value))
437444
as.integer(value)
438445
else

tests/testthat/test-backend.R

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,12 @@ test_backend("k_set_epsilon", {
132132
k_set_epsilon(1e-07)
133133
expect_equal(k_epsilon(), 1e-07)
134134
})
135+
136+
137+
test_backend("k_random_normal", {
138+
chk <- function(x) expect_tensor(x, shape = c(1L, 2L, 3L))
139+
chk(k_random_normal(c(1,2,3)))
140+
chk(k_random_normal(as_tensor(c(1,2,3), "int32")))
141+
chk(k_random_normal(lapply(c(1,2,3), as_tensor, "int32")))
142+
chk(k_random_normal(list(1, as_tensor(2, "int32"), 3)))
143+
})

vignettes/new-guides/customizing_what_happens_in_fit.Rmd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,6 @@ GAN(keras$Model) %py_class% {
356356
# Sample random points in the latent space
357357
batch_size <- tf$shape(real_images)[1]
358358
359-
# TODO: shape() should be able to handle a scalar tensor as input
360-
# also, backend functions like k_random_normal
361359
random_latent_vectors <- tf$random$normal(list(batch_size, self$latent_dim))
362360
363361
# Decode them to fake images

0 commit comments

Comments
 (0)