Skip to content

Commit a49c733

Browse files
jereliutensorflower-gardener
authored andcommitted
Adds unit tests for using cls_head in BertClassifier.
PiperOrigin-RevId: 364952108
1 parent 44c2b99 commit a49c733

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

official/nlp/modeling/models/bert_classifier.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ class BertClassifier(tf.keras.Model):
4545
dropout_rate: The dropout probability of the cls head.
4646
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
4747
encoder.
48-
cls_head: (Optional) The layer instance to use for the classifier head
49-
. It should take in the output from network and produce the final logits.
48+
cls_head: (Optional) The layer instance to use for the classifier head.
49+
It should take in the output from network and produce the final logits.
5050
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
5151
'use_encoder_pooler') will be ignored.
5252
"""
@@ -62,7 +62,6 @@ def __init__(self,
6262
self.num_classes = num_classes
6363
self.initializer = initializer
6464
self.use_encoder_pooler = use_encoder_pooler
65-
self.cls_head = cls_head
6665

6766
# We want to use the inputs of the passed network as the inputs to this
6867
# Model. To do this, we need to keep a handle to the network inputs for use
@@ -107,6 +106,8 @@ def __init__(self,
107106
super(BertClassifier, self).__init__(
108107
inputs=inputs, outputs=predictions, **kwargs)
109108
self._network = network
109+
self._cls_head = cls_head
110+
110111
config_dict = self._make_config_dict()
111112
# We are storing the config dict as a namedtuple here to ensure checkpoint
112113
# compatibility with an earlier version of this model which did not track
@@ -138,5 +139,5 @@ def _make_config_dict(self):
138139
'num_classes': self.num_classes,
139140
'initializer': self.initializer,
140141
'use_encoder_pooler': self.use_encoder_pooler,
141-
'cls_head': self.cls_head,
142+
'cls_head': self._cls_head,
142143
}

official/nlp/modeling/models/bert_classifier_test.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import tensorflow as tf
1919

2020
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
21+
from official.nlp.modeling import layers
2122
from official.nlp.modeling import networks
2223
from official.nlp.modeling.models import bert_classifier
2324

@@ -53,16 +54,22 @@ def test_bert_trainer(self, num_classes, dict_outputs):
5354
expected_classification_shape = [None, num_classes]
5455
self.assertAllEqual(expected_classification_shape, cls_outs.shape.as_list())
5556

56-
@parameterized.parameters(1, 2)
57-
def test_bert_trainer_tensor_call(self, num_classes):
57+
@parameterized.named_parameters(
58+
('single_cls', 1, False),
59+
('2_cls', 2, False),
60+
('single_cls_custom_head', 1, True),
61+
('2_cls_custom_head', 2, True))
62+
def test_bert_trainer_tensor_call(self, num_classes, use_custom_head):
5863
"""Validate that the Keras object can be invoked."""
5964
# Build a transformer network to use within the BERT trainer. (Here, we use
6065
# a short sequence_length for convenience.)
6166
test_network = networks.BertEncoder(vocab_size=100, num_layers=2)
67+
cls_head = layers.GaussianProcessClassificationHead(
68+
inner_dim=0, num_classes=num_classes) if use_custom_head else None
6269

6370
# Create a BERT trainer with the created network.
6471
bert_trainer_model = bert_classifier.BertClassifier(
65-
test_network, num_classes=num_classes)
72+
test_network, num_classes=num_classes, cls_head=cls_head)
6673

6774
# Create a set of 2-dimensional data tensors to feed into the model.
6875
word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
@@ -74,7 +81,11 @@ def test_bert_trainer_tensor_call(self, num_classes):
7481
# too complex: this simply ensures we're not hitting runtime errors.)
7582
_ = bert_trainer_model([word_ids, mask, type_ids])
7683

77-
def test_serialize_deserialize(self):
84+
@parameterized.named_parameters(
85+
('default_cls_head', None),
86+
('sngp_cls_head', layers.GaussianProcessClassificationHead(
87+
inner_dim=0, num_classes=4)))
88+
def test_serialize_deserialize(self, cls_head):
7889
"""Validate that the BERT trainer can be serialized and deserialized."""
7990
# Build a transformer network to use within the BERT trainer. (Here, we use
8091
# a short sequence_length for convenience.)
@@ -84,7 +95,7 @@ def test_serialize_deserialize(self):
8495
# Create a BERT trainer with the created network. (Note that all the args
8596
# are different, so we can catch any serialization mismatches.)
8697
bert_trainer_model = bert_classifier.BertClassifier(
87-
test_network, num_classes=4, initializer='zeros')
98+
test_network, num_classes=4, initializer='zeros', cls_head=cls_head)
8899

89100
# Create another BERT trainer via serialization and deserialization.
90101
config = bert_trainer_model.get_config()

0 commit comments

Comments
 (0)