Skip to content

Commit 854febe

Browse files
No public description
PiperOrigin-RevId: 591264030
1 parent 817674b commit 854febe

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed

official/nlp/configs/encoders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class BertEncoderConfig(hyperparams.Config):
4747
output_range: Optional[int] = None
4848
return_all_encoder_outputs: bool = False
4949
return_attention_scores: bool = False
50+
return_word_embeddings: bool = False
5051
# Pre/Post-LN Transformer
5152
norm_first: bool = False
5253

@@ -769,5 +770,6 @@ def build_encoder(config: EncoderConfig,
769770
embedding_layer=embedding_layer,
770771
return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
771772
return_attention_scores=encoder_cfg.return_attention_scores,
773+
return_word_embeddings=encoder_cfg.return_word_embeddings,
772774
dict_outputs=True,
773775
norm_first=encoder_cfg.norm_first)

official/nlp/modeling/networks/bert_encoder.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class BertEncoderV2(tf_keras.layers.Layer):
7979
attention scores of all transformer layers. This will be a list of length
8080
`num_layers`, and each element will be in the shape [batch_size,
8181
num_attention_heads, seq_dim, seq_dim].
82+
return_word_embeddings: If true, also return the input word embedding
83+
sequence in the bert inference output.
8284
"""
8385

8486
def __init__(
@@ -101,6 +103,7 @@ def __init__(
101103
norm_first: bool = False,
102104
with_dense_inputs: bool = False,
103105
return_attention_scores: bool = False,
106+
return_word_embeddings: bool = False,
104107
**kwargs):
105108
# Pops kwargs that are used in V1 implementation.
106109
if 'dict_outputs' in kwargs:
@@ -208,6 +211,7 @@ def __init__(
208211
'norm_first': norm_first,
209212
'with_dense_inputs': with_dense_inputs,
210213
'return_attention_scores': return_attention_scores,
214+
'return_word_embeddings': return_word_embeddings,
211215
}
212216
if with_dense_inputs:
213217
self.inputs = dict(
@@ -278,6 +282,10 @@ def call(self, inputs):
278282
encoder_outputs=encoder_outputs)
279283
if self._config['return_attention_scores']:
280284
output['attention_scores'] = attention_outputs
285+
286+
if self._config['return_word_embeddings']:
287+
output['word_embeddings'] = embeddings
288+
281289
return output
282290

283291
def get_embedding_table(self):
@@ -390,6 +398,8 @@ class BertEncoder(tf_keras.Model):
390398
attention scores of all transformer layers. This will be a list of length
391399
`num_layers`, and each element will be in the shape [batch_size,
392400
num_attention_heads, seq_dim, seq_dim].
401+
return_word_embeddings: If true, also return the input word embedding
402+
sequence in the bert inference output.
393403
"""
394404

395405
def __init__(
@@ -412,6 +422,7 @@ def __init__(
412422
dict_outputs=False,
413423
return_all_encoder_outputs=False,
414424
return_attention_scores: bool = False,
425+
return_word_embeddings: bool = False,
415426
**kwargs):
416427
if 'sequence_length' in kwargs:
417428
kwargs.pop('sequence_length')
@@ -538,6 +549,9 @@ def __init__(
538549
if return_attention_scores:
539550
outputs['attention_scores'] = attention_outputs
540551

552+
if return_word_embeddings:
553+
outputs['word_embeddings'] = embeddings
554+
541555
if dict_outputs:
542556
super().__init__(
543557
inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
@@ -587,6 +601,7 @@ def __init__(
587601
'norm_first': norm_first,
588602
'dict_outputs': dict_outputs,
589603
'return_attention_scores': return_attention_scores,
604+
'return_word_embeddings': return_word_embeddings,
590605
}
591606
# pylint: disable=protected-access
592607
self._setattr_tracking = False

official/nlp/modeling/networks/bert_encoder_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,38 @@ def test_dict_outputs_network_creation_return_attention_scores(
138138
# The default output dtype is float32.
139139
self.assertAllEqual(tf.float32, all_attention_outputs[-1].dtype)
140140

141+
@parameterized.named_parameters(
142+
("encoder_v2", bert_encoder.BertEncoderV2),
143+
("encoder_v1", bert_encoder.BertEncoder),
144+
)
145+
def test_dict_outputs_network_creation_return_word_embeddings(
146+
self, encoder_cls):
147+
hidden_size = 32
148+
sequence_length = 21
149+
num_attention_heads = 5
150+
num_layers = 3
151+
# Create a small BertEncoder for testing.
152+
test_network = encoder_cls(
153+
vocab_size=100,
154+
hidden_size=hidden_size,
155+
num_attention_heads=num_attention_heads,
156+
num_layers=num_layers,
157+
return_word_embeddings=True,
158+
dict_outputs=True)
159+
# Create the inputs (note that the first dimension is implicit).
160+
word_ids = tf_keras.Input(shape=(sequence_length,), dtype=tf.int32)
161+
mask = tf_keras.Input(shape=(sequence_length,), dtype=tf.int32)
162+
type_ids = tf_keras.Input(shape=(sequence_length,), dtype=tf.int32)
163+
dict_outputs = test_network(
164+
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids))
165+
word_embeddings = dict_outputs["word_embeddings"]
166+
167+
expected_data_shape = [None, sequence_length, hidden_size]
168+
self.assertAllEqual(expected_data_shape, word_embeddings.shape)
169+
170+
# The default output dtype is float32.
171+
self.assertAllEqual(tf.float32, word_embeddings[-1].dtype)
172+
141173
@parameterized.named_parameters(
142174
("encoder_v2", bert_encoder.BertEncoderV2),
143175
("encoder_v1", bert_encoder.BertEncoder),

0 commit comments

Comments
 (0)