Skip to content

Commit 36df2f3

Browse files
lingvo-botcopybara-github
authored andcommitted
Add support for GroupNormLayer with input_rank=3
PiperOrigin-RevId: 477580174
1 parent 535a9c3 commit 36df2f3

File tree

1 file changed

+48
-29
lines changed

1 file changed

+48
-29
lines changed

lingvo/core/conformer_layer.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -260,28 +260,30 @@ def _Normalize(self, theta, inputs, paddings):
260260
261261
Args:
262262
theta: A NestedMap of layer params.
263-
inputs: [b, t, 1, d].
263+
inputs: [b, t, d].
264264
paddings: [b, t].
265265
266266
Returns:
267267
A Tensor of shape [b, t, d].
268268
"""
269269
if isinstance(self.norm, bn_layers.GroupNormLayer):
270-
assert self.norm.params.input_rank == 4
270+
gn_input_rank = self.norm.params.input_rank
271+
if gn_input_rank == 4:
272+
tf.logging.info(
273+
'Using GroupNormLayer with input_rank=4, causing extra reshapes. '
274+
'Set norm.params.input_rank=3.')
275+
inputs = tf.expand_dims(inputs, 2)
271276
inputs, _ = self.norm.FProp(theta.norm, inputs, paddings)
272-
# [b, t, d]
273-
inputs = tf.squeeze(inputs, 2)
277+
if gn_input_rank == 4:
278+
inputs = tf.squeeze(inputs, 2)
279+
elif isinstance(self.norm, bn_layers.BatchNormLayer):
280+
inputs = self.norm.FProp(theta.norm, inputs, paddings)
281+
elif isinstance(self.norm, layers.LayerNorm):
282+
inputs = self.norm.FProp(theta.norm, inputs)
274283
else:
275-
# [b, t, 1, d] -> [b, t, d]
276-
inputs = tf.squeeze(inputs, 2)
277-
if isinstance(self.norm, bn_layers.BatchNormLayer):
278-
inputs = self.norm.FProp(theta.norm, inputs, paddings)
279-
elif isinstance(self.norm, layers.LayerNorm):
280-
inputs = self.norm.FProp(theta.norm, inputs)
281-
else:
282-
raise NotImplementedError(
283-
'Only bn_layers.{BatchNormLayer,GroupNormLayer}, layers.LayerNorm '
284-
'are supported.')
284+
raise NotImplementedError(
285+
'Only bn_layers.{BatchNormLayer,GroupNormLayer}, layers.LayerNorm '
286+
'are supported.')
285287
return self._CastToFPropDtype(inputs)
286288

287289
def FProp(self, theta, inputs, paddings):
@@ -335,6 +337,8 @@ def FProp(self, theta, inputs, paddings):
335337

336338
inputs = gshard_utils.MeshSplit(inputs, p.device_mesh,
337339
adapted_blf_dims_mapping)
340+
# [b, t, 1, d] --> [b, t, d]
341+
inputs = tf.squeeze(inputs, 2)
338342
inputs = self._Normalize(theta, inputs, paddings)
339343
inputs = gshard_utils.MeshSplit(inputs, p.device_mesh,
340344
p.activation_split_dims_mapping.blf)
@@ -363,24 +367,37 @@ def zero_state(self, batch_size):
363367
return py_utils.NestedMap()
364368

365369
def _NormalizeStep(self, theta, inputs, paddings, state0, state1):
366-
if hasattr(self.norm, 'StreamStep'):
367-
# TODO(jamesqin): support 3d inputs.
368-
# At present it's guaranteed GroupNorm.
369-
assert (isinstance(self.norm, bn_layers.GroupNormLayer) and
370-
self.norm.params.input_rank == 4)
370+
"""Applies normalization in a streaming fashion.
371+
372+
Args:
373+
theta: A NestedMap of layer params.
374+
inputs: [b, t, d].
375+
paddings: [b, t].
376+
state0: A NestedMap of tensors of the same struct as returned by
377+
zero_state().
378+
state1: A NestedMap of tensors of the same struct as state0. On output
379+
state1.norm_state might be updated with the new state.
380+
381+
Returns:
382+
A Tensor of shape [b, t, d].
383+
"""
384+
if isinstance(self.norm, bn_layers.GroupNormLayer):
385+
gn_input_rank = self.norm.params.input_rank
386+
if gn_input_rank == 4:
387+
tf.logging.info(
388+
'Using GroupNormLayer with input_rank=4, causing extra reshapes. '
389+
'Set norm.params.input_rank=3.')
390+
inputs = tf.expand_dims(inputs, 2)
371391
inputs, paddings, norm_state1 = self.norm.StreamStep(
372392
theta.norm, inputs, paddings, state0.norm_state)
373-
# [b, t, d]
374-
inputs = tf.squeeze(inputs, 2)
393+
if gn_input_rank == 4:
394+
inputs = tf.squeeze(inputs, 2)
375395
state1.norm_state = norm_state1
396+
elif isinstance(self.norm, layers.LayerNorm):
397+
inputs = self.norm.FProp(theta.norm, inputs)
376398
else:
377-
# [b, t, 1, d] -> [b, t, d]
378-
inputs = tf.squeeze(inputs, 2)
379-
if isinstance(self.norm, layers.LayerNorm):
380-
inputs = self.norm.FProp(theta.norm, inputs)
381-
else:
382-
raise NotImplementedError(
383-
'Only bn_layers.GroupNormLayer, layers.LayerNorm are supported.')
399+
raise NotImplementedError(
400+
'Only bn_layers.GroupNormLayer, layers.LayerNorm are supported.')
384401
# [b, t, d]
385402
return inputs, paddings
386403

@@ -395,7 +412,7 @@ def StreamStep(self, theta, inputs, paddings, state0):
395412
zero_state().
396413
397414
Returns:
398-
outputs: A NestedMap of tensors consisting:
415+
output: the same shape as inputs, with normalized values.
399416
padding: the same as input paddings.
400417
state1: A NestedMap of tensors of the same struct as state0.
401418
"""
@@ -424,6 +441,8 @@ def StreamStep(self, theta, inputs, paddings, state0):
424441
inputs, paddings, conv_state1 = self.depthwise_conv1d.StreamStep(
425442
theta.depthwise_conv1d, inputs, paddings, state0.conv_state)
426443
state1.conv_state = conv_state1
444+
# [b, t, 1, d] -> [b, t, d]
445+
inputs = tf.squeeze(inputs, 2)
427446
# [b, t, d]
428447
inputs, paddings = self._NormalizeStep(theta, inputs, paddings, state0,
429448
state1)

0 commit comments

Comments
 (0)