2525# ' @param gamma_regularizer Optional regularizer for the gamma weight.
2626# ' @param beta_constraint Optional constraint for the beta weight.
2727# ' @param gamma_constraint Optional constraint for the gamma weight.
28- # '
28+ # ' @param renorm Whether to use Batch Renormalization
29+ # ' (https://arxiv.org/abs/1702.03275). This adds extra variables during
30+ # ' training. The inference is the same for either value of this parameter.
31+ # ' @param renorm_clipping A named list or dictionary that may map keys `rmax`,
32+ # ' `rmin`, `dmax` to scalar Tensors used to clip the renorm correction. The
33+ # ' correction `(r, d)` is used as `corrected_value = normalized_value * r + d`,
34+ # ' with `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing `rmax`,
35+ # ' `rmin`, `dmax` are set to `Inf`, `0`, `Inf`, `respectively`.
36+ # ' @param renorm_momentum Momentum used to update the moving means and standard
37+ # ' deviations with renorm. Unlike momentum, this affects training and should
38+ # ' be neither too small (which would add noise) nor too large (which would
39+ # ' give stale estimates). Note that momentum is still applied to get the means
40+ # ' and variances for inference.
41+ # ' @param fused `TRUE`, use a faster, fused implementation, or raise a ValueError
42+ # ' if the fused implementation cannot be used. If `NULL`, use the faster
43+ # ' implementation if possible. If `FALSE`, do not use the fused implementation.
44+ # ' @param virtual_batch_size An integer. By default, virtual_batch_size is `NULL`,
45+ # ' which means batch normalization is performed across the whole batch.
46+ # ' When virtual_batch_size is not `NULL`, instead perform "Ghost Batch
47+ # ' Normalization", which creates virtual sub-batches which are each normalized
48+ # ' separately (with shared gamma, beta, and moving statistics). Must divide
49+ # ' the actual `batch size` during execution.
50+ # ' @param adjustment A function taking the Tensor containing the (dynamic) shape
51+ # ' of the input tensor and returning a pair `(scale, bias)` to apply to the
52+ # ' normalized values `(before gamma and beta)`, only during training.
53+ # ' For example, if `axis==-1`,
54+ # ' \code{adjustment <- function(shape) {
55+ # ' tuple(tf$random$uniform(shape[-1:NULL, style = "python"], 0.93, 1.07),
56+ # ' tf$random$uniform(shape[-1:NULL, style = "python"], -0.1, 0.1))
57+ # ' }}
58+ # ' will scale the normalized value
59+ # ' by up to 7% up or down, then shift the result by up to 0.1 (with
60+ # ' independent scaling and bias for each feature but shared across all examples),
61+ # ' and finally apply gamma and/or beta. If `NULL`, no adjustment is applied.
62+ # ' Cannot be specified if virtual_batch_size is specified.
2963# ' @section Input shape: Arbitrary. Use the keyword argument `input_shape` (list
3064# ' of integers, does not include the samples axis) when using this layer as
3165# ' the first layer in a model.
3670# ' - [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
3771# '
3872# ' @export
39- layer_batch_normalization <- function (object , axis = - 1L , momentum = 0.99 , epsilon = 0.001 , center = TRUE , scale = TRUE ,
40- beta_initializer = " zeros" , gamma_initializer = " ones" ,
41- moving_mean_initializer = " zeros" , moving_variance_initializer = " ones" ,
42- beta_regularizer = NULL , gamma_regularizer = NULL ,
43- beta_constraint = NULL , gamma_constraint = NULL ,
44- input_shape = NULL , batch_input_shape = NULL , batch_size = NULL ,
45- dtype = NULL , name = NULL , trainable = NULL , weights = NULL ) {
73+ layer_batch_normalization <- function (object , axis = - 1L , momentum = 0.99 , epsilon = 0.001 , center = TRUE , scale = TRUE ,
74+ beta_initializer = " zeros" , gamma_initializer = " ones" ,
75+ moving_mean_initializer = " zeros" , moving_variance_initializer = " ones" ,
76+ beta_regularizer = NULL , gamma_regularizer = NULL , beta_constraint = NULL ,
77+ gamma_constraint = NULL , renorm = FALSE , renorm_clipping = NULL ,
78+ renorm_momentum = 0.99 , fused = NULL , virtual_batch_size = NULL ,
79+ adjustment = NULL , input_shape = NULL , batch_input_shape = NULL ,
80+ batch_size = NULL , dtype = NULL , name = NULL , trainable = NULL , weights = NULL ) {
81+
82+ stopifnot(is.null(adjustment ) || is.function(adjustment ))
83+
4684 create_layer(keras $ layers $ BatchNormalization , object , list (
4785 axis = as.integer(axis ),
4886 momentum = momentum ,
@@ -57,12 +95,18 @@ layer_batch_normalization <- function(object, axis = -1L, momentum = 0.99, epsil
5795 gamma_regularizer = gamma_regularizer ,
5896 beta_constraint = beta_constraint ,
5997 gamma_constraint = gamma_constraint ,
98+ renorm = renorm ,
99+ renorm_clipping = renorm_clipping ,
100+ renorm_momentum = renorm_momentum ,
101+ fused = fused ,
60102 input_shape = normalize_shape(input_shape ),
61103 batch_input_shape = normalize_shape(batch_input_shape ),
62104 batch_size = as_nullable_integer(batch_size ),
63105 dtype = dtype ,
64106 name = name ,
65107 trainable = trainable ,
108+ virtual_batch_size = as_nullable_integer(virtual_batch_size ),
109+ adjustment = adjustment ,
66110 weights = weights
67111 ))
68- }
112+ }
0 commit comments