@@ -260,28 +260,30 @@ def _Normalize(self, theta, inputs, paddings):
260260
261261 Args:
262262 theta: A NestedMap of layer params.
263- inputs: [b, t, 1, d].
263+ inputs: [b, t, d].
264264 paddings: [b, t].
265265
266266 Returns:
267267 A Tensor of shape [b, t, d].
268268 """
269269 if isinstance (self .norm , bn_layers .GroupNormLayer ):
270- assert self .norm .params .input_rank == 4
270+ gn_input_rank = self .norm .params .input_rank
271+ if gn_input_rank == 4 :
272+ tf .logging .info (
273+ 'Using GroupNormLayer with input_rank=4, causing extra reshapes. '
274+ 'Set norm.params.input_rank=3.' )
275+ inputs = tf .expand_dims (inputs , 2 )
271276 inputs , _ = self .norm .FProp (theta .norm , inputs , paddings )
272- # [b, t, d]
273- inputs = tf .squeeze (inputs , 2 )
277+ if gn_input_rank == 4 :
278+ inputs = tf .squeeze (inputs , 2 )
279+ elif isinstance (self .norm , bn_layers .BatchNormLayer ):
280+ inputs = self .norm .FProp (theta .norm , inputs , paddings )
281+ elif isinstance (self .norm , layers .LayerNorm ):
282+ inputs = self .norm .FProp (theta .norm , inputs )
274283 else :
275- # [b, t, 1, d] -> [b, t, d]
276- inputs = tf .squeeze (inputs , 2 )
277- if isinstance (self .norm , bn_layers .BatchNormLayer ):
278- inputs = self .norm .FProp (theta .norm , inputs , paddings )
279- elif isinstance (self .norm , layers .LayerNorm ):
280- inputs = self .norm .FProp (theta .norm , inputs )
281- else :
282- raise NotImplementedError (
283- 'Only bn_layers.{BatchNormLayer,GroupNormLayer}, layers.LayerNorm '
284- 'are supported.' )
284+ raise NotImplementedError (
285+ 'Only bn_layers.{BatchNormLayer,GroupNormLayer}, layers.LayerNorm '
286+ 'are supported.' )
285287 return self ._CastToFPropDtype (inputs )
286288
287289 def FProp (self , theta , inputs , paddings ):
@@ -335,6 +337,8 @@ def FProp(self, theta, inputs, paddings):
335337
336338 inputs = gshard_utils .MeshSplit (inputs , p .device_mesh ,
337339 adapted_blf_dims_mapping )
340+ # [b, t, 1, d] --> [b, t, d]
341+ inputs = tf .squeeze (inputs , 2 )
338342 inputs = self ._Normalize (theta , inputs , paddings )
339343 inputs = gshard_utils .MeshSplit (inputs , p .device_mesh ,
340344 p .activation_split_dims_mapping .blf )
@@ -363,24 +367,37 @@ def zero_state(self, batch_size):
363367 return py_utils .NestedMap ()
364368
365369 def _NormalizeStep (self , theta , inputs , paddings , state0 , state1 ):
366- if hasattr (self .norm , 'StreamStep' ):
367- # TODO(jamesqin): support 3d inputs.
368- # At present it's guaranteed GroupNorm.
369- assert (isinstance (self .norm , bn_layers .GroupNormLayer ) and
370- self .norm .params .input_rank == 4 )
370+ """Applies normalization in a streaming fashion.
371+
372+ Args:
373+ theta: A NestedMap of layer params.
374+ inputs: [b, t, d].
375+ paddings: [b, t].
376+ state0: A NestedMap of tensors of the same struct as returned by
377+ zero_state().
378+ state1: A NestedMap of tensors of the same struct as state0. On output
379+ state1.norm_state might be updated with the new state.
380+
381+ Returns:
382+ A Tensor of shape [b, t, d].
383+ """
384+ if isinstance (self .norm , bn_layers .GroupNormLayer ):
385+ gn_input_rank = self .norm .params .input_rank
386+ if gn_input_rank == 4 :
387+ tf .logging .info (
388+ 'Using GroupNormLayer with input_rank=4, causing extra reshapes. '
389+ 'Set norm.params.input_rank=3.' )
390+ inputs = tf .expand_dims (inputs , 2 )
371391 inputs , paddings , norm_state1 = self .norm .StreamStep (
372392 theta .norm , inputs , paddings , state0 .norm_state )
373- # [b, t, d]
374- inputs = tf .squeeze (inputs , 2 )
393+ if gn_input_rank == 4 :
394+ inputs = tf .squeeze (inputs , 2 )
375395 state1 .norm_state = norm_state1
396+ elif isinstance (self .norm , layers .LayerNorm ):
397+ inputs = self .norm .FProp (theta .norm , inputs )
376398 else :
377- # [b, t, 1, d] -> [b, t, d]
378- inputs = tf .squeeze (inputs , 2 )
379- if isinstance (self .norm , layers .LayerNorm ):
380- inputs = self .norm .FProp (theta .norm , inputs )
381- else :
382- raise NotImplementedError (
383- 'Only bn_layers.GroupNormLayer, layers.LayerNorm are supported.' )
399+ raise NotImplementedError (
400+ 'Only bn_layers.GroupNormLayer, layers.LayerNorm are supported.' )
384401 # [b, t, d]
385402 return inputs , paddings
386403
@@ -395,7 +412,7 @@ def StreamStep(self, theta, inputs, paddings, state0):
395412 zero_state().
396413
397414 Returns:
398- outputs: A NestedMap of tensors consisting:
415+ output: the same shape as inputs, with normalized values.
399416 padding: the same as input paddings.
400417 state1: A NestedMap of tensors of the same struct as state0.
401418 """
@@ -424,6 +441,8 @@ def StreamStep(self, theta, inputs, paddings, state0):
424441 inputs , paddings , conv_state1 = self .depthwise_conv1d .StreamStep (
425442 theta .depthwise_conv1d , inputs , paddings , state0 .conv_state )
426443 state1 .conv_state = conv_state1
444+ # [b, t, 1, d] -> [b, t, d]
445+ inputs = tf .squeeze (inputs , 2 )
427446 # [b, t, d]
428447 inputs , paddings = self ._NormalizeStep (theta , inputs , paddings , state0 ,
429448 state1 )
0 commit comments