18
18
import tensorflow as tf
19
19
20
20
from tensorflow .python .keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
21
+ from official .nlp .modeling import layers
21
22
from official .nlp .modeling import networks
22
23
from official .nlp .modeling .models import bert_classifier
23
24
@@ -53,16 +54,22 @@ def test_bert_trainer(self, num_classes, dict_outputs):
53
54
expected_classification_shape = [None , num_classes ]
54
55
self .assertAllEqual (expected_classification_shape , cls_outs .shape .as_list ())
55
56
56
- @parameterized .parameters (1 , 2 )
57
- def test_bert_trainer_tensor_call (self , num_classes ):
57
+ @parameterized .named_parameters (
58
+ ('single_cls' , 1 , False ),
59
+ ('2_cls' , 2 , False ),
60
+ ('single_cls_custom_head' , 1 , True ),
61
+ ('2_cls_custom_head' , 2 , True ))
62
+ def test_bert_trainer_tensor_call (self , num_classes , use_custom_head ):
58
63
"""Validate that the Keras object can be invoked."""
59
64
# Build a transformer network to use within the BERT trainer. (Here, we use
60
65
# a short sequence_length for convenience.)
61
66
test_network = networks .BertEncoder (vocab_size = 100 , num_layers = 2 )
67
+ cls_head = layers .GaussianProcessClassificationHead (
68
+ inner_dim = 0 , num_classes = num_classes ) if use_custom_head else None
62
69
63
70
# Create a BERT trainer with the created network.
64
71
bert_trainer_model = bert_classifier .BertClassifier (
65
- test_network , num_classes = num_classes )
72
+ test_network , num_classes = num_classes , cls_head = cls_head )
66
73
67
74
# Create a set of 2-dimensional data tensors to feed into the model.
68
75
word_ids = tf .constant ([[1 , 1 ], [2 , 2 ]], dtype = tf .int32 )
@@ -74,7 +81,11 @@ def test_bert_trainer_tensor_call(self, num_classes):
74
81
# too complex: this simply ensures we're not hitting runtime errors.)
75
82
_ = bert_trainer_model ([word_ids , mask , type_ids ])
76
83
77
- def test_serialize_deserialize (self ):
84
+ @parameterized .named_parameters (
85
+ ('default_cls_head' , None ),
86
+ ('sngp_cls_head' , layers .GaussianProcessClassificationHead (
87
+ inner_dim = 0 , num_classes = 4 )))
88
+ def test_serialize_deserialize (self , cls_head ):
78
89
"""Validate that the BERT trainer can be serialized and deserialized."""
79
90
# Build a transformer network to use within the BERT trainer. (Here, we use
80
91
# a short sequence_length for convenience.)
@@ -84,7 +95,7 @@ def test_serialize_deserialize(self):
84
95
# Create a BERT trainer with the created network. (Note that all the args
85
96
# are different, so we can catch any serialization mismatches.)
86
97
bert_trainer_model = bert_classifier .BertClassifier (
87
- test_network , num_classes = 4 , initializer = 'zeros' )
98
+ test_network , num_classes = 4 , initializer = 'zeros' , cls_head = cls_head )
88
99
89
100
# Create another BERT trainer via serialization and deserialization.
90
101
config = bert_trainer_model .get_config ()
0 commit comments