Skip to content

Commit 687d341

Browse files
authored
Merge pull request #958 from ifrit98/master
expose renorm, renorm_clipping to layer_batch_normalization
2 parents e3f62ae + 412e264 commit 687d341

File tree

3 files changed

+101
-10
lines changed

3 files changed

+101
-10
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,5 @@ Suggests:
4343
jpeg
4444
SystemRequirements: Keras >= 2.0 (https://keras.io)
4545
Roxygen: list(markdown = TRUE, r6 = FALSE)
46-
RoxygenNote: 7.0.1
46+
RoxygenNote: 7.0.2
4747
VignetteBuilder: knitr

R/layers-normalization.R

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,41 @@
2525
#' @param gamma_regularizer Optional regularizer for the gamma weight.
2626
#' @param beta_constraint Optional constraint for the beta weight.
2727
#' @param gamma_constraint Optional constraint for the gamma weight.
28-
#'
28+
#' @param renorm Whether to use Batch Renormalization
29+
#' (https://arxiv.org/abs/1702.03275). This adds extra variables during
30+
#' training. The inference is the same for either value of this parameter.
31+
#' @param renorm_clipping A named list or dictionary that may map keys `rmax`,
32+
#' `rmin`, `dmax` to scalar Tensors used to clip the renorm correction. The
33+
#' correction `(r, d)` is used as `corrected_value = normalized_value * r + d`,
34+
#' with `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing `rmax`,
35+
#' `rmin`, `dmax` are set to `Inf`, `0`, `Inf`, `respectively`.
36+
#' @param renorm_momentum Momentum used to update the moving means and standard
37+
#' deviations with renorm. Unlike momentum, this affects training and should
38+
#' be neither too small (which would add noise) nor too large (which would
39+
#' give stale estimates). Note that momentum is still applied to get the means
40+
#' and variances for inference.
41+
#' @param fused `TRUE`, use a faster, fused implementation, or raise a ValueError
42+
#' if the fused implementation cannot be used. If `NULL`, use the faster
43+
#' implementation if possible. If `FALSE`, do not use the fused implementation.
44+
#' @param virtual_batch_size An integer. By default, virtual_batch_size is `NULL`,
45+
#' which means batch normalization is performed across the whole batch.
46+
#' When virtual_batch_size is not `NULL`, instead perform "Ghost Batch
47+
#' Normalization", which creates virtual sub-batches which are each normalized
48+
#' separately (with shared gamma, beta, and moving statistics). Must divide
49+
#' the actual `batch size` during execution.
50+
#' @param adjustment A function taking the Tensor containing the (dynamic) shape
51+
#' of the input tensor and returning a pair `(scale, bias)` to apply to the
52+
#' normalized values `(before gamma and beta)`, only during training.
53+
#' For example, if `axis==-1`,
54+
#' \code{adjustment <- function(shape) {
55+
#' tuple(tf$random$uniform(shape[-1:NULL, style = "python"], 0.93, 1.07),
56+
#' tf$random$uniform(shape[-1:NULL, style = "python"], -0.1, 0.1))
57+
#' }}
58+
#' will scale the normalized value
59+
#' by up to 7% up or down, then shift the result by up to 0.1 (with
60+
#' independent scaling and bias for each feature but shared across all examples),
61+
#' and finally apply gamma and/or beta. If `NULL`, no adjustment is applied.
62+
#' Cannot be specified if virtual_batch_size is specified.
2963
#' @section Input shape: Arbitrary. Use the keyword argument `input_shape` (list
3064
#' of integers, does not include the samples axis) when using this layer as
3165
#' the first layer in a model.
@@ -36,13 +70,17 @@
3670
#' - [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
3771
#'
3872
#' @export
39-
layer_batch_normalization <- function(object, axis = -1L, momentum = 0.99, epsilon = 0.001, center = TRUE, scale = TRUE,
40-
beta_initializer = "zeros", gamma_initializer = "ones",
41-
moving_mean_initializer = "zeros", moving_variance_initializer = "ones",
42-
beta_regularizer = NULL, gamma_regularizer = NULL,
43-
beta_constraint = NULL, gamma_constraint = NULL,
44-
input_shape = NULL, batch_input_shape = NULL, batch_size = NULL,
45-
dtype = NULL, name = NULL, trainable = NULL, weights = NULL) {
73+
layer_batch_normalization <- function(object, axis = -1L, momentum = 0.99, epsilon = 0.001, center = TRUE, scale = TRUE,
74+
beta_initializer = "zeros", gamma_initializer = "ones",
75+
moving_mean_initializer = "zeros", moving_variance_initializer = "ones",
76+
beta_regularizer = NULL, gamma_regularizer = NULL, beta_constraint = NULL,
77+
gamma_constraint = NULL, renorm = FALSE, renorm_clipping = NULL,
78+
renorm_momentum = 0.99, fused = NULL, virtual_batch_size = NULL,
79+
adjustment = NULL, input_shape = NULL, batch_input_shape = NULL,
80+
batch_size = NULL, dtype = NULL, name = NULL, trainable = NULL, weights = NULL) {
81+
82+
stopifnot(is.null(adjustment) || is.function(adjustment))
83+
4684
create_layer(keras$layers$BatchNormalization, object, list(
4785
axis = as.integer(axis),
4886
momentum = momentum,
@@ -57,12 +95,18 @@ layer_batch_normalization <- function(object, axis = -1L, momentum = 0.99, epsil
5795
gamma_regularizer = gamma_regularizer,
5896
beta_constraint = beta_constraint,
5997
gamma_constraint = gamma_constraint,
98+
renorm = renorm,
99+
renorm_clipping = renorm_clipping,
100+
renorm_momentum = renorm_momentum,
101+
fused = fused,
60102
input_shape = normalize_shape(input_shape),
61103
batch_input_shape = normalize_shape(batch_input_shape),
62104
batch_size = as_nullable_integer(batch_size),
63105
dtype = dtype,
64106
name = name,
65107
trainable = trainable,
108+
virtual_batch_size = as_nullable_integer(virtual_batch_size),
109+
adjustment = adjustment,
66110
weights = weights
67111
))
68-
}
112+
}

man/layer_batch_normalization.Rd

Lines changed: 47 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)