Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit a9d6a74

Browse files
Niki ParmarCopybara-Service
authored andcommitted
ResNet changes for MeshTF
PiperOrigin-RevId: 218251907
1 parent a759cfb commit a9d6a74

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

mesh_tensorflow/layers.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,32 +104,40 @@ def layer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"):
104104
return norm_x * scale + bias
105105

106106

107-
def batch_norm(x, is_training, momentum, epsilon=1e-9, name=None):
107+
def batch_norm(x, is_training, momentum, epsilon=1e-9,
108+
init_zero=False, name=None):
108109
"""Batch normalization.
109110
110111
Args:
111112
x: a mtf.Tensor whose shape contains [batch_dim, ..., dim]
112113
is_training: a boolean, whether mode is training.
113114
momentum: a floating point number, specifying batch norm decay value.
114115
epsilon: a floating point number.
116+
init_zero: a boolean, whether to initialize scale with 0's or 1's.
115117
name: a string. variable scope.
116118
117119
Returns:
118120
a mtf.Tensor with same shape as x.
119121
"""
120122
with tf.variable_scope(name, default_name="batch_norm", values=[x]):
121-
batch_dim = x.shape.dims[0]
122-
reduced_shape = x.shape - batch_dim
123+
if init_zero:
124+
gamma_initializer = tf.zeros_initializer()
125+
else:
126+
gamma_initializer = tf.ones_initializer()
127+
128+
norm_dim = x.shape.dims[0:3]
129+
reduced_shape = x.shape - norm_dim
130+
123131
scale = mtf.get_variable(
124132
x.mesh,
125133
"batch_norm_scale",
126-
mtf.Shape([batch_dim]),
127-
initializer=tf.ones_initializer(),
134+
reduced_shape,
135+
initializer=gamma_initializer,
128136
activation_dtype=x.dtype)
129137
bias = mtf.get_variable(
130138
x.mesh,
131139
"batch_norm_bias",
132-
mtf.Shape([batch_dim]),
140+
reduced_shape,
133141
initializer=tf.zeros_initializer(),
134142
activation_dtype=x.dtype)
135143

@@ -150,6 +158,7 @@ def batch_norm(x, is_training, momentum, epsilon=1e-9, name=None):
150158
mean = mtf.reduce_mean(x, output_shape=reduced_shape)
151159
variance = mtf.reduce_mean(
152160
mtf.square(x - mean), output_shape=reduced_shape)
161+
153162
norm_x = (x - mean) * mtf.rsqrt(variance + epsilon)
154163

155164
# Update running mean and running variance.
@@ -161,7 +170,8 @@ def batch_norm(x, is_training, momentum, epsilon=1e-9, name=None):
161170
else:
162171
# At eval and test time, use the running mean and variance.
163172
norm_x = (x - moving_mean) * mtf.rsqrt(moving_variance + epsilon)
164-
return norm_x * scale + bias
173+
174+
return (norm_x * scale) + bias
165175

166176

167177
def softmax_cross_entropy_with_logits(logits, targets, vocab_dim):

0 commit comments

Comments
 (0)