@@ -320,6 +320,43 @@ def test_network_invocation(self, output_range, out_seq_len, unpool_length):
320
320
self .assertEqual (outputs [0 ].shape [- 1 ], hidden_size )
321
321
self .assertTrue (hasattr (test_network , "_embedding_projection" ))
322
322
323
+ def test_embeddings_as_inputs (self ):
324
+ hidden_size = 32
325
+ sequence_length = 21
326
+ # Create a small BertEncoder for testing.
327
+ test_network = funnel_transformer .FunnelTransformerEncoder (
328
+ vocab_size = 100 ,
329
+ hidden_size = hidden_size ,
330
+ num_attention_heads = 2 ,
331
+ num_layers = 3 ,
332
+ pool_stride = 2 ,
333
+ )
334
+ # Create the inputs (note that the first dimension is implicit).
335
+ word_ids = tf .keras .Input (shape = (sequence_length ), dtype = tf .int32 )
336
+ mask = tf .keras .Input (shape = (sequence_length ,), dtype = tf .int32 )
337
+ type_ids = tf .keras .Input (shape = (sequence_length ,), dtype = tf .int32 )
338
+ test_network .build (
339
+ dict (input_word_ids = word_ids , input_mask = mask , input_type_ids = type_ids )
340
+ )
341
+ embeddings = test_network .get_embedding_layer ()(word_ids )
342
+ # Calls with the embeddings.
343
+ dict_outputs = test_network (
344
+ dict (
345
+ input_word_embeddings = embeddings ,
346
+ input_mask = mask ,
347
+ input_type_ids = type_ids ,
348
+ )
349
+ )
350
+ all_encoder_outputs = dict_outputs ["encoder_outputs" ]
351
+ pooled = dict_outputs ["pooled_output" ]
352
+
353
+ expected_pooled_shape = [None , hidden_size ]
354
+ self .assertAllEqual (expected_pooled_shape , pooled .shape .as_list ())
355
+
356
+ # The default output dtype is float32.
357
+ self .assertAllEqual (tf .float32 , all_encoder_outputs [- 1 ].dtype )
358
+ self .assertAllEqual (tf .float32 , pooled .dtype )
359
+
323
360
def test_serialize_deserialize (self ):
324
361
# Create a network object that sets all of its config options.
325
362
kwargs = dict (
0 commit comments