Skip to content

Commit ab138a4

Browse files
committed
layer_merge() family can accept layers in ...
1 parent c4fe6e2 commit ab138a4

File tree

14 files changed

+197
-77
lines changed

14 files changed

+197
-77
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,7 @@ importFrom(reticulate,use_condaenv)
617617
importFrom(reticulate,use_python)
618618
importFrom(reticulate,use_virtualenv)
619619
importFrom(rlang,"%||%")
620+
importFrom(rlang,names2)
620621
importFrom(stats,predict)
621622
importFrom(tensorflow,as_tensor)
622623
importFrom(tensorflow,evaluate)

NEWS.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22

33
- New `callback_backup_and_restore()`, for resuming an interrupted `fit()` call.
44

5+
- The merging family of layers (`layer_add`, `layer_concatenate`, etc.) gain the ability
6+
to accept layers in `...`, allowing for easier composition of residual blocks with the pipe `%>%`.
7+
e.g. something like this now works:
8+
```r
9+
block_1_output <- ...
10+
block_2_output <- block_1_output %>%
11+
layer_conv_2d(64, 3, activation = "relu", padding = "same") %>%
12+
layer_add(block_1_output)
13+
```
14+
515
# keras 2.9.0
616

717
- New functions for constructing custom keras subclasses:

R/layers-merge.R

Lines changed: 120 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
#' It takes as input a list of tensors, all of the same shape, and returns a
66
#' single tensor (also of the same shape).
77
#'
8-
#' @param inputs A list of input tensors (at least 2). Can be missing.
9-
#' @param ... Standard layer arguments (must be named).
8+
#' @param inputs A input tensor, or list of input tensors. Can be missing.
9+
#' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments.
1010
#'
1111
#' @return A tensor, the sum of the inputs. If `inputs` is missing, a keras
1212
#' layer instance is returned.
@@ -15,16 +15,20 @@
1515
#'
1616
#' @seealso
1717
#' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/add>
18-
#' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/Add>
1918
#' + <https://keras.io/api/layers/merging_layers/add>
2019
#'
2120
#' @export
2221
layer_add <- function(inputs, ...) {
23-
callable <- if(missing(inputs)) keras$layers$Add else keras$layers$add
24-
args <- capture_args(match.call(), list(batch_size = as_nullable_integer))
25-
do.call(callable, args)
22+
if (missing(inputs))
23+
return(keras$layers$Add(...))
24+
if (!is.list(inputs))
25+
inputs <- list(inputs)
26+
dots <- split_dots_named_unnamed(list(...))
27+
inputs <- c(inputs, dots$unnamed)
28+
do.call(keras$layers$add, c(list(inputs), dots$named))
2629
}
2730

31+
2832
# TODO: there should be a common topic where we can use
2933
# @inheritDotParams standard-layer-args
3034

