@@ -79,6 +79,8 @@ class BertEncoderV2(tf_keras.layers.Layer):
79
79
attention scores of all transformer layers. This will be a list of length
80
80
`num_layers`, and each element will be in the shape [batch_size,
81
81
num_attention_heads, seq_dim, seq_dim].
82
+ return_word_embeddings: If true, also return the input word embedding
83
+ sequence in the bert inference output.
82
84
"""
83
85
84
86
def __init__ (
@@ -101,6 +103,7 @@ def __init__(
101
103
norm_first : bool = False ,
102
104
with_dense_inputs : bool = False ,
103
105
return_attention_scores : bool = False ,
106
+ return_word_embeddings : bool = False ,
104
107
** kwargs ):
105
108
# Pops kwargs that are used in V1 implementation.
106
109
if 'dict_outputs' in kwargs :
@@ -208,6 +211,7 @@ def __init__(
208
211
'norm_first' : norm_first ,
209
212
'with_dense_inputs' : with_dense_inputs ,
210
213
'return_attention_scores' : return_attention_scores ,
214
+ 'return_word_embeddings' : return_word_embeddings ,
211
215
}
212
216
if with_dense_inputs :
213
217
self .inputs = dict (
@@ -278,6 +282,10 @@ def call(self, inputs):
278
282
encoder_outputs = encoder_outputs )
279
283
if self ._config ['return_attention_scores' ]:
280
284
output ['attention_scores' ] = attention_outputs
285
+
286
+ if self ._config ['return_word_embeddings' ]:
287
+ output ['word_embeddings' ] = embeddings
288
+
281
289
return output
282
290
283
291
def get_embedding_table (self ):
@@ -390,6 +398,8 @@ class BertEncoder(tf_keras.Model):
390
398
attention scores of all transformer layers. This will be a list of length
391
399
`num_layers`, and each element will be in the shape [batch_size,
392
400
num_attention_heads, seq_dim, seq_dim].
401
+ return_word_embeddings: If true, also return the input word embedding
402
+ sequence in the bert inference output.
393
403
"""
394
404
395
405
def __init__ (
@@ -412,6 +422,7 @@ def __init__(
412
422
dict_outputs = False ,
413
423
return_all_encoder_outputs = False ,
414
424
return_attention_scores : bool = False ,
425
+ return_word_embeddings : bool = False ,
415
426
** kwargs ):
416
427
if 'sequence_length' in kwargs :
417
428
kwargs .pop ('sequence_length' )
@@ -538,6 +549,9 @@ def __init__(
538
549
if return_attention_scores :
539
550
outputs ['attention_scores' ] = attention_outputs
540
551
552
+ if return_word_embeddings :
553
+ outputs ['word_embeddings' ] = embeddings
554
+
541
555
if dict_outputs :
542
556
super ().__init__ (
543
557
inputs = [word_ids , mask , type_ids ], outputs = outputs , ** kwargs )
@@ -587,6 +601,7 @@ def __init__(
587
601
'norm_first' : norm_first ,
588
602
'dict_outputs' : dict_outputs ,
589
603
'return_attention_scores' : return_attention_scores ,
604
+ 'return_word_embeddings' : return_word_embeddings ,
590
605
}
591
606
# pylint: disable=protected-access
592
607
self ._setattr_tracking = False
0 commit comments