55# ' It takes as input a list of tensors, all of the same shape, and returns a
66# ' single tensor (also of the same shape).
77# '
8- # ' @param inputs A list of input tensors (at least 2) . Can be missing.
9- # ' @param ... Standard layer arguments (must be named) .
8+ # ' @param inputs A input tensor, or list of input tensors. Can be missing.
9+ # ' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments .
1010# '
1111# ' @return A tensor, the sum of the inputs. If `inputs` is missing, a keras
1212# ' layer instance is returned.
1515# '
1616# ' @seealso
1717# ' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/add>
18- # ' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/Add>
1918# ' + <https://keras.io/api/layers/merging_layers/add>
2019# '
2120# ' @export
2221layer_add <- function (inputs , ... ) {
23- callable <- if (missing(inputs )) keras $ layers $ Add else keras $ layers $ add
24- args <- capture_args(match.call(), list (batch_size = as_nullable_integer ))
25- do.call(callable , args )
22+ if (missing(inputs ))
23+ return (keras $ layers $ Add(... ))
24+ if (! is.list(inputs ))
25+ inputs <- list (inputs )
26+ dots <- split_dots_named_unnamed(list (... ))
27+ inputs <- c(inputs , dots $ unnamed )
28+ do.call(keras $ layers $ add , c(list (inputs ), dots $ named ))
2629}
2730
31+
2832# TODO: there should be a common topic where we can use
2933# @inheritDotParams standard-layer-args
3034
@@ -35,8 +39,8 @@ layer_add <- function(inputs, ...) {
3539# ' returns a single tensor, (`inputs[[1]] - inputs[[2]]`), also of the same
3640# ' shape.
3741# '
38- # ' @param inputs A list of input tensors (exactly 2) . Can be missing.
39- # ' @param ... Standard layer arguments (must be named) .
42+ # ' @param inputs A input tensor, or list of two input tensors. Can be missing.
43+ # ' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments .
4044# '
4145# ' @return A tensor, the difference of the inputs. If `inputs` is missing, a
4246# ' keras layer instance is returned.
@@ -46,23 +50,27 @@ layer_add <- function(inputs, ...) {
4650# '
4751# ' @seealso
4852# ' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/subtract>
49- # ' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/Subtract>
5053# ' + <https://keras.io/api/layers/merging_layers/subtract>
5154# '
5255# ' @export
5356layer_subtract <- function (inputs , ... ) {
54- callable <- if (missing(inputs )) keras $ layers $ Subtract else keras $ layers $ subtract
55- args <- capture_args(match.call(), list (batch_size = as_nullable_integer ))
56- do.call(callable , args )
57+ if (missing(inputs ))
58+ return (keras $ layers $ Subtract(... ))
59+ if (! is.list(inputs ))
60+ inputs <- list (inputs )
61+ dots <- split_dots_named_unnamed(list (... ))
62+ inputs <- c(inputs , dots $ unnamed )
63+ do.call(keras $ layers $ subtract , c(list (inputs ), dots $ named ))
5764}
5865
66+
5967# ' Layer that multiplies (element-wise) a list of inputs.
6068# '
6169# ' It takes as input a list of tensors, all of the same shape, and returns a
6270# ' single tensor (also of the same shape).
6371# '
64- # ' @param inputs A list of input tensors (at least 2) . Can be missing.
65- # ' @param ... Standard layer arguments (must be named) .
72+ # ' @param inputs A input tensor, or list of input tensors. Can be missing.
73+ # ' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments .
6674# '
6775# ' @return A tensor, the element-wise product of the inputs. If `inputs` is
6876# ' missing, a keras layer instance is returned.
@@ -72,15 +80,17 @@ layer_subtract <- function(inputs, ...) {
7280# '
7381# ' @seealso
7482# ' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/multiply>
75- # ' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/Multiply>
7683# ' + <https://keras.io/api/layers/merging_layers/multiply>
7784# '
7885# ' @export
7986layer_multiply <- function (inputs , ... ) {
80- callable <- if (missing(inputs )) keras $ layers $ Multiply else keras $ layers $ multiply
81- args <- capture_args(match.call(), list (batch_size = as_nullable_integer ))
82- do.call(callable , args )
83-
87+ if (missing(inputs ))
88+ return (keras $ layers $ Multiply(... ))
89+ if (! is.list(inputs ))
90+ inputs <- list (inputs )
91+ dots <- split_dots_named_unnamed(list (... ))
92+ inputs <- c(inputs , dots $ unnamed )
93+ do.call(keras $ layers $ multiply , c(list (inputs ), dots $ named ))
8494}
8595
8696
@@ -89,8 +99,8 @@ layer_multiply <- function(inputs, ...) {
8999# ' It takes as input a list of tensors, all of the same shape, and returns a
90100# ' single tensor (also of the same shape).
91101# '
92- # ' @param inputs A list of input tensors (at least 2) . Can be missing.
93- # ' @param ... Standard layer arguments (must be named) .
102+ # ' @param inputs A input tensor, or list of input tensors. Can be missing.
103+ # ' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments .
94104# '
95105# ' @return A tensor, the average of the inputs. If `inputs` is missing, a keras
96106# ' layer instance is returned.
@@ -105,19 +115,22 @@ layer_multiply <- function(inputs, ...) {
105115# '
106116# ' @export
107117layer_average <- function (inputs , ... ) {
108- callable <- if (missing(inputs )) keras $ layers $ Average else keras $ layers $ average
109- args <- capture_args(match.call(), list (batch_size = as_nullable_integer ))
110- do.call(callable , args )
111-
118+ if (missing(inputs ))
119+ return (keras $ layers $ Average(... ))
120+ if (! is.list(inputs ))
121+ inputs <- list (inputs )
122+ dots <- split_dots_named_unnamed(list (... ))
123+ inputs <- c(inputs , dots $ unnamed )
124+ do.call(keras $ layers $ average , c(list (inputs ), dots $ named ))
112125}
113126
114127# ' Layer that computes the maximum (element-wise) a list of inputs.
115128# '
116129# ' It takes as input a list of tensors, all of the same shape, and returns a
117130# ' single tensor (also of the same shape).
118131# '
119- # ' @param inputs A list of input tensors (at least 2) . Can be missing.
120- # ' @param ... Standard layer arguments (must be named) .
132+ # ' @param inputs A input tensor, or list of input tensors. Can be missing.
133+ # ' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments .
121134# '
122135# ' @return A tensor, the element-wise maximum of the inputs. If `inputs` is
123136# ' missing, a keras layer instance is returned.
@@ -132,10 +145,13 @@ layer_average <- function(inputs, ...) {
132145# '
133146# ' @export
134147layer_maximum <- function (inputs , ... ) {
135- callable <- if (missing(inputs )) keras $ layers $ Maximum else keras $ layers $ maximum
136- args <- capture_args(match.call(), list (batch_size = as_nullable_integer ))
137- do.call(callable , args )
138-
148+ if (missing(inputs ))
149+ return (keras $ layers $ Maximum(... ))
150+ if (! is.list(inputs ))
151+ inputs <- list (inputs )
152+ dots <- split_dots_named_unnamed(list (... ))
153+ inputs <- c(inputs , dots $ unnamed )
154+ do.call(keras $ layers $ maximum , c(list (inputs ), dots $ named ))
139155}
140156
141157
@@ -144,8 +160,8 @@ layer_maximum <- function(inputs, ...) {
144160# ' It takes as input a list of tensors, all of the same shape, and returns a
145161# ' single tensor (also of the same shape).
146162# '
147- # ' @param inputs A list of input tensors (at least 2) . Can be missing.
148- # ' @param ... Standard layer arguments (must be named) .
163+ # ' @param inputs A input tensor, or list of input tensors. Can be missing.
164+ # ' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments .
149165# '
150166# ' @return A tensor, the element-wise maximum of the inputs. If `inputs` is
151167# ' missing, a keras layer instance is returned.
@@ -159,9 +175,13 @@ layer_maximum <- function(inputs, ...) {
159175# '
160176# ' @export
161177layer_minimum <- function (inputs , ... ) {
162- callable <- if (missing(inputs )) keras $ layers $ Minimum else keras $ layers $ minimum
163- args <- capture_args(match.call(), list (batch_size = as_nullable_integer ))
164- do.call(callable , args )
178+ if (missing(inputs ))
179+ return (keras $ layers $ Minimum(... ))
180+ if (! is.list(inputs ))
181+ inputs <- list (inputs )
182+ dots <- split_dots_named_unnamed(list (... ))
183+ inputs <- c(inputs , dots $ unnamed )
184+ do.call(keras $ layers $ minimum , c(list (inputs ), dots $ named ))
165185}
166186
167187
@@ -171,9 +191,9 @@ layer_minimum <- function(inputs, ...) {
171191# ' concatenation axis, and returns a single tensor, the concatenation of all
172192# ' inputs.
173193# '
174- # ' @param inputs A list of input tensors (at least 2). Can be missing.
194+ # ' @param inputs A input tensor, or list of input tensors. Can be missing.
195+ # ' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments.
175196# ' @param axis Concatenation axis.
176- # ' @param ... Standard layer arguments (must be named).
177197# '
178198# ' @return A tensor, the concatenation of the inputs alongside axis `axis`. If
179199# ' `inputs` is missing, a keras layer instance is returned.
@@ -186,23 +206,43 @@ layer_minimum <- function(inputs, ...) {
186206# ' + <https://keras.io/api/layers/merging_layers/concatenate>
187207# '
188208# ' @export
189- layer_concatenate <- function (inputs , axis = - 1 , ... ) {
190- callable <- if (missing(inputs )) keras $ layers $ Concatenate else keras $ layers $ concatenate
191- # TODO: this axis should probably be 1-based
192- args <- capture_args(match.call(), list (batch_size = as_nullable_integer ,
193- axis = as.integer ))
194- do.call(callable , args )
209+ layer_concatenate <- function (inputs , ... , axis = - 1 ) {
210+ if (missing(inputs )) {
211+ args <- capture_args(match.call(), list (axis = as.integer ))
212+ return (do.call(keras $ layers $ Concatenate , args ))
213+ }
214+
215+ # TODO: this axis arg should probably be 1-based
216+
217+ if (is.list(inputs )) {
218+ # backcompat: axis used to be in 2nd position, inputs used to accept only a list.
219+
220+ dots <- list (... )
221+ if (length(dots ) && names2(dots )[[1 ]] == " " &&
222+ missing(axis ) &&
223+ is.numeric(dots [[1L ]]) &&
224+ is_scalar(dots [[1L ]])) {
225+ axis <- as.integer(dots [[1L ]])
226+ dots [[1L ]] <- NULL
227+ }
228+ return (do.call(keras $ layers $ concatenate ,
229+ c(list (inputs ), dots , axis = as.integer(axis ))))
230+ }
231+
232+ dots <- split_dots_named_unnamed(list (... ))
233+ inputs <- c(list (inputs ), dots $ unnamed )
234+ do.call(keras $ layers $ concatenate , c(list (inputs ), dots $ named ))
195235}
196236
197237# ' Layer that computes a dot product between samples in two tensors.
198238# '
199- # ' @param inputs A list of input tensors (at least 2). Can be missing.
239+ # ' @param inputs A input tensor, or list of input tensors. Can be missing.
240+ # ' @param ... Unnamed args are treated as additional `inputs`. Named arguments are passed on as standard layer arguments.
200241# ' @param axes Integer or list of integers, axis or axes along which to take the
201242# ' dot product.
202243# ' @param normalize Whether to L2-normalize samples along the dot product axis
203244# ' before taking the dot product. If set to TRUE, then the output of the dot
204245# ' product is the cosine proximity between the two samples.
205- # ' @param ... Standard layer arguments (must be named).
206246# '
207247# ' @return If `inputs` is supplied: A tensor, the dot product of the samples
208248# ' from the inputs. If `inputs` is missing, a keras layer instance is
@@ -213,13 +253,41 @@ layer_concatenate <- function(inputs, axis = -1, ...) {
213253# '
214254# ' @seealso
215255# ' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/dot>
216- # ' + <https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dot>
217256# ' + <https://keras.io/api/layers/merging_layers/dot/>
218257# '
219258# ' @export
220- layer_dot <- function (inputs , axes , normalize = FALSE , ... ) {
221- callable <- if (missing(inputs )) keras $ layers $ Dot else keras $ layers $ dot
222- args <- capture_args(match.call(), list (batch_size = as_nullable_integer ,
223- axes = as.integer ))
224- do.call(callable , args )
259+ # ' @importFrom rlang names2
260+ layer_dot <- function (inputs , ... , axes , normalize = FALSE ) {
261+ if (missing(inputs )) {
262+ args <- capture_args(match.call(), list (axes = as.integer ))
263+ return (do.call(keras $ layers $ Dot , args ))
264+ }
265+
266+ if (is.list(inputs )) {
267+ # backcompat: inputs used to only accept a list of layers, and
268+ # axis, normalize, used to be in 2nd, 3rd position.
269+ dots <- list (... )
270+ if (length(dots ) && names2(dots )[[1 ]] == " " &&
271+ missing(axes )) {
272+ axes <- as.integer(dots [[1L ]])
273+ dots [[1L ]] <- NULL
274+ }
275+ if (length(dots ) && names2(dots )[[1 ]] == " " &&
276+ missing(normalize )) {
277+ normalize <- as.integer(dots [[1L ]])
278+ dots [[1L ]] <- NULL
279+ }
280+ args <- c(list (inputs ), dots ,
281+ axes = as.integer(axes ), normalize = normalize )
282+ return (do.call(keras $ layers $ dot , args ))
283+ }
284+
285+ # inputs is not a list
286+ dots <- split_dots_named_unnamed(list (... ))
287+ inputs <- c(inputs , dots $ unnamed )
288+ args <- c(list (inputs ),
289+ dots $ named ,
290+ axes = as.integer(axes ),
291+ normalize = normalize )
292+ do.call(keras $ layers $ dot , args )
225293}
0 commit comments