@@ -35,8 +39,8 @@ layer_add <- function(inputs, ...) {
3539
#' returns a single tensor, (`inputs[[1]] - inputs[[2]]`), also of the same
3640
#' shape.
3741
#'
38-
#' @param inputs A list of input tensors (exactly 2). Can be missing.
39-
#' @param ... Standard layer arguments (must be named).
42+
#' @param inputs A input tensor, or list of two input tensors. Can be missing.
43+
#' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments.
4044
#'
4145
#' @return A tensor, the difference of the inputs. If `inputs` is missing, a
4246
#' keras layer instance is returned.
@@ -46,23 +50,27 @@ layer_add <- function(inputs, ...) {
4650
#'
4751
#' @seealso
4852
#' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/subtract>
49-
#' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/Subtract>
5053
#' + <https://keras.io/api/layers/merging_layers/subtract>
5154
#'
5255
#' @export
5356
layer_subtract <- function(inputs, ...) {
54-
callable <- if (missing(inputs)) keras$layers$Subtract else keras$layers$subtract
55-
args <- capture_args(match.call(), list(batch_size = as_nullable_integer))
56-
do.call(callable, args)
57+
if (missing(inputs))
58+
return(keras$layers$Subtract(...))
59+
if (!is.list(inputs))
60+
inputs <- list(inputs)
61+
dots <- split_dots_named_unnamed(list(...))
62+
inputs <- c(inputs, dots$unnamed)
63+
do.call(keras$layers$subtract, c(list(inputs), dots$named))
5764
}
5865

66+
5967
#' Layer that multiplies (element-wise) a list of inputs.
6068
#'
6169
#' It takes as input a list of tensors, all of the same shape, and returns a
6270
#' single tensor (also of the same shape).
6371
#'
64-
#' @param inputs A list of input tensors (at least 2). Can be missing.
65-
#' @param ... Standard layer arguments (must be named).
72+
#' @param inputs A input tensor, or list of input tensors. Can be missing.
73+
#' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments.
6674
#'
6775
#' @return A tensor, the element-wise product of the inputs. If `inputs` is
6876
#' missing, a keras layer instance is returned.
@@ -72,15 +80,17 @@ layer_subtract <- function(inputs, ...) {
7280
#'
7381
#' @seealso
7482
#' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/multiply>
75-
#' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/Multiply>
7683
#' + <https://keras.io/api/layers/merging_layers/multiply>
7784
#'
7885
#' @export
7986
layer_multiply <- function(inputs, ...) {
80-
callable <- if (missing(inputs)) keras$layers$Multiply else keras$layers$multiply
81-
args <- capture_args(match.call(), list(batch_size = as_nullable_integer))
82-
do.call(callable, args)
83-
87+
if (missing(inputs))
88+
return(keras$layers$Multiply(...))
89+
if (!is.list(inputs))
90+
inputs <- list(inputs)
91+
dots <- split_dots_named_unnamed(list(...))
92+
inputs <- c(inputs, dots$unnamed)
93+
do.call(keras$layers$multiply, c(list(inputs), dots$named))
8494
}
8595

8696

@@ -89,8 +99,8 @@ layer_multiply <- function(inputs, ...) {
8999
#' It takes as input a list of tensors, all of the same shape, and returns a
90100
#' single tensor (also of the same shape).
91101
#'
92-
#' @param inputs A list of input tensors (at least 2). Can be missing.
93-
#' @param ... Standard layer arguments (must be named).
102+
#' @param inputs A input tensor, or list of input tensors. Can be missing.
103+
#' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments.
94104
#'
95105
#' @return A tensor, the average of the inputs. If `inputs` is missing, a keras
96106
#' layer instance is returned.
@@ -105,19 +115,22 @@ layer_multiply <- function(inputs, ...) {
105115
#'
106116
#' @export
107117
layer_average <- function(inputs, ...) {
108-
callable <- if (missing(inputs)) keras$layers$Average else keras$layers$average
109-
args <- capture_args(match.call(), list(batch_size = as_nullable_integer))
110-
do.call(callable, args)
111-
118+
if (missing(inputs))
119+
return(keras$layers$Average(...))
120+
if (!is.list(inputs))
121+
inputs <- list(inputs)
122+
dots <- split_dots_named_unnamed(list(...))
123+
inputs <- c(inputs, dots$unnamed)
124+
do.call(keras$layers$average, c(list(inputs), dots$named))
112125
}
113126

114127
#' Layer that computes the maximum (element-wise) a list of inputs.
115128
#'
116129
#' It takes as input a list of tensors, all of the same shape, and returns a
117130
#' single tensor (also of the same shape).
118131
#'
119-
#' @param inputs A list of input tensors (at least 2). Can be missing.
120-
#' @param ... Standard layer arguments (must be named).
132+
#' @param inputs A input tensor, or list of input tensors. Can be missing.
133+
#' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments.
121134
#'
122135
#' @return A tensor, the element-wise maximum of the inputs. If `inputs` is
123136
#' missing, a keras layer instance is returned.
@@ -132,10 +145,13 @@ layer_average <- function(inputs, ...) {
132145
#'
133146
#' @export
134147
layer_maximum <- function(inputs, ...) {
135-
callable <- if (missing(inputs)) keras$layers$Maximum else keras$layers$maximum
136-
args <- capture_args(match.call(), list(batch_size = as_nullable_integer))
137-
do.call(callable, args)
138-
148+
if (missing(inputs))
149+
return(keras$layers$Maximum(...))
150+
if (!is.list(inputs))
151+
inputs <- list(inputs)
152+
dots <- split_dots_named_unnamed(list(...))
153+
inputs <- c(inputs, dots$unnamed)
154+
do.call(keras$layers$maximum, c(list(inputs), dots$named))
139155
}
140156

141157

@@ -144,8 +160,8 @@ layer_maximum <- function(inputs, ...) {
144160
#' It takes as input a list of tensors, all of the same shape, and returns a
145161
#' single tensor (also of the same shape).
146162
#'
147-
#' @param inputs A list of input tensors (at least 2). Can be missing.
148-
#' @param ... Standard layer arguments (must be named).
163+
#' @param inputs A input tensor, or list of input tensors. Can be missing.
164+
#' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments.
149165
#'
150166
#' @return A tensor, the element-wise maximum of the inputs. If `inputs` is
151167
#' missing, a keras layer instance is returned.
@@ -159,9 +175,13 @@ layer_maximum <- function(inputs, ...) {
159175
#'
160176
#' @export
161177
layer_minimum <- function(inputs, ...) {
162-
callable <- if (missing(inputs)) keras$layers$Minimum else keras$layers$minimum
163-
args <- capture_args(match.call(), list(batch_size = as_nullable_integer))
164-
do.call(callable, args)
178+
if (missing(inputs))
179+
return(keras$layers$Minimum(...))
180+
if (!is.list(inputs))
181+
inputs <- list(inputs)
182+
dots <- split_dots_named_unnamed(list(...))
183+
inputs <- c(inputs, dots$unnamed)
184+
do.call(keras$layers$minimum, c(list(inputs), dots$named))
165185
}
166186

167187

@@ -171,9 +191,9 @@ layer_minimum <- function(inputs, ...) {
171191
#' concatenation axis, and returns a single tensor, the concatenation of all
172192
#' inputs.
173193
#'
174-
#' @param inputs A list of input tensors (at least 2). Can be missing.
194+
#' @param inputs A input tensor, or list of input tensors. Can be missing.
195+
#' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments.
175196
#' @param axis Concatenation axis.
176-
#' @param ... Standard layer arguments (must be named).
177197
#'
178198
#' @return A tensor, the concatenation of the inputs alongside axis `axis`. If
179199
#' `inputs` is missing, a keras layer instance is returned.
@@ -186,23 +206,43 @@ layer_minimum <- function(inputs, ...) {
186206
#' + <https://keras.io/api/layers/merging_layers/concatenate>
187207
#'
188208
#' @export
189-
layer_concatenate <- function(inputs, axis = -1, ...) {
190-
callable <- if (missing(inputs)) keras$layers$Concatenate else keras$layers$concatenate
191-
# TODO: this axis should probably be 1-based
192-
args <- capture_args(match.call(), list(batch_size = as_nullable_integer,
193-
axis = as.integer))
194-
do.call(callable, args)
209+
layer_concatenate <- function(inputs, ..., axis = -1) {
210+
if (missing(inputs)) {
211+
args <- capture_args(match.call(), list(axis = as.integer))
212+
return(do.call(keras$layers$Concatenate, args))
213+
}
214+
215+
# TODO: this axis arg should probably be 1-based
216+
217+
if (is.list(inputs)) {
218+
# backcompat: axis used to be in 2nd position, inputs used to accept only a list.
219+
220+
dots <- list(...)
221+
if (length(dots) && names2(dots)[[1]] == "" &&
222+
missing(axis) &&
223+
is.numeric(dots[[1L]]) &&
224+
is_scalar(dots[[1L]])) {
225+
axis <- as.integer(dots[[1L]])
226+
dots[[1L]] <- NULL
227+
}
228+
return(do.call(keras$layers$concatenate,
229+
c(list(inputs), dots, axis = as.integer(axis))))
230+
}
231+
232+
dots <- split_dots_named_unnamed(list(...))
233+
inputs <- c(list(inputs), dots$unnamed)
234+
do.call(keras$layers$concatenate, c(list(inputs), dots$named))
195235
}
196236

197237
#' Layer that computes a dot product between samples in two tensors.
198238
#'
199-
#' @param inputs A list of input tensors (at least 2). Can be missing.
239+
#' @param inputs A input tensor, or list of input tensors. Can be missing.
240+
#' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments.
200241
#' @param axes Integer or list of integers, axis or axes along which to take the
201242
#' dot product.
202243
#' @param normalize Whether to L2-normalize samples along the dot product axis
203244
#' before taking the dot product. If set to TRUE, then the output of the dot
204245
#' product is the cosine proximity between the two samples.
205-
#' @param ... Standard layer arguments (must be named).
206246
#'
207247
#' @return If `inputs` is supplied: A tensor, the dot product of the samples
208248
#' from the inputs. If `inputs` is missing, a keras layer instance is
@@ -213,13 +253,41 @@ layer_concatenate <- function(inputs, axis = -1, ...) {
213253
#'
214254
#' @seealso
215255
#' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/dot>
216-
#' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dot>
217256
#' + <https://keras.io/api/layers/merging_layers/dot/>
218257
#'
219258
#' @export
220-
layer_dot <- function(inputs, axes, normalize = FALSE, ...) {
221-
callable <- if (missing(inputs)) keras$layers$Dot else keras$layers$dot
222-
args <- capture_args(match.call(), list(batch_size = as_nullable_integer,
223-
axes = as.integer))
224-
do.call(callable, args)
259+
#' @importFrom rlang names2
260+
layer_dot <- function(inputs, ..., axes, normalize = FALSE) {
261+
if (missing(inputs)) {
262+
args <- capture_args(match.call(), list(axes = as.integer))
263+
return(do.call(keras$layers$Dot, args))
264+
}
265+
266+
if (is.list(inputs)) {
267+
# backcompat: inputs used to only accept a list of layers, and
268+
# axis, normalize, used to be in 2nd, 3rd position.
269+
dots <- list(...)
270+
if (length(dots) && names2(dots)[[1]] == "" &&
271+
missing(axes)) {
272+
axes <- as.integer(dots[[1L]])
273+
dots[[1L]] <- NULL
274+
}
275+
if (length(dots) && names2(dots)[[1]] == "" &&
276+
missing(normalize)) {
277+
normalize <- as.integer(dots[[1L]])
278+
dots[[1L]] <- NULL
279+
}
280+
args <- c(list(inputs), dots,
281+
axes = as.integer(axes), normalize = normalize)
282+
return(do.call(keras$layers$dot, args))
283+
}
284+
285+
# inputs is not a list
286+
dots <- split_dots_named_unnamed(list(...))
287+
inputs <- c(inputs, dots$unnamed)
288+
args <- c(list(inputs),
289+
dots$named,
290+
axes = as.integer(axes),
291+
normalize = normalize)
292+
do.call(keras$layers$dot, args)
225293
}

R/utils.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,14 @@ is_keras_tensor <- function(x) {
418418
}
419419

420420

421+
split_dots_named_unnamed <- function(dots) {
422+
nms <- names(dots)
423+
if (is.null(nms))
424+
return(list(unnamed = dots, named = list()))
425+
named <- nzchar(nms)
426+
list(unnamed = dots[!named], named = dots[named])
427+
}
428+
421429

422430
assert_all_dots_named <- function(envir = parent.frame(), cl) {
423431

@@ -437,6 +445,8 @@ assert_all_dots_named <- function(envir = parent.frame(), cl) {
437445
# TODO: should there be some default modifiers in capture_args() for standard layer args
438446
# like, input_shape, batch_input_shape, etc.
439447

448+
449+
440450
capture_args <- function(cl, modifiers = NULL, ignore = NULL,
441451
envir = parent.frame(), fn = sys.function(-1)) {
442452

keras.Rproj

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,5 @@ StripTrailingWhitespace: Yes
1717

1818
BuildType: Package
1919
PackageUseDevtools: Yes
20-
PackageCleanBeforeInstall: Yes
2120
PackageInstallArgs: --no-multiarch --with-keep.source
2221
PackageRoxygenize: rd,collate,namespace

man/layer_add.Rd

Lines changed: 2 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_average.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)