Skip to content

Commit 4be23b2

Browse files
committed
Accept NA in layer_input(shape = c(NA))
Now returns a tensor with shape (None, None)
1 parent a193a67 commit 4be23b2

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

R/layers-core.R

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,23 +431,30 @@ normalize_shape <- function(shape) {
431431
return(shape)
432432

433433
# if it's a list or a numeric vector then convert to integer
434-
if (is.list(shape) || is.numeric(shape)) {
435-
shape <- lapply(shape, function(value) {
434+
# NA's in are accepted as NULL
435+
# also accept c(NA), as if it was a numeric
436+
if (is.list(shape) || is.numeric(shape) ||
437+
(is.logical(shape) && all(is.na(shape)))) {
436438

439+
shape <- lapply(shape, function(value) {
437440
# Pass through python objects unmodified, only coerce R objects
438441
# supplied shapes, e.g., to tf$random$normal, can be a list that's a mix
439442
# of scalar integer tensors and regular integers
440-
if(inherits(value, "python.builtin.object"))
443+
if (inherits(value, "python.builtin.object"))
441444
return(value)
442445

446+
# accept NA,NA_integer_,NA_real_ as NULL
447+
if ((is_scalar(value) && is.na(value)))
448+
return(NULL)
449+
443450
if (!is.null(value))
444451
as.integer(value)
445452
else
446453
NULL
447454
})
448455
}
449456

450-
if(inherits(shape, "tensorflow.python.framework.tensor_shape.TensorShape"))
457+
if (inherits(shape, "tensorflow.python.framework.tensor_shape.TensorShape"))
451458
shape <- as.list(shape$as_list()) # unpack for tuple()
452459

453460
# coerce to tuple so it's iterable

R/utils.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,6 @@ capture_args <- function(cl, modifiers = NULL, ignore = NULL,
448448

449449
args
450450
}
451+
452+
453+
is_scalar <- function(x) identical(length(x), 1L)

tests/testthat/test-model.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ test_succeeds("can call model with R objects", {
129129
test_succeeds("layer_input() ", {
130130
# can take dtype = Dtype
131131
layer_input(shape = 1, dtype = tf$string)
132+
133+
expect_identical(as.list(layer_input(shape = c(NA))$shape),
134+
list(NULL, NULL))
132135
})
133136

134137

0 commit comments

Comments
 (0)