Skip to content

Commit a6d3214

Browse files
No public description
PiperOrigin-RevId: 605086430
1 parent 8e9f80a commit a6d3214

File tree

2 files changed

+99
-16
lines changed

2 files changed

+99
-16
lines changed

official/nlp/modeling/layers/transformer_encoder_block.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,58 @@
1313
# limitations under the License.
1414

1515
"""Keras-based TransformerEncoder block layer."""
16-
from typing import Any, Optional
16+
from typing import Any, Optional, Sequence
1717
from absl import logging
1818
import tensorflow as tf, tf_keras
1919

2020
from official.modeling import tf_utils
2121
from official.nlp.modeling.layers import util
2222

2323

24+
class RMSNorm(tf_keras.layers.Layer):
25+
"""Root mean square layer normalization layer."""
26+
27+
def __init__(
28+
self,
29+
axis: int | Sequence[int] = -1,
30+
epsilon: float = 1e-6,
31+
**kwargs,
32+
):
33+
"""Initializes RMSNorm.
34+
35+
Args:
36+
axis: The axis that the input is normalized over.
37+
epsilon: A small value added to the mean square for numerical stability.
38+
**kwargs: Keyword arguments passed to the base layer.
39+
"""
40+
super().__init__(**kwargs)
41+
self.axis = [axis] if isinstance(axis, int) else axis
42+
self.epsilon = epsilon
43+
44+
def build(self, input_shape: tf.TensorShape | Sequence[int | None]):
45+
input_shape = tf.TensorShape(input_shape)
46+
scale_shape = [1] * input_shape.rank
47+
for dim in self.axis:
48+
scale_shape[dim] = input_shape[dim]
49+
with tf.name_scope(self.name):
50+
self.scale = self.add_weight(
51+
name="scale",
52+
shape=scale_shape,
53+
initializer="ones",
54+
experimental_autocast=False,
55+
)
56+
super().build(input_shape)
57+
58+
def call(self, inputs: tf.Tensor) -> tf.Tensor:
59+
input_dtype = inputs.dtype
60+
inputs = tf.cast(inputs, tf.float32)
61+
var = tf.math.reduce_mean(
62+
tf.math.square(inputs), axis=self.axis, keepdims=True
63+
)
64+
outputs = inputs * tf.math.rsqrt(var + self.epsilon) * self.scale
65+
return tf.cast(outputs, input_dtype)
66+
67+
2468
@tf_keras.utils.register_keras_serializable(package="Text")
2569
class TransformerEncoderBlock(tf_keras.layers.Layer):
2670
"""TransformerEncoderBlock layer.
@@ -51,6 +95,7 @@ def __init__(self,
5195
use_bias=True,
5296
norm_first=False,
5397
norm_epsilon=1e-12,
98+
use_rms_norm=False,
5499
output_dropout=0.0,
55100
attention_dropout=0.0,
56101
inner_dropout=0.0,
@@ -76,7 +121,7 @@ def __init__(self,
76121
E.g. let's say input dims are `[batch_size, seq_dim, input_last_dim]`.
77122
Scenario 1: If `output_last_dim` is not `None`, then the output dims of this
78123
module would be `[batch_size, seq_dim, output_last_dim]`. Note `key_dim` is
79-
overriden by `output_last_dim`.
124+
overridden by `output_last_dim`.
80125
Scenario 2: If `output_last_dim` is `None` and `key_dim` is not `None`, then
81126
the output dims of this module would be `[batch_size, seq_dim, key_dim]`.
82127
Scenario 3: If the `output_last_dim` and `key_dim` are both `None`, the
@@ -103,6 +148,7 @@ def __init__(self,
103148
dense layers. If set False, output of attention and intermediate dense
104149
layers is normalized.
105150
norm_epsilon: Epsilon value to initialize normalization layers.
151+
use_rms_norm: Whether to use RMSNorm instead of LayerNorm.
106152
output_dropout: Dropout probability for the post-attention and output
107153
dropout.
108154
attention_dropout: Dropout probability for within the attention layer.
@@ -154,6 +200,7 @@ def __init__(self,
154200
self._use_bias = use_bias
155201
self._norm_first = norm_first
156202
self._norm_epsilon = norm_epsilon
203+
self._use_rms_norm = use_rms_norm
157204
self._inner_dropout = inner_dropout
158205
self._use_query_residual = use_query_residual
159206
self._key_dim = key_dim
@@ -213,25 +260,39 @@ def build(self, input_shape):
213260
attention_axes=self._attention_axes,
214261
output_shape=self._output_last_dim,
215262
name="self_attention",
216-
**common_kwargs)
263+
**common_kwargs
264+
)
217265
self._attention_dropout = tf_keras.layers.Dropout(
218-
rate=self._attention_dropout_rate)
266+
rate=self._attention_dropout_rate
267+
)
219268
# Use float32 in layernorm for numeric stability.
220269
# It is probably safe in mixed_float16, but we haven't validated this yet.
221-
self._attention_layer_norm = (
222-
tf_keras.layers.LayerNormalization(
223-
name="self_attention_layer_norm",
224-
axis=-1,
225-
epsilon=self._norm_epsilon,
226-
dtype=tf.float32))
270+
if self._use_rms_norm:
271+
self._attention_layer_norm = RMSNorm(
272+
epsilon=self._norm_epsilon,
273+
name="self_attention_layer_norm",
274+
)
275+
else:
276+
self._attention_layer_norm = tf_keras.layers.LayerNormalization(
277+
name="self_attention_layer_norm",
278+
axis=-1,
279+
epsilon=self._norm_epsilon,
280+
dtype=tf.float32,
281+
)
227282
self._attention_layer_norm_kv = self._attention_layer_norm
228283
if self._diff_q_kv_att_layer_norm:
229-
self._attention_layer_norm_kv = (
230-
tf_keras.layers.LayerNormalization(
231-
name="self_attention_layer_norm_kv",
232-
axis=-1,
233-
epsilon=self._norm_epsilon,
234-
dtype=tf.float32))
284+
if self._use_rms_norm:
285+
self._attention_layer_norm_kv = RMSNorm(
286+
epsilon=self._norm_epsilon,
287+
name="self_attention_layer_norm_kv",
288+
)
289+
else:
290+
self._attention_layer_norm_kv = tf_keras.layers.LayerNormalization(
291+
name="self_attention_layer_norm_kv",
292+
axis=-1,
293+
epsilon=self._norm_epsilon,
294+
dtype=tf.float32,
295+
)
235296

236297
self._intermediate_dense = tf_keras.layers.EinsumDense(
237298
einsum_equation,

official/nlp/modeling/layers/transformer_encoder_block_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,28 @@ def test_use_bias_norm_first(self):
429429
output = encoder_block(inputs)
430430
self.assertEqual(output.shape, (2, 4, hidden_size))
431431

432+
def test_use_rms_norm(self):
433+
num_attention_heads = 2
434+
hidden_size = 16
435+
encoder_block = TransformerEncoderBlock(
436+
num_attention_heads=num_attention_heads,
437+
inner_dim=32,
438+
inner_activation='relu',
439+
output_dropout=0.1,
440+
attention_dropout=0.1,
441+
use_bias=False,
442+
use_rms_norm=True,
443+
norm_epsilon=1e-6,
444+
inner_dropout=0.1,
445+
attention_initializer=tf_keras.initializers.RandomUniform(
446+
minval=0., maxval=1.))
447+
# Forward path.
448+
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
449+
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
450+
inputs = [dummy_tensor, dummy_mask]
451+
output = encoder_block(inputs)
452+
self.assertEqual(output.shape, (2, 4, hidden_size))
453+
432454
def test_norm_first_false_and_diff_q_kv_att_layer_norm_true_raises(self):
433455
some_num_attention_heads = 2
434456
some_inner_dim = 32

0 commit comments

Comments
 (0)