Skip to content

Commit a7e6097

Browse files
wanxinxwtensorflower-gardener
authored andcommitted
Added block diagonal feedforward layer.
This layer replaces the weight matrix of the output_dense layer with a block diagonal matrix to save layer parameters and FLOPs. A linear mixing layer can be added optionally to improve layer expressibility. PiperOrigin-RevId: 418828099
1 parent 60f6d6c commit a7e6097

File tree

3 files changed

+286
-0
lines changed

3 files changed

+286
-0
lines changed

official/nlp/modeling/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from official.nlp.modeling.layers.attention import *
2121
from official.nlp.modeling.layers.bigbird_attention import BigBirdAttention
2222
from official.nlp.modeling.layers.bigbird_attention import BigBirdMasks
23+
from official.nlp.modeling.layers.block_diag_feedforward import BlockDiagFeedforward
2324
from official.nlp.modeling.layers.cls_head import *
2425
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
2526
from official.nlp.modeling.layers.gaussian_process import RandomFeatureGaussianProcess
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Keras-based gated feedforward layer."""
16+
# pylint: disable=g-classes-have-attributes
17+
from typing import Optional
18+
19+
import tensorflow as tf
20+
21+
22+
class BlockDiagFeedforward(tf.keras.layers.Layer):
23+
"""Block diagonal feedforward layer.
24+
25+
This layer replaces the weight matrix of the output_dense layer with a block
26+
diagonal matrix to save layer parameters and FLOPs. A linear mixing layer can
27+
be added optionally to improve layer expressibility.
28+
29+
Args:
30+
intermediate_size: Size of the intermediate layer.
31+
intermediate_activation: Activation for the intermediate layer.
32+
dropout: Dropout probability for the output dropout.
33+
num_blocks: The number of blocks for the block diagonal matrix of the
34+
output_dense layer.
35+
apply_mixing: Apply linear mixing if True.
36+
kernel_initializer: Initializer for dense layer kernels.
37+
bias_initializer: Initializer for dense layer biases.
38+
kernel_regularizer: Regularizer for dense layer kernels.
39+
bias_regularizer: Regularizer for dense layer biases.
40+
activity_regularizer: Regularizer for dense layer activity.
41+
kernel_constraint: Constraint for dense layer kernels.
42+
bias_constraint: Constraint for dense layer kernels.
43+
"""
44+
45+
def __init__(
46+
self,
47+
intermediate_size: int,
48+
intermediate_activation: str,
49+
dropout: float,
50+
num_blocks: int = 1,
51+
apply_mixing: bool = True,
52+
kernel_initializer: str = "glorot_uniform",
53+
bias_initializer: str = "zeros",
54+
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
55+
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
56+
activity_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
57+
kernel_constraint: Optional[tf.keras.constraints.Constraint] = None,
58+
bias_constraint: Optional[tf.keras.constraints.Constraint] = None,
59+
**kwargs): # pylint: disable=g-doc-args
60+
super(BlockDiagFeedforward, self).__init__(**kwargs)
61+
self._intermediate_size = intermediate_size
62+
self._intermediate_activation = intermediate_activation
63+
self._dropout = dropout
64+
self._num_blocks = num_blocks
65+
self._apply_mixing = apply_mixing
66+
67+
if intermediate_size % num_blocks != 0:
68+
raise ValueError("Intermediate_size (%d) isn't a multiple of num_blocks "
69+
"(%d)." % (intermediate_size, num_blocks))
70+
71+
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
72+
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
73+
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
74+
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
75+
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
76+
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
77+
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
78+
79+
def build(self, input_shape):
80+
hidden_size = input_shape.as_list()[-1]
81+
82+
common_kwargs = dict(
83+
kernel_initializer=self._kernel_initializer,
84+
bias_initializer=self._bias_initializer,
85+
kernel_regularizer=self._kernel_regularizer,
86+
bias_regularizer=self._bias_regularizer,
87+
activity_regularizer=self._activity_regularizer,
88+
kernel_constraint=self._kernel_constraint,
89+
bias_constraint=self._bias_constraint)
90+
91+
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
92+
"abc,cde->abde",
93+
output_shape=(None, self._num_blocks,
94+
self._intermediate_size // self._num_blocks),
95+
bias_axes="de",
96+
name="intermediate",
97+
**common_kwargs)
98+
99+
policy = tf.keras.mixed_precision.global_policy()
100+
if policy.name == "mixed_bfloat16":
101+
# bfloat16 causes BERT with the LAMB optimizer to not converge
102+
# as well, so we use float32.
103+
policy = tf.float32
104+
self._intermediate_activation_layer = tf.keras.layers.Activation(
105+
self._intermediate_activation, dtype=policy)
106+
107+
self._output_dense = tf.keras.layers.experimental.EinsumDense(
108+
"abde,deo->abdo",
109+
output_shape=(None, self._num_blocks,
110+
hidden_size // self._num_blocks),
111+
bias_axes="do",
112+
name="output",
113+
**common_kwargs)
114+
115+
if self._apply_mixing:
116+
self._output_mixing = tf.keras.layers.experimental.EinsumDense(
117+
"abdo,de->abeo",
118+
output_shape=(None, self._num_blocks,
119+
hidden_size // self._num_blocks),
120+
name="output_mixing",
121+
**common_kwargs)
122+
self._output_reshape = tf.keras.layers.Reshape((-1, hidden_size))
123+
124+
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout)
125+
126+
def get_config(self):
127+
config = {
128+
"intermediate_size":
129+
self._intermediate_size,
130+
"intermediate_activation":
131+
self._intermediate_activation,
132+
"dropout":
133+
self._dropout,
134+
"num_blocks":
135+
self._num_blocks,
136+
"apply_mixing":
137+
self._apply_mixing,
138+
"kernel_initializer":
139+
tf.keras.initializers.serialize(self._kernel_initializer),
140+
"bias_initializer":
141+
tf.keras.initializers.serialize(self._bias_initializer),
142+
"kernel_regularizer":
143+
tf.keras.regularizers.serialize(self._kernel_regularizer),
144+
"bias_regularizer":
145+
tf.keras.regularizers.serialize(self._bias_regularizer),
146+
"activity_regularizer":
147+
tf.keras.regularizers.serialize(self._activity_regularizer),
148+
"kernel_constraint":
149+
tf.keras.constraints.serialize(self._kernel_constraint),
150+
"bias_constraint":
151+
tf.keras.constraints.serialize(self._bias_constraint)
152+
}
153+
base_config = super(BlockDiagFeedforward, self).get_config()
154+
return dict(list(base_config.items()) + list(config.items()))
155+
156+
def call(self, inputs):
157+
intermediate_output = self._intermediate_dense(inputs)
158+
intermediate_output = self._intermediate_activation_layer(
159+
intermediate_output)
160+
layer_output = self._output_dense(intermediate_output)
161+
if self._apply_mixing:
162+
layer_output = self._output_mixing(layer_output)
163+
layer_output = self._output_reshape(layer_output)
164+
layer_output = self._output_dropout(layer_output)
165+
166+
return layer_output
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for Keras-based gated feedforward layer."""
16+
17+
from absl.testing import parameterized
18+
import numpy as np
19+
import tensorflow as tf
20+
21+
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
22+
from official.nlp.modeling.layers import block_diag_feedforward
23+
24+
25+
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
26+
# guarantees forward compatibility of this code for the V2 switchover.
27+
@keras_parameterized.run_all_keras_modes
28+
class BlockDiagFeedforwardTest(keras_parameterized.TestCase):
29+
30+
def tearDown(self):
31+
super(BlockDiagFeedforwardTest, self).tearDown()
32+
tf.keras.mixed_precision.set_global_policy("float32")
33+
34+
@parameterized.parameters(
35+
(1, True, "float32"),
36+
(1, True, "mixed_float16"),
37+
(1, False, "float32"),
38+
(1, False, "mixed_float16"),
39+
(2, True, "float32"),
40+
(2, True, "mixed_float16"),
41+
(2, False, "float32"),
42+
(2, False, "mixed_float16"),
43+
)
44+
def test_layer_creation(self, num_blocks, apply_mixing, dtype):
45+
tf.keras.mixed_precision.set_global_policy(dtype)
46+
kwargs = dict(
47+
intermediate_size=128,
48+
intermediate_activation="relu",
49+
dropout=0.1,
50+
num_blocks=num_blocks,
51+
apply_mixing=apply_mixing,
52+
kernel_initializer="glorot_uniform",
53+
bias_initializer="zeros")
54+
test_layer = block_diag_feedforward.BlockDiagFeedforward(**kwargs)
55+
56+
sequence_length = 64
57+
width = 128
58+
# Create a 3-dimensional input (the first dimension is implicit).
59+
data_tensor = tf.keras.Input(shape=(sequence_length, width))
60+
output_tensor = test_layer(data_tensor)
61+
# The default output of a transformer layer should be the same as the input.
62+
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
63+
64+
@parameterized.parameters(
65+
(1, True, "float32"),
66+
(1, True, "mixed_float16"),
67+
(1, False, "float32"),
68+
(1, False, "mixed_float16"),
69+
(2, True, "float32"),
70+
(2, True, "mixed_float16"),
71+
(2, False, "float32"),
72+
(2, False, "mixed_float16"),
73+
)
74+
def test_layer_invocation(self, num_blocks, apply_mixing, dtype):
75+
tf.keras.mixed_precision.set_global_policy(dtype)
76+
kwargs = dict(
77+
intermediate_size=16,
78+
intermediate_activation="relu",
79+
dropout=0.1,
80+
num_blocks=num_blocks,
81+
apply_mixing=apply_mixing,
82+
kernel_initializer="glorot_uniform",
83+
bias_initializer="zeros")
84+
test_layer = block_diag_feedforward.BlockDiagFeedforward(**kwargs)
85+
86+
sequence_length = 16
87+
width = 32
88+
# Create a 3-dimensional input (the first dimension is implicit).
89+
data_tensor = tf.keras.Input(shape=(sequence_length, width))
90+
output_tensor = test_layer(data_tensor)
91+
92+
# Create a model from the test layer.
93+
model = tf.keras.Model(data_tensor, output_tensor)
94+
95+
# Invoke the model on test data.
96+
batch_size = 6
97+
input_data = 10 * np.random.random_sample(
98+
(batch_size, sequence_length, width))
99+
output_data = model.predict(input_data)
100+
self.assertEqual(output_data.shape, (batch_size, sequence_length, width))
101+
102+
def test_get_config(self):
103+
kwargs = dict(
104+
intermediate_size=16,
105+
intermediate_activation="relu",
106+
dropout=0.1,
107+
num_blocks=2,
108+
apply_mixing=True,
109+
kernel_initializer="glorot_uniform",
110+
bias_initializer="zeros")
111+
test_layer = block_diag_feedforward.BlockDiagFeedforward(**kwargs)
112+
new_layer = block_diag_feedforward.BlockDiagFeedforward.from_config(
113+
test_layer.get_config())
114+
115+
self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
116+
117+
118+
if __name__ == "__main__":
119+
tf.test.main()

0 commit comments

Comments
 (0)