Skip to content

Commit 1da2c32

Browse files
committed
fix-up standard layer args in custom layers
closes #1338
1 parent 3a5a680 commit 1da2c32

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

R/layer-custom.R

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,16 @@ py_formals <- function(py_obj) {
196196
create_layer_wrapper <- function(Layer, modifiers = NULL, convert = TRUE) {
197197

198198
force(Layer)
199-
force(modifiers)
199+
modifiers <- utils::modifyList(
200+
list(
201+
# include helpers for standard layer args by default,
202+
# but leave an escape hatch allowing users to override/opt-out.
203+
input_shape = as_tf_shape,
204+
batch_input_shape = as_tf_shape,
205+
batch_size = as.integer
206+
),
207+
as.list(modifiers)
208+
)
200209

201210
wrapper <- function(object) {
202211
args <- capture_args(match.call(), modifiers, ignore = "object")
@@ -241,3 +250,11 @@ r_to_py.keras_layer_wrapper <- function(fn, convert = FALSE) {
241250
layer <- r_to_py(layer, convert)
242251
layer
243252
}
253+
254+
255+
as_tf_shape <- function (x) {
256+
if (inherits(x, "tensorflow.python.framework.tensor_shape.TensorShape"))
257+
x
258+
else
259+
shape(dims = x)
260+
}

tests/testthat/test-Layer.R

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,11 @@ test_succeeds("create_layer_wrapper", {
197197
})
198198

199199

200-
201-
202-
203200
test_succeeds("create_layer_wrapper", {
204201

205202
layer_sampler <- new_layer_class(
206203
classname = "Sampler",
207-
call = function(self, z_mean, z_log_var) {
204+
call = function(z_mean, z_log_var) {
208205
epsilon <- k_random_normal(shape = k_shape(z_mean))
209206
z_mean + exp(0.5 * z_log_var) * epsilon
210207
}
@@ -217,3 +214,40 @@ test_succeeds("create_layer_wrapper", {
217214
expect_equal(dim(res), c(128, 2))
218215

219216
})
217+
218+
219+
test_succeeds("custom layers can accept standard layer args like input_shape", {
220+
# https://github.com/rstudio/keras/issues/1338
221+
layer_simple_dense <- new_layer_class(
222+
classname = "SimpleDense",
223+
224+
initialize = function(units, activation = NULL, ...) {
225+
super$initialize(...)
226+
self$units <- as.integer(units)
227+
self$activation <- activation
228+
},
229+
230+
build = function(input_shape) {
231+
input_dim <- input_shape[length(input_shape)]
232+
self$W <- self$add_weight(shape = c(input_dim, self$units),
233+
initializer = "random_normal")
234+
self$b <- self$add_weight(shape = c(self$units),
235+
initializer = "zeros")
236+
},
237+
238+
call = function(inputs) {
239+
y <- tf$matmul(inputs, self$W) + self$b
240+
if (!is.null(self$activation))
241+
y <- self$activation(y)
242+
y
243+
}
244+
)
245+
246+
model <- keras_model_sequential() %>%
247+
layer_simple_dense(20, input_shape = 30) %>%
248+
layer_dense(10)
249+
250+
res <- model(random_array(1, 30))
251+
expect_tensor(res, shape = c(1L, 10L))
252+
253+
})

0 commit comments

Comments
 (0)