Skip to content

Commit c5ae411

Browse files
tensorflower-gardenersaberkun
authored andcommitted
Internal change
PiperOrigin-RevId: 398593113
1 parent 6ca5ac9 commit c5ae411

22 files changed

+736
-325
lines changed

official/core/input_reader.py

Lines changed: 235 additions & 180 deletions
Large diffs are not rendered by default.

official/nlp/configs/encoders.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
204204
bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
205205
kernel: KernelEncoderConfig = KernelEncoderConfig()
206206
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
207+
teams: BertEncoderConfig = BertEncoderConfig()
207208
xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
208209

209210

@@ -436,6 +437,40 @@ def build_encoder(config: EncoderConfig,
436437
initializer=tf.keras.initializers.RandomNormal(
437438
stddev=encoder_cfg.initializer_range))
438439

440+
if encoder_type == "teams":
441+
embedding_cfg = dict(
442+
vocab_size=encoder_cfg.vocab_size,
443+
type_vocab_size=encoder_cfg.type_vocab_size,
444+
hidden_size=encoder_cfg.hidden_size,
445+
embedding_width=encoder_cfg.embedding_size,
446+
max_seq_length=encoder_cfg.max_position_embeddings,
447+
initializer=tf.keras.initializers.TruncatedNormal(
448+
stddev=encoder_cfg.initializer_range),
449+
dropout_rate=encoder_cfg.dropout_rate,
450+
)
451+
embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg)
452+
hidden_cfg = dict(
453+
num_attention_heads=encoder_cfg.num_attention_heads,
454+
intermediate_size=encoder_cfg.intermediate_size,
455+
intermediate_activation=tf_utils.get_activation(
456+
encoder_cfg.hidden_activation),
457+
dropout_rate=encoder_cfg.dropout_rate,
458+
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
459+
kernel_initializer=tf.keras.initializers.TruncatedNormal(
460+
stddev=encoder_cfg.initializer_range),
461+
)
462+
kwargs = dict(
463+
embedding_cfg=embedding_cfg,
464+
embedding_cls=embedding_network,
465+
hidden_cfg=hidden_cfg,
466+
num_hidden_instances=encoder_cfg.num_layers,
467+
pooled_output_dim=encoder_cfg.hidden_size,
468+
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
469+
stddev=encoder_cfg.initializer_range),
470+
return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
471+
dict_outputs=True)
472+
return networks.EncoderScaffold(**kwargs)
473+
439474
# Uses the default BERTEncoder configuration schema to create the encoder.
440475
# If it does not match, please add a switch branch by the encoder type.
441476
return networks.BertEncoder(

official/nlp/configs/finetuning_experiments.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def bert_sentence_prediction() -> cfg.ExperimentConfig:
6161
'task.train_data.is_training != None',
6262
'task.validation_data.is_training != None'
6363
])
64-
config.task.model.encoder.type = 'bert'
6564
return config
6665

6766

@@ -98,7 +97,6 @@ def bert_squad() -> cfg.ExperimentConfig:
9897
'task.train_data.is_training != None',
9998
'task.validation_data.is_training != None'
10099
])
101-
config.task.model.encoder.type = 'bert'
102100
return config
103101

104102

official/nlp/modeling/networks/funnel_transformer.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,24 @@
1414

1515
"""Funnel Transformer network."""
1616
# pylint: disable=g-classes-have-attributes
17-
from typing import Union, Collection
17+
from typing import Union, Sequence
1818
from absl import logging
1919
import tensorflow as tf
2020

2121
from official.nlp import keras_nlp
2222

2323

