Skip to content

Commit 9d18db5

Browse files
committed
fix: create_layer_wrapper() should include args with a NULL default
1 parent 7ea3cae commit 9d18db5

File tree

3 files changed

+46
-3
lines changed

3 files changed

+46
-3
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333

3434
- `image_load()` gains a `color_mode` argument.
3535

36+
- Fixed issue where `create_layer_wrapper()` would not include arguments
37+
with a `NULL` default value in the returned wrapper.
38+
3639
- Deprecated functions are no longer included in the package documentation index.
3740

3841
# keras 2.7.0

R/layer-custom.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ compat_custom_KerasLayer_handler <- function(layer_class, args) {
124124

125125

126126
py_formals <- function(py_obj) {
127-
# returns python fn formals as a list (formals(),
128-
# but for py functions/methods
127+
# returns python fn formals as a list
128+
# like base::formals(), but for py functions/methods
129129
inspect <- reticulate::import("inspect")
130130
sig <- if (inspect$isclass(py_obj)) {
131131
inspect$signature(py_obj$`__init__`)
@@ -161,7 +161,7 @@ py_formals <- function(py_obj) {
161161
next
162162
}
163163

164-
args[[name]] <- default
164+
args[name] <- list(default) # default can be NULL
165165
}
166166
args
167167
}

tests/testthat/test-Layer.R

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,43 @@ test_succeeds("Can inherit from an R custom layer", {
155155
l <- layer2(x = 2)
156156
expect_equal(as.numeric(l(1)), 12)
157157
})
158+
159+
160+
test_succeeds("create_layer_wrapper", {
161+
162+
SimpleDense(keras$layers$Layer) %py_class% {
163+
initialize <- function(units, activation = NULL) {
164+
super$initialize()
165+
self$units <- as.integer(units)
166+
self$activation <- activation
167+
}
168+
169+
build <- function(input_shape) {
170+
input_dim <- as.integer(input_shape) %>% tail(1)
171+
self$W <- self$add_weight(shape = c(input_dim, self$units),
172+
initializer = "random_normal")
173+
self$b <- self$add_weight(shape = c(self$units),
174+
initializer = "zeros")
175+
}
176+
177+
call <- function(inputs) {
178+
y <- tf$matmul(inputs, self$W) + self$b
179+
if (!is.null(self$activation))
180+
y <- self$activation(y)
181+
y
182+
}
183+
}
184+
185+
layer_simple_dense <- create_layer_wrapper(SimpleDense)
186+
187+
expect_identical(formals(layer_simple_dense),
188+
formals(function(object, units, activation = NULL) {}))
189+
190+
model <- keras_model_sequential() %>%
191+
layer_simple_dense(32, activation = "relu") %>%
192+
layer_simple_dense(64, activation = "relu") %>%
193+
layer_simple_dense(32, activation = "relu") %>%
194+
layer_simple_dense(10, activation = "softmax")
195+
196+
expect_equal(length(model$layers), 4L)
197+
})

0 commit comments

Comments
 (0)