@@ -104,32 +104,40 @@ def layer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"):
104104 return norm_x * scale + bias
105105
106106
107- def batch_norm (x , is_training , momentum , epsilon = 1e-9 , name = None ):
107+ def batch_norm (x , is_training , momentum , epsilon = 1e-9 ,
108+ init_zero = False , name = None ):
108109 """Batch normalization.
109110
110111 Args:
111112 x: a mtf.Tensor whose shape contains [batch_dim, ..., dim]
112113 is_training: a boolean, whether mode is training.
113114 momentum: a floating point number, specifying batch norm decay value.
114115 epsilon: a floating point number.
116+ init_zero: a boolean, whether to initialize scale with 0's or 1's.
115117 name: a string. variable scope.
116118
117119 Returns:
118120 a mtf.Tensor with same shape as x.
119121 """
120122 with tf .variable_scope (name , default_name = "batch_norm" , values = [x ]):
121- batch_dim = x .shape .dims [0 ]
122- reduced_shape = x .shape - batch_dim
123+ if init_zero :
124+ gamma_initializer = tf .zeros_initializer ()
125+ else :
126+ gamma_initializer = tf .ones_initializer ()
127+
128+ norm_dim = x .shape .dims [0 :3 ]
129+ reduced_shape = x .shape - norm_dim
130+
123131 scale = mtf .get_variable (
124132 x .mesh ,
125133 "batch_norm_scale" ,
126- mtf . Shape ([ batch_dim ]) ,
127- initializer = tf . ones_initializer () ,
134+ reduced_shape ,
135+ initializer = gamma_initializer ,
128136 activation_dtype = x .dtype )
129137 bias = mtf .get_variable (
130138 x .mesh ,
131139 "batch_norm_bias" ,
132- mtf . Shape ([ batch_dim ]) ,
140+ reduced_shape ,
133141 initializer = tf .zeros_initializer (),
134142 activation_dtype = x .dtype )
135143
@@ -150,6 +158,7 @@ def batch_norm(x, is_training, momentum, epsilon=1e-9, name=None):
150158 mean = mtf .reduce_mean (x , output_shape = reduced_shape )
151159 variance = mtf .reduce_mean (
152160 mtf .square (x - mean ), output_shape = reduced_shape )
161+
153162 norm_x = (x - mean ) * mtf .rsqrt (variance + epsilon )
154163
155164 # Update running mean and running variance.
@@ -161,7 +170,8 @@ def batch_norm(x, is_training, momentum, epsilon=1e-9, name=None):
161170 else :
162171 # At eval and test time, use the running mean and variance.
163172 norm_x = (x - moving_mean ) * mtf .rsqrt (moving_variance + epsilon )
164- return norm_x * scale + bias
173+
174+ return (norm_x * scale ) + bias
165175
166176
167177def softmax_cross_entropy_with_logits (logits , targets , vocab_dim ):
0 commit comments