24-
def _pool_and_concat(data, unpool_length: int, stride: int,
25-
axes: Union[Collection[int], int]):
24+
def _pool_and_concat(data, unpool_length: int, strides: Union[Sequence[int],
25+
int],
26+
axes: Union[Sequence[int], int]):
2627
"""Pools the data along a given axis with stride.
2728
2829
It also skips first unpool_length elements.
2930
3031
Args:
3132
data: Tensor to be pooled.
3233
unpool_length: Leading elements to be skipped.
33-
stride: Stride for the given axis.
34+
strides: Strides for the given axes.
3435
axes: Axes to pool the Tensor.
3536
3637
Returns:
@@ -39,8 +40,13 @@ def _pool_and_concat(data, unpool_length: int, stride: int,
3940
# Wraps the axes as a list.
4041
if isinstance(axes, int):
4142
axes = [axes]
43+
if isinstance(strides, int):
44+
strides = [strides] * len(axes)
45+
else:
46+
if len(strides) != len(axes):
47+
raise ValueError('The lengths of strides and axes need to match.')
4248

43-
for axis in axes:
49+
for axis, stride in zip(axes, strides):
4450
# Skips first `unpool_length` tokens.
4551
unpool_tensor_shape = [slice(None)] * axis + [slice(None, unpool_length)]
4652
unpool_tensor = data[unpool_tensor_shape]
@@ -80,7 +86,9 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
8086
dropout.
8187
attention_dropout: The dropout rate to use for the attention layers within
8288
the transformer layers.
83-
pool_stride: Pooling stride to compress the sequence length.
89+
pool_stride: An int or a list of ints. Pooling stride(s) to compress the
90+
sequence length. If set to int, each layer will have the same stride size.
91+
If set to list, the number of elements needs to match num_layers.
8492
unpool_length: Leading n tokens to be skipped from pooling.
8593
initializer: The initialzer to use for all weights in this encoder.
8694
output_range: The sequence output range, [0, output_range), by slicing the
@@ -185,12 +193,23 @@ def __init__(
185193
activation='tanh',
186194
kernel_initializer=initializer,
187195
name='pooler_transform')
188-
self._att_input_pool_layer = tf.keras.layers.MaxPooling1D(
189-
pool_size=pool_stride,
190-
strides=pool_stride,
191-
padding='same',
192-
name='att_input_pool_layer')
193-
self._pool_stride = pool_stride
196+
if isinstance(pool_stride, int):
197+
# TODO(b/197133196): Pooling layer can be shared.
198+
pool_strides = [pool_stride] * num_layers
199+
else:
200+
if len(pool_stride) != num_layers:
201+
raise ValueError('Lengths of pool_stride and num_layers are not equal.')
202+
pool_strides = pool_stride
203+
self._att_input_pool_layers = []
204+
for layer_pool_stride in pool_strides:
205+
att_input_pool_layer = tf.keras.layers.MaxPooling1D(
206+
pool_size=layer_pool_stride,
207+
strides=layer_pool_stride,
208+
padding='same',
209+
name='att_input_pool_layer')
210+
self._att_input_pool_layers.append(att_input_pool_layer)
211+
212+
self._pool_strides = pool_strides # This is a list here.
194213
self._unpool_length = unpool_length
195214

196215
self._config = {
@@ -250,23 +269,25 @@ def call(self, inputs):
250269
attention_mask = _pool_and_concat(
251270
attention_mask,
252271
unpool_length=self._unpool_length,
253-
stride=self._pool_stride,
272+
strides=self._pool_strides[0],
254273
axes=[1])
255-
for layer in self._transformer_layers:
274+
for i, layer in enumerate(self._transformer_layers):
256275
# Pools layer for compressing the query length.
257-
pooled_inputs = self._att_input_pool_layer(x[:, self._unpool_length:, :])
276+
pooled_inputs = self._att_input_pool_layers[i](
277+
x[:, self._unpool_length:, :])
258278
query_inputs = tf.concat(
259279
values=(tf.cast(
260280
x[:, :self._unpool_length, :],
261281
dtype=pooled_inputs.dtype), pooled_inputs),
262282
axis=1)
263283
x = layer([query_inputs, x, attention_mask])
264284
# Pools the corresponding attention_mask.
265-
attention_mask = _pool_and_concat(
266-
attention_mask,
267-
unpool_length=self._unpool_length,
268-
stride=self._pool_stride,
269-
axes=[1, 2])
285+
if i < len(self._transformer_layers) - 1:
286+
attention_mask = _pool_and_concat(
287+
attention_mask,
288+
unpool_length=self._unpool_length,
289+
strides=[self._pool_strides[i+1], self._pool_strides[i]],
290+
axes=[1, 2])
270291
encoder_outputs.append(x)
271292

272293
last_encoder_output = encoder_outputs[-1]

official/nlp/modeling/networks/funnel_transformer_test.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,24 @@ def test_network_creation(self, policy, pooled_dtype):
8080
self.assertAllEqual(tf.float32, data.dtype)
8181
self.assertAllEqual(pooled_dtype, pooled.dtype)
8282

83+
def test_invalid_stride_and_num_layers(self):
84+
hidden_size = 32
85+
num_layers = 3
86+
pool_stride = [2, 2]
87+
unpool_length = 1
88+
with self.assertRaisesRegex(ValueError,
89+
"pool_stride and num_layers are not equal"):
90+
_ = funnel_transformer.FunnelTransformerEncoder(
91+
vocab_size=100,
92+
hidden_size=hidden_size,
93+
num_attention_heads=2,
94+
num_layers=num_layers,
95+
pool_stride=pool_stride,
96+
unpool_length=unpool_length)
97+
8398
@parameterized.named_parameters(
8499
("no_stride_no_unpool", 1, 0),
100+
("stride_list_with_unpool", [2, 3, 4], 1),
85101
("large_stride_with_unpool", 3, 1),
86102
("large_stride_with_large_unpool", 5, 10),
87103
("no_stride_with_unpool", 1, 1),
@@ -110,11 +126,12 @@ def test_all_encoder_outputs_network_creation(self, pool_stride,
110126
expected_data_shape = [None, sequence_length, hidden_size]
111127
expected_pooled_shape = [None, hidden_size]
112128
self.assertLen(all_encoder_outputs, num_layers)
113-
for data in all_encoder_outputs:
114-
expected_data_shape[1] = unpool_length + (expected_data_shape[1] +
115-
pool_stride - 1 -
116-
unpool_length) // pool_stride
117-
print("shapes:", expected_data_shape, data.shape.as_list())
129+
if isinstance(pool_stride, int):
130+
pool_stride = [pool_stride] * num_layers
131+
for layer_pool_stride, data in zip(pool_stride, all_encoder_outputs):
132+
expected_data_shape[1] = unpool_length + (
133+
expected_data_shape[1] + layer_pool_stride - 1 -
134+
unpool_length) // layer_pool_stride
118135
self.assertAllEqual(expected_data_shape, data.shape.as_list())
119136
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
120137

official/nlp/modeling/networks/packed_sequence_embedding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def __init__(self,
6262
pack_multiple_sequences=False,
6363
**kwargs):
6464
initializer = tf.keras.initializers.get(initializer)
65+
if embedding_width is None:
66+
embedding_width = hidden_size
6567
config_dict = {
6668
'vocab_size': vocab_size,
6769
'type_vocab_size': type_vocab_size,

official/nlp/projects/teams/experiments/teams_en_uncased_base.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
task:
22
model:
33
encoder:
4-
bert:
4+
teams:
55
attention_dropout_rate: 0.1
66
dropout_rate: 0.1
77
embedding_size: 768
@@ -14,3 +14,4 @@ task:
1414
num_layers: 12
1515
type_vocab_size: 2
1616
vocab_size: 30522
17+
type: teams

official/nlp/projects/teams/experiments/teams_en_uncased_small.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
task:
22
model:
33
encoder:
4-
bert:
4+
teams:
55
attention_dropout_rate: 0.1
66
dropout_rate: 0.1
77
embedding_size: 128
@@ -14,3 +14,4 @@ task:
1414
num_layers: 12
1515
type_vocab_size: 2
1616
vocab_size: 30522
17+
type: teams

official/nlp/projects/teams/teams.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,6 @@ def get_encoder(bert_config,
6464
Returns:
6565
A encoder object.
6666
"""
67-
# embedding_size is required for PackedSequenceEmbedding.
68-
if bert_config.embedding_size is None:
69-
bert_config.embedding_size = bert_config.hidden_size
7067
embedding_cfg = dict(
7168
vocab_size=bert_config.vocab_size,
7269
type_vocab_size=bert_config.type_vocab_size,

official/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pyyaml>=5.1
2121
opencv-python-headless
2222
Pillow
2323
pycocotools
24+
waymo-open-dataset-tf-2-6-0
2425
# NLP related dependencies
2526
seqeval
2627
sentencepiece

0 commit comments

Comments
 (0)