13
13
# limitations under the License.
14
14
15
15
"""Keras-based TransformerEncoder block layer."""
16
- from typing import Any , Optional
16
+ from typing import Any , Optional , Sequence
17
17
from absl import logging
18
18
import tensorflow as tf , tf_keras
19
19
20
20
from official .modeling import tf_utils
21
21
from official .nlp .modeling .layers import util
22
22
23
23
24
+ class RMSNorm (tf_keras .layers .Layer ):
25
+ """Root mean square layer normalization layer."""
26
+
27
+ def __init__ (
28
+ self ,
29
+ axis : int | Sequence [int ] = - 1 ,
30
+ epsilon : float = 1e-6 ,
31
+ ** kwargs ,
32
+ ):
33
+ """Initializes RMSNorm.
34
+
35
+ Args:
36
+ axis: The axis that the input is normalized over.
37
+ epsilon: A small value added to the mean square for numerical stability.
38
+ **kwargs: Keyword arguments passed to the base layer.
39
+ """
40
+ super ().__init__ (** kwargs )
41
+ self .axis = [axis ] if isinstance (axis , int ) else axis
42
+ self .epsilon = epsilon
43
+
44
+ def build (self , input_shape : tf .TensorShape | Sequence [int | None ]):
45
+ input_shape = tf .TensorShape (input_shape )
46
+ scale_shape = [1 ] * input_shape .rank
47
+ for dim in self .axis :
48
+ scale_shape [dim ] = input_shape [dim ]
49
+ with tf .name_scope (self .name ):
50
+ self .scale = self .add_weight (
51
+ name = "scale" ,
52
+ shape = scale_shape ,
53
+ initializer = "ones" ,
54
+ experimental_autocast = False ,
55
+ )
56
+ super ().build (input_shape )
57
+
58
+ def call (self , inputs : tf .Tensor ) -> tf .Tensor :
59
+ input_dtype = inputs .dtype
60
+ inputs = tf .cast (inputs , tf .float32 )
61
+ var = tf .math .reduce_mean (
62
+ tf .math .square (inputs ), axis = self .axis , keepdims = True
63
+ )
64
+ outputs = inputs * tf .math .rsqrt (var + self .epsilon ) * self .scale
65
+ return tf .cast (outputs , input_dtype )
66
+
67
+
24
68
@tf_keras .utils .register_keras_serializable (package = "Text" )
25
69
class TransformerEncoderBlock (tf_keras .layers .Layer ):
26
70
"""TransformerEncoderBlock layer.
@@ -51,6 +95,7 @@ def __init__(self,
51
95
use_bias = True ,
52
96
norm_first = False ,
53
97
norm_epsilon = 1e-12 ,
98
+ use_rms_norm = False ,
54
99
output_dropout = 0.0 ,
55
100
attention_dropout = 0.0 ,
56
101
inner_dropout = 0.0 ,
@@ -76,7 +121,7 @@ def __init__(self,
76
121
E.g. let's say input dims are `[batch_size, seq_dim, input_last_dim]`.
77
122
Scenario 1: If `output_last_dim` is not `None`, then the output dims of this
78
123
module would be `[batch_size, seq_dim, output_last_dim]`. Note `key_dim` is
79
- overriden by `output_last_dim`.
124
+ overridden by `output_last_dim`.
80
125
Scenario 2: If `output_last_dim` is `None` and `key_dim` is not `None`, then
81
126
the output dims of this module would be `[batch_size, seq_dim, key_dim]`.
82
127
Scenario 3: If the `output_last_dim` and `key_dim` are both `None`, the
@@ -103,6 +148,7 @@ def __init__(self,
103
148
dense layers. If set False, output of attention and intermediate dense
104
149
layers is normalized.
105
150
norm_epsilon: Epsilon value to initialize normalization layers.
151
+ use_rms_norm: Whether to use RMSNorm instead of LayerNorm.
106
152
output_dropout: Dropout probability for the post-attention and output
107
153
dropout.
108
154
attention_dropout: Dropout probability for within the attention layer.
@@ -154,6 +200,7 @@ def __init__(self,
154
200
self ._use_bias = use_bias
155
201
self ._norm_first = norm_first
156
202
self ._norm_epsilon = norm_epsilon
203
+ self ._use_rms_norm = use_rms_norm
157
204
self ._inner_dropout = inner_dropout
158
205
self ._use_query_residual = use_query_residual
159
206
self ._key_dim = key_dim
@@ -213,25 +260,39 @@ def build(self, input_shape):
213
260
attention_axes = self ._attention_axes ,
214
261
output_shape = self ._output_last_dim ,
215
262
name = "self_attention" ,
216
- ** common_kwargs )
263
+ ** common_kwargs
264
+ )
217
265
self ._attention_dropout = tf_keras .layers .Dropout (
218
- rate = self ._attention_dropout_rate )
266
+ rate = self ._attention_dropout_rate
267
+ )
219
268
# Use float32 in layernorm for numeric stability.
220
269
# It is probably safe in mixed_float16, but we haven't validated this yet.
221
- self ._attention_layer_norm = (
222
- tf_keras .layers .LayerNormalization (
223
- name = "self_attention_layer_norm" ,
224
- axis = - 1 ,
225
- epsilon = self ._norm_epsilon ,
226
- dtype = tf .float32 ))
270
+ if self ._use_rms_norm :
271
+ self ._attention_layer_norm = RMSNorm (
272
+ epsilon = self ._norm_epsilon ,
273
+ name = "self_attention_layer_norm" ,
274
+ )
275
+ else :
276
+ self ._attention_layer_norm = tf_keras .layers .LayerNormalization (
277
+ name = "self_attention_layer_norm" ,
278
+ axis = - 1 ,
279
+ epsilon = self ._norm_epsilon ,
280
+ dtype = tf .float32 ,
281
+ )
227
282
self ._attention_layer_norm_kv = self ._attention_layer_norm
228
283
if self ._diff_q_kv_att_layer_norm :
229
- self ._attention_layer_norm_kv = (
230
- tf_keras .layers .LayerNormalization (
231
- name = "self_attention_layer_norm_kv" ,
232
- axis = - 1 ,
233
- epsilon = self ._norm_epsilon ,
234
- dtype = tf .float32 ))
284
+ if self ._use_rms_norm :
285
+ self ._attention_layer_norm_kv = RMSNorm (
286
+ epsilon = self ._norm_epsilon ,
287
+ name = "self_attention_layer_norm_kv" ,
288
+ )
289
+ else :
290
+ self ._attention_layer_norm_kv = tf_keras .layers .LayerNormalization (
291
+ name = "self_attention_layer_norm_kv" ,
292
+ axis = - 1 ,
293
+ epsilon = self ._norm_epsilon ,
294
+ dtype = tf .float32 ,
295
+ )
235
296
236
297
self ._intermediate_dense = tf_keras .layers .EinsumDense (
237
298
einsum_equation ,
0 commit comments