Skip to content

Commit f18e233

Browse files
committed
update recurrent layers
1 parent 101cc41 commit f18e233

File tree

5 files changed

+119
-220
lines changed

5 files changed

+119
-220
lines changed

NEWS.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,14 @@
2121
- `layer_conv_lstm_1d()`
2222
- `layer_conv_lstm_3d()`
2323

24-
- `layer_lstm()` default value for `recurrent_activation` changed from `"hard_sigmoid"` to `"sigmoid"`.
24+
- `layer_cudnn_gru()` and `layer_cudnn_lstm()` are deprecated.
25+
`layer_gru()` and `layer_lstm()` will automatically use CuDNN if it is available.
2526

26-
- `layer_cudnn_gru()` and `layer_cudnn_lstm()` are deprecated. `layer_gru()` and `layer_lstm()` will
27-
automatically use CuDNN if it is available.
27+
- `layer_lstm()` and `layer_gru()`:
28+
default value for `recurrent_activation` changed
29+
from `"hard_sigmoid"` to `"sigmoid"`.
30+
31+
- `layer_gru()`: default value `reset_after` changed from `FALSE` to `TRUE`
2832

2933
- New vignette: "Transfer learning and fine-tuning".
3034

@@ -54,7 +58,8 @@
5458
`name`, `trainable`, `weights`.
5559
Layers updated:
5660
`layer_global_{max,average}_pooling_{1,2,3}d()`,
57-
`time_distributed()`, `bidirectional()`.
61+
`time_distributed()`, `bidirectional()`,
62+
`layer_gru()`, `layer_lstm()`, `layer_simple_rnn()`
5863

5964
- All the backend function with a shape argument `k_*(shape =)` that now accept a
6065
a mix of integer tensors and R numerics in the supplied list.

R/layers-recurrent.R

Lines changed: 102 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#' linear transformation of the inputs.
4343
#' @param recurrent_dropout Float between 0 and 1. Fraction of the units to drop
4444
#' for the linear transformation of the recurrent state.
45+
#' @param ... Standard Layer args.
4546
#'
4647
#' @template roxlate-recurrent-layer
4748
#'
@@ -50,46 +51,36 @@
5051
#'
5152
#'
5253
#' @export
53-
layer_simple_rnn <- function(object, units, activation = "tanh", use_bias = TRUE,
54-
return_sequences = FALSE, return_state = FALSE, go_backwards = FALSE, stateful = FALSE, unroll = FALSE,
55-
kernel_initializer = "glorot_uniform", recurrent_initializer = "orthogonal", bias_initializer = "zeros",
56-
kernel_regularizer = NULL, recurrent_regularizer = NULL, bias_regularizer = NULL, activity_regularizer = NULL,
57-
kernel_constraint = NULL, recurrent_constraint = NULL, bias_constraint = NULL,
58-
dropout = 0.0, recurrent_dropout = 0.0, input_shape = NULL, batch_input_shape = NULL, batch_size = NULL,
59-
dtype = NULL, name = NULL, trainable = NULL, weights = NULL) {
60-
61-
args <- list(
62-
units = as.integer(units),
63-
activation = activation,
64-
use_bias = use_bias,
65-
return_sequences = return_sequences,
66-
go_backwards = go_backwards,
67-
stateful = stateful,
68-
unroll = unroll,
69-
kernel_initializer = kernel_initializer,
70-
recurrent_initializer = recurrent_initializer,
71-
bias_initializer = bias_initializer,
72-
kernel_regularizer = kernel_regularizer,
73-
recurrent_regularizer = recurrent_regularizer,
74-
bias_regularizer = bias_regularizer,
75-
activity_regularizer = activity_regularizer,
76-
kernel_constraint = kernel_constraint,
77-
recurrent_constraint = recurrent_constraint,
78-
bias_constraint = bias_constraint,
79-
dropout = dropout,
80-
recurrent_dropout = recurrent_dropout,
81-
input_shape = normalize_shape(input_shape),
82-
batch_input_shape = normalize_shape(batch_input_shape),
83-
batch_size = as_nullable_integer(batch_size),
84-
dtype = dtype,
85-
name = name,
86-
trainable = trainable,
87-
weights = weights
88-
)
89-
90-
if (keras_version() >= "2.0.5")
91-
args$return_state <- return_state
92-
54+
layer_simple_rnn <-
55+
function(object,
56+
units,
57+
activation = "tanh",
58+
use_bias = TRUE,
59+
return_sequences = FALSE,
60+
return_state = FALSE,
61+
go_backwards = FALSE,
62+
stateful = FALSE,
63+
unroll = FALSE,
64+
kernel_initializer = "glorot_uniform",
65+
recurrent_initializer = "orthogonal",
66+
bias_initializer = "zeros",
67+
kernel_regularizer = NULL,
68+
recurrent_regularizer = NULL,
69+
bias_regularizer = NULL,
70+
activity_regularizer = NULL,
71+
kernel_constraint = NULL,
72+
recurrent_constraint = NULL,
73+
bias_constraint = NULL,
74+
dropout = 0.0,
75+
recurrent_dropout = 0.0,
76+
...)
77+
{
78+
args <- capture_args(match.call(), list(
79+
units = as.integer,
80+
input_shape = normalize_shape,
81+
batch_input_shape = normalize_shape,
82+
batch_size = as_nullable_integer
83+
), ignore = "object")
9384
create_layer(keras$layers$SimpleRNN, object, args)
9485
}
9586

@@ -135,61 +126,51 @@ layer_simple_rnn <- function(object, units, activation = "tanh", use_bias = TRUE
135126
#' Networks](https://arxiv.org/abs/1512.05287)
136127
#'
137128
#' @export
138-
layer_gru <- function(object, units, activation = "tanh", recurrent_activation = "hard_sigmoid", use_bias = TRUE,
139-
return_sequences = FALSE, return_state = FALSE, go_backwards = FALSE, stateful = FALSE, unroll = FALSE,
140-
time_major = FALSE, reset_after = FALSE,
141-
kernel_initializer = "glorot_uniform", recurrent_initializer = "orthogonal", bias_initializer = "zeros",
142-
kernel_regularizer = NULL, recurrent_regularizer = NULL, bias_regularizer = NULL, activity_regularizer = NULL,
143-
kernel_constraint = NULL, recurrent_constraint = NULL, bias_constraint = NULL,
144-
dropout = 0.0, recurrent_dropout = 0.0, input_shape = NULL, batch_input_shape = NULL, batch_size = NULL,
145-
dtype = NULL, name = NULL, trainable = NULL, weights = NULL) {
146-
147-
args <- list(
148-
units = as.integer(units),
149-
activation = activation,
150-
recurrent_activation = recurrent_activation,
151-
use_bias = use_bias,
152-
return_sequences = return_sequences,
153-
go_backwards = go_backwards,
154-
stateful = stateful,
155-
unroll = unroll,
156-
time_major = time_major,
157-
kernel_initializer = kernel_initializer,
158-
recurrent_initializer = recurrent_initializer,
159-
bias_initializer = bias_initializer,
160-
kernel_regularizer = kernel_regularizer,
161-
recurrent_regularizer = recurrent_regularizer,
162-
bias_regularizer = bias_regularizer,
163-
activity_regularizer = activity_regularizer,
164-
kernel_constraint = kernel_constraint,
165-
recurrent_constraint = recurrent_constraint,
166-
bias_constraint = bias_constraint,
167-
dropout = dropout,
168-
recurrent_dropout = recurrent_dropout,
169-
input_shape = normalize_shape(input_shape),
170-
batch_input_shape = normalize_shape(batch_input_shape),
171-
batch_size = as_nullable_integer(batch_size),
172-
dtype = dtype,
173-
name = name,
174-
trainable = trainable,
175-
weights = weights
176-
)
177-
178-
if (keras_version() >= "2.0.5")
179-
args$return_state <- return_state
180-
181-
if (keras_version() >= "2.1.5")
182-
args$reset_after <- reset_after
183-
129+
layer_gru <-
130+
function(object,
131+
units,
132+
activation = "tanh",
133+
recurrent_activation = "sigmoid",
134+
use_bias = TRUE,
135+
return_sequences = FALSE,
136+
return_state = FALSE,
137+
go_backwards = FALSE,
138+
stateful = FALSE,
139+
unroll = FALSE,
140+
time_major = FALSE,
141+
reset_after = TRUE,
142+
kernel_initializer = "glorot_uniform",
143+
recurrent_initializer = "orthogonal",
144+
bias_initializer = "zeros",
145+
kernel_regularizer = NULL,
146+
recurrent_regularizer = NULL,
147+
bias_regularizer = NULL,
148+
activity_regularizer = NULL,
149+
kernel_constraint = NULL,
150+
recurrent_constraint = NULL,
151+
bias_constraint = NULL,
152+
dropout = 0.0,
153+
recurrent_dropout = 0.0,
154+
...)
155+
{
156+
args <- capture_args(match.call(), list(
157+
units = as.integer,
158+
input_shape = normalize_shape,
159+
batch_input_shape = normalize_shape,
160+
batch_size = as_nullable_integer
161+
), ignore = "object")
184162
create_layer(keras$layers$GRU, object, args)
185163
}
186164

187165

166+
167+
188168
#' Fast GRU implementation backed by [CuDNN](https://developer.nvidia.com/cudnn).
189169
#'
190170
#' Can only be run on GPU, with the TensorFlow backend.
191171
#'
192172
#' @inheritParams layer_simple_rnn
173+
#' @inheritParams layer_dense
193174
#'
194175
#' @family recurrent layers
195176
#'
@@ -264,50 +245,40 @@ layer_cudnn_gru <- function(object, units,
264245
#' @family recurrent layers
265246
#'
266247
#' @export
267-
layer_lstm <- function(object, units, activation = "tanh", recurrent_activation = "sigmoid", use_bias = TRUE,
268-
return_sequences = FALSE, return_state = FALSE, go_backwards = FALSE, stateful = FALSE,
269-
time_major = FALSE, unroll = FALSE,
270-
kernel_initializer = "glorot_uniform", recurrent_initializer = "orthogonal", bias_initializer = "zeros",
271-
unit_forget_bias = TRUE, kernel_regularizer = NULL, recurrent_regularizer = NULL, bias_regularizer = NULL,
272-
activity_regularizer = NULL, kernel_constraint = NULL, recurrent_constraint = NULL, bias_constraint = NULL,
273-
dropout = 0.0, recurrent_dropout = 0.0, input_shape = NULL, batch_input_shape = NULL, batch_size = NULL,
274-
dtype = NULL, name = NULL, trainable = NULL, weights = NULL) {
275-
276-
args <- list(
277-
units = as.integer(units),
278-
activation = activation,
279-
recurrent_activation = recurrent_activation,
280-
use_bias = use_bias,
281-
return_sequences = return_sequences,
282-
go_backwards = go_backwards,
283-
stateful = stateful,
284-
time_major = time_major,
285-
unroll = unroll,
286-
kernel_initializer = kernel_initializer,
287-
recurrent_initializer = recurrent_initializer,
288-
bias_initializer = bias_initializer,
289-
unit_forget_bias = unit_forget_bias,
290-
kernel_regularizer = kernel_regularizer,
291-
recurrent_regularizer = recurrent_regularizer,
292-
bias_regularizer = bias_regularizer,
293-
activity_regularizer = activity_regularizer,
294-
kernel_constraint = kernel_constraint,
295-
recurrent_constraint = recurrent_constraint,
296-
bias_constraint = bias_constraint,
297-
dropout = dropout,
298-
recurrent_dropout = recurrent_dropout,
299-
input_shape = normalize_shape(input_shape),
300-
batch_input_shape = normalize_shape(batch_input_shape),
301-
batch_size = as_nullable_integer(batch_size),
302-
dtype = dtype,
303-
name = name,
304-
trainable = trainable,
305-
weights = weights
306-
)
307-
308-
if (keras_version() >= "2.0.5")
309-
args$return_state <- return_state
310-
248+
layer_lstm <-
249+
function(object,
250+
units,
251+
activation = "tanh",
252+
recurrent_activation = "sigmoid",
253+
use_bias = TRUE,
254+
return_sequences = FALSE,
255+
return_state = FALSE,
256+
go_backwards = FALSE,
257+
stateful = FALSE,
258+
time_major = FALSE,
259+
unroll = FALSE,
260+
kernel_initializer = "glorot_uniform",
261+
recurrent_initializer = "orthogonal",
262+
bias_initializer = "zeros",
263+
unit_forget_bias = TRUE,
264+
kernel_regularizer = NULL,
265+
recurrent_regularizer = NULL,
266+
bias_regularizer = NULL,
267+
activity_regularizer = NULL,
268+
kernel_constraint = NULL,
269+
recurrent_constraint = NULL,
270+
bias_constraint = NULL,
271+
dropout = 0.0,
272+
recurrent_dropout = 0.0,
273+
...
274+
)
275+
{
276+
args <- capture_args(match.call(), list(
277+
units = as.integer,
278+
input_shape = normalize_shape,
279+
batch_input_shape = normalize_shape,
280+
batch_size = as_nullable_integer
281+
), ignore = "object")
311282
create_layer(keras$layers$LSTM, object, args)
312283
}
313284

@@ -316,6 +287,7 @@ layer_lstm <- function(object, units, activation = "tanh", recurrent_activation
316287
#' Can only be run on GPU, with the TensorFlow backend.
317288
#'
318289
#' @inheritParams layer_lstm
290+
#' @inheritParams layer_dense
319291
#'
320292
#' @section References:
321293
#' - [Long short-term memory](https://www.bioinf.jku.at/publications/older/2604.pdf) (original 1997 paper)

man/layer_gru.Rd

Lines changed: 4 additions & 30 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)