Skip to content

Commit f29f702

Browse files
committed
Merge branch 'main' of https://github.com/rstudio/keras into main
2 parents 98edbb6 + f33be71 commit f29f702

File tree

8 files changed

+75
-11
lines changed

8 files changed

+75
-11
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# keras (development version)
22

3+
- Fixed issue where `input_shape` supplied to custom layers defined with `new_layer_class()`
4+
would result in an error (#1338)
5+
36
- New `callback_backup_and_restore()`, for resuming an interrupted `fit()` call.
47

58
- The merging family of layers (`layer_add`, `layer_concatenate`, etc.) gain the ability

R/install.R

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@ install_keras <- function(method = c("auto", "virtualenv", "conda"),
3333
method = method,
3434
conda = conda,
3535
version = version,
36-
extra_packages = c("pandas", "Pillow",
37-
"pydot",
38-
"tensorflow-hub",
39-
"tensorflow-datasets",
36+
extra_packages = c(default_extra_packages(),
4037
extra_packages),
4138
...))
4239
}
@@ -66,6 +63,7 @@ default_version <- numeric_version("2.9")
6663
default_extra_packages <- function(tensorflow_version = "default") {
6764
pkgs <- c(
6865
"tensorflow-hub",
66+
"tensorflow-datasets",
6967
"scipy",
7068
"requests",
7169
"pyyaml",

R/layer-custom.R

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,16 @@ py_formals <- function(py_obj) {
196196
create_layer_wrapper <- function(Layer, modifiers = NULL, convert = TRUE) {
197197

198198
force(Layer)
199-
force(modifiers)
199+
modifiers <- utils::modifyList(
200+
list(
201+
# include helpers for standard layer args by default,
202+
# but leave an escape hatch allowing users to override/opt-out.
203+
input_shape = as_tf_shape,
204+
batch_input_shape = as_tf_shape,
205+
batch_size = as.integer
206+
),
207+
as.list(modifiers)
208+
)
200209

201210
wrapper <- function(object) {
202211
args <- capture_args(match.call(), modifiers, ignore = "object")
@@ -241,3 +250,11 @@ r_to_py.keras_layer_wrapper <- function(fn, convert = FALSE) {
241250
layer <- r_to_py(layer, convert)
242251
layer
243252
}
253+
254+
255+
as_tf_shape <- function (x) {
256+
if (inherits(x, "tensorflow.python.framework.tensor_shape.TensorShape"))
257+
x
258+
else
259+
shape(dims = x)
260+
}

R/layer-methods.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#'
1313
#' @param object Layer or model object
1414
#' @param config Object with layer or model configuration
15+
#' @param custom_objects list of custom objects needed to instantiate the layer,
16+
#' e.g., custom layers defined by `new_layer_class()` or similar.
1517
#'
1618
#' @return `get_config()` returns an object with the configuration,
1719
#' `from_config()` returns a re-instantiation of the object.

man/get_config.Rd

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/install_keras.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/helper-utils.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
# Sys.setenv(TF_CPP_MIN_LOG_LEVEL = 1)
1+
Sys.setenv(TF_CPP_MIN_LOG_LEVEL = 1)
2+
# 0 = all messages are logged (default behavior)
3+
# 1 = INFO messages are not printed
4+
# 2 = INFO and WARNING messages are not printed
5+
# 3 = INFO, WARNING, and ERROR messages are not printed
26

37

48
# Sys.setenv(RETICULATE_PYTHON = "~/.local/share/r-miniconda/envs/tf-2.7-cpu/bin/python")

tests/testthat/test-Layer.R

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,11 @@ test_succeeds("create_layer_wrapper", {
197197
})
198198

199199

200-
201-
202-
203200
test_succeeds("create_layer_wrapper", {
204201

205202
layer_sampler <- new_layer_class(
206203
classname = "Sampler",
207-
call = function(self, z_mean, z_log_var) {
204+
call = function(z_mean, z_log_var) {
208205
epsilon <- k_random_normal(shape = k_shape(z_mean))
209206
z_mean + exp(0.5 * z_log_var) * epsilon
210207
}
@@ -217,3 +214,43 @@ test_succeeds("create_layer_wrapper", {
217214
expect_equal(dim(res), c(128, 2))
218215

219216
})
217+
218+
219+
test_succeeds("custom layers can accept standard layer args like input_shape", {
220+
# https://github.com/rstudio/keras/issues/1338
221+
layer_simple_dense <- new_layer_class(
222+
classname = "SimpleDense",
223+
224+
initialize = function(units, activation = NULL, ...) {
225+
super$initialize(...)
226+
self$units <- as.integer(units)
227+
self$activation <- activation
228+
},
229+
230+
build = function(input_shape) {
231+
input_dim <- input_shape[length(input_shape)]
232+
self$W <- self$add_weight(shape = c(input_dim, self$units),
233+
initializer = "random_normal")
234+
self$b <- self$add_weight(shape = c(self$units),
235+
initializer = "zeros")
236+
},
237+
238+
call = function(inputs) {
239+
y <- tf$matmul(inputs, self$W) + self$b
240+
if (!is.null(self$activation))
241+
y <- self$activation(y)
242+
y
243+
}
244+
)
245+
246+
model <- keras_model_sequential() %>%
247+
layer_simple_dense(20, input_shape = 30) %>%
248+
layer_dense(10)
249+
250+
expect_identical(dim(model$input), c(NA_integer_, 30L))
251+
expect_true(model$built)
252+
253+
res <- model(random_array(1, 30))
254+
expect_tensor(res, shape = c(1L, 10L))
255+
256+
})

0 commit comments

Comments
 (0)