Skip to content

Commit 326930f

Browse files
Jiayu Yetensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 531529187
1 parent 471b5c1 commit 326930f

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

official/nlp/modeling/networks/funnel_transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def __init__(
452452

453453
def call(self, inputs, output_range: Optional[tf.Tensor] = None):
454454
# inputs are [word_ids, mask, type_ids]
455+
word_embeddings = None
455456
if isinstance(inputs, (list, tuple)):
456457
logging.warning('List inputs to %s are discouraged.', self.__class__)
457458
if len(inputs) == 3:
@@ -472,14 +473,16 @@ def call(self, inputs, output_range: Optional[tf.Tensor] = None):
472473
word_ids = inputs.get('input_word_ids')
473474
mask = inputs.get('input_mask')
474475
type_ids = inputs.get('input_type_ids')
476+
word_embeddings = inputs.get('input_word_embeddings', None)
475477

476478
dense_inputs = inputs.get('dense_inputs', None)
477479
dense_mask = inputs.get('dense_mask', None)
478480
dense_type_ids = inputs.get('dense_type_ids', None)
479481
else:
480482
raise ValueError('Unexpected inputs type to %s.' % self.__class__)
481483

482-
word_embeddings = self._embedding_layer(word_ids)
484+
if word_embeddings is None:
485+
word_embeddings = self._embedding_layer(word_ids)
483486

484487
if dense_inputs is not None:
485488
# Concat the dense embeddings at sequence begin so unpool_len can control

official/nlp/modeling/networks/funnel_transformer_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,43 @@ def test_network_invocation(self, output_range, out_seq_len, unpool_length):
320320
self.assertEqual(outputs[0].shape[-1], hidden_size)
321321
self.assertTrue(hasattr(test_network, "_embedding_projection"))
322322

323+
def test_embeddings_as_inputs(self):
324+
hidden_size = 32
325+
sequence_length = 21
326+
# Create a small BertEncoder for testing.
327+
test_network = funnel_transformer.FunnelTransformerEncoder(
328+
vocab_size=100,
329+
hidden_size=hidden_size,
330+
num_attention_heads=2,
331+
num_layers=3,
332+
pool_stride=2,
333+
)
334+
# Create the inputs (note that the first dimension is implicit).
335+
word_ids = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
336+
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
337+
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
338+
test_network.build(
339+
dict(input_word_ids=word_ids, input_mask=mask, input_type_ids=type_ids)
340+
)
341+
embeddings = test_network.get_embedding_layer()(word_ids)
342+
# Calls with the embeddings.
343+
dict_outputs = test_network(
344+
dict(
345+
input_word_embeddings=embeddings,
346+
input_mask=mask,
347+
input_type_ids=type_ids,
348+
)
349+
)
350+
all_encoder_outputs = dict_outputs["encoder_outputs"]
351+
pooled = dict_outputs["pooled_output"]
352+
353+
expected_pooled_shape = [None, hidden_size]
354+
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
355+
356+
# The default output dtype is float32.
357+
self.assertAllEqual(tf.float32, all_encoder_outputs[-1].dtype)
358+
self.assertAllEqual(tf.float32, pooled.dtype)
359+
323360
def test_serialize_deserialize(self):
324361
# Create a network object that sets all of its config options.
325362
kwargs = dict(

0 commit comments

Comments
 (0)