Skip to content

Commit f87d669

Browse files
jereliutensorflower-gardener
authored andcommitted
Deprecate network.Classification from BertClassifier.
PiperOrigin-RevId: 364722225
1 parent 0032b25 commit f87d669

File tree

3 files changed

+111
-49
lines changed

3 files changed

+111
-49
lines changed

official/nlp/modeling/layers/cls_head.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def __init__(self,
3636
"""Initializes the `ClassificationHead`.
3737
3838
Args:
39-
inner_dim: The dimensionality of inner projection layer.
39+
inner_dim: The dimensionality of inner projection layer. If 0 or `None`
40+
then only the output projection layer is created.
4041
num_classes: Number of output classes.
4142
cls_token_idx: The index inside the sequence to pool.
4243
activation: Dense layer activation.
@@ -52,19 +53,25 @@ def __init__(self,
5253
self.initializer = tf.keras.initializers.get(initializer)
5354
self.cls_token_idx = cls_token_idx
5455

55-
self.dense = tf.keras.layers.Dense(
56-
units=inner_dim,
57-
activation=self.activation,
58-
kernel_initializer=self.initializer,
59-
name="pooler_dense")
60-
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
56+
if self.inner_dim:
57+
self.dense = tf.keras.layers.Dense(
58+
units=self.inner_dim,
59+
activation=self.activation,
60+
kernel_initializer=self.initializer,
61+
name="pooler_dense")
62+
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
63+
6164
self.out_proj = tf.keras.layers.Dense(
6265
units=num_classes, kernel_initializer=self.initializer, name="logits")
6366

6467
def call(self, features):
65-
x = features[:, self.cls_token_idx, :] # take <CLS> token.
66-
x = self.dense(x)
67-
x = self.dropout(x)
68+
if not self.inner_dim:
69+
x = features
70+
else:
71+
x = features[:, self.cls_token_idx, :] # take <CLS> token.
72+
x = self.dense(x)
73+
x = self.dropout(x)
74+
6875
x = self.out_proj(x)
6976
return x
7077

@@ -103,7 +110,8 @@ def __init__(self,
103110
"""Initializes the `MultiClsHeads`.
104111
105112
Args:
106-
inner_dim: The dimensionality of inner projection layer.
113+
inner_dim: The dimensionality of inner projection layer. If 0 or `None`
114+
then only the output projection layer is created.
107115
cls_list: a list of pairs of (classification problem name and the numbers
108116
of classes.
109117
cls_token_idx: The index inside the sequence to pool.
@@ -120,12 +128,13 @@ def __init__(self,
120128
self.initializer = tf.keras.initializers.get(initializer)
121129
self.cls_token_idx = cls_token_idx
122130

123-
self.dense = tf.keras.layers.Dense(
124-
units=inner_dim,
125-
activation=self.activation,
126-
kernel_initializer=self.initializer,
127-
name="pooler_dense")
128-
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
131+
if self.inner_dim:
132+
self.dense = tf.keras.layers.Dense(
133+
units=inner_dim,
134+
activation=self.activation,
135+
kernel_initializer=self.initializer,
136+
name="pooler_dense")
137+
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
129138
self.out_projs = []
130139
for name, num_classes in cls_list:
131140
self.out_projs.append(
@@ -134,9 +143,13 @@ def __init__(self,
134143
name=name))
135144

136145
def call(self, features):
137-
x = features[:, self.cls_token_idx, :] # take <CLS> token.
138-
x = self.dense(x)
139-
x = self.dropout(x)
146+
if not self.inner_dim:
147+
x = features
148+
else:
149+
x = features[:, self.cls_token_idx, :] # take <CLS> token.
150+
x = self.dense(x)
151+
x = self.dropout(x)
152+
140153
outputs = {}
141154
for proj_layer in self.out_projs:
142155
outputs[proj_layer.name] = proj_layer(x)
@@ -195,7 +208,8 @@ def __init__(self,
195208
"""Initializes the `GaussianProcessClassificationHead`.
196209
197210
Args:
198-
inner_dim: The dimensionality of inner projection layer.
211+
inner_dim: The dimensionality of inner projection layer. If 0 or `None`
212+
then only the output projection layer is created.
199213
num_classes: Number of output classes.
200214
cls_token_idx: The index inside the sequence to pool.
201215
activation: Dense layer activation.
@@ -220,8 +234,8 @@ def __init__(self,
220234
initializer=initializer,
221235
**kwargs)
222236

223-
# Applies spectral normalization to the pooler layer.
224-
if use_spec_norm:
237+
# Applies spectral normalization to the dense pooler layer.
238+
if self.use_spec_norm and hasattr(self, "dense"):
225239
self.dense = spectral_normalization.SpectralNormalization(
226240
self.dense, inhere_layer_name=True, **self.spec_norm_kwargs)
227241

official/nlp/modeling/layers/cls_head_test.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,24 @@
1313
# limitations under the License.
1414

1515
"""Tests for cls_head."""
16+
from absl.testing import parameterized
1617

1718
import tensorflow as tf
1819

1920
from official.nlp.modeling.layers import cls_head
2021

2122

22-
class ClassificationHeadTest(tf.test.TestCase):
23+
class ClassificationHeadTest(tf.test.TestCase, parameterized.TestCase):
24+
25+
@parameterized.named_parameters(("no_pooler_layer", 0, 2),
26+
("has_pooler_layer", 5, 4))
27+
def test_pooler_layer(self, inner_dim, num_weights_expected):
28+
test_layer = cls_head.ClassificationHead(inner_dim=inner_dim, num_classes=2)
29+
features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
30+
_ = test_layer(features)
31+
32+
num_weights_observed = len(test_layer.get_weights())
33+
self.assertEqual(num_weights_observed, num_weights_expected)
2334

2435
def test_layer_invocation(self):
2536
test_layer = cls_head.ClassificationHead(inner_dim=5, num_classes=2)
@@ -37,7 +48,18 @@ def test_layer_serialization(self):
3748
self.assertAllEqual(layer.get_config(), new_layer.get_config())
3849

3950

40-
class MultiClsHeadsTest(tf.test.TestCase):
51+
class MultiClsHeadsTest(tf.test.TestCase, parameterized.TestCase):
52+
53+
@parameterized.named_parameters(("no_pooler_layer", 0, 4),
54+
("has_pooler_layer", 5, 6))
55+
def test_pooler_layer(self, inner_dim, num_weights_expected):
56+
cls_list = [("foo", 2), ("bar", 3)]
57+
test_layer = cls_head.MultiClsHeads(inner_dim=inner_dim, cls_list=cls_list)
58+
features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
59+
_ = test_layer(features)
60+
61+
num_weights_observed = len(test_layer.get_weights())
62+
self.assertEqual(num_weights_observed, num_weights_expected)
4163

4264
def test_layer_invocation(self):
4365
cls_list = [("foo", 2), ("bar", 3)]
@@ -58,13 +80,31 @@ def test_layer_serialization(self):
5880
self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
5981

6082

61-
class GaussianProcessClassificationHead(tf.test.TestCase):
83+
class GaussianProcessClassificationHead(tf.test.TestCase,
84+
parameterized.TestCase):
6285

6386
def setUp(self):
6487
super().setUp()
6588
self.spec_norm_kwargs = dict(norm_multiplier=1.,)
6689
self.gp_layer_kwargs = dict(num_inducing=512)
6790

91+
@parameterized.named_parameters(("no_pooler_layer", 0, 7),
92+
("has_pooler_layer", 5, 11))
93+
def test_pooler_layer(self, inner_dim, num_weights_expected):
94+
test_layer = cls_head.GaussianProcessClassificationHead(
95+
inner_dim=inner_dim,
96+
num_classes=2,
97+
use_spec_norm=True,
98+
use_gp_layer=True,
99+
initializer="zeros",
100+
**self.spec_norm_kwargs,
101+
**self.gp_layer_kwargs)
102+
features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
103+
_ = test_layer(features)
104+
105+
num_weights_observed = len(test_layer.get_weights())
106+
self.assertEqual(num_weights_observed, num_weights_expected)
107+
68108
def test_layer_invocation(self):
69109
test_layer = cls_head.GaussianProcessClassificationHead(
70110
inner_dim=5,

official/nlp/modeling/models/bert_classifier.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import tensorflow as tf
1919

2020
from official.nlp.modeling import layers
21-
from official.nlp.modeling import networks
2221

2322

2423
@tf.keras.utils.register_keras_serializable(package='Text')
@@ -46,6 +45,10 @@ class BertClassifier(tf.keras.Model):
4645
dropout_rate: The dropout probability of the cls head.
4746
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
4847
encoder.
48+
cls_head: (Optional) The layer instance to use for the classifier head
49+
. It should take in the output from network and produce the final logits.
50+
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
51+
'use_encoder_pooler') will be ignored.
4952
"""
5053

5154
def __init__(self,
@@ -54,7 +57,12 @@ def __init__(self,
5457
initializer='glorot_uniform',
5558
dropout_rate=0.1,
5659
use_encoder_pooler=True,
60+
cls_head=None,
5761
**kwargs):
62+
self.num_classes = num_classes
63+
self.initializer = initializer
64+
self.use_encoder_pooler = use_encoder_pooler
65+
self.cls_head = cls_head
5866

5967
# We want to use the inputs of the passed network as the inputs to this
6068
# Model. To do this, we need to keep a handle to the network inputs for use
@@ -66,31 +74,28 @@ def __init__(self,
6674
# invoke the Network object with its own input tensors to start the Model.
6775
outputs = network(inputs)
6876
if isinstance(outputs, list):
69-
cls_output = outputs[1]
77+
cls_inputs = outputs[1]
7078
else:
71-
cls_output = outputs['pooled_output']
72-
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output)
73-
74-
classifier = networks.Classification(
75-
input_width=cls_output.shape[-1],
76-
num_classes=num_classes,
77-
initializer=initializer,
78-
output='logits',
79-
name='sentence_prediction')
80-
predictions = classifier(cls_output)
79+
cls_inputs = outputs['pooled_output']
80+
cls_inputs = tf.keras.layers.Dropout(rate=dropout_rate)(cls_inputs)
8181
else:
8282
outputs = network(inputs)
8383
if isinstance(outputs, list):
84-
sequence_output = outputs[0]
84+
cls_inputs = outputs[0]
8585
else:
86-
sequence_output = outputs['sequence_output']
86+
cls_inputs = outputs['sequence_output']
87+
88+
if cls_head:
89+
classifier = cls_head
90+
else:
8791
classifier = layers.ClassificationHead(
88-
inner_dim=sequence_output.shape[-1],
92+
inner_dim=0 if use_encoder_pooler else cls_inputs.shape[-1],
8993
num_classes=num_classes,
9094
initializer=initializer,
9195
dropout_rate=dropout_rate,
9296
name='sentence_prediction')
93-
predictions = classifier(sequence_output)
97+
98+
predictions = classifier(cls_inputs)
9499

95100
# b/164516224
96101
# Once we've created the network using the Functional API, we call
@@ -102,13 +107,7 @@ def __init__(self,
102107
super(BertClassifier, self).__init__(
103108
inputs=inputs, outputs=predictions, **kwargs)
104109
self._network = network
105-
config_dict = {
106-
'network': network,
107-
'num_classes': num_classes,
108-
'initializer': initializer,
109-
'use_encoder_pooler': use_encoder_pooler,
110-
}
111-
110+
config_dict = self._make_config_dict()
112111
# We are storing the config dict as a namedtuple here to ensure checkpoint
113112
# compatibility with an earlier version of this model which did not track
114113
# the config dict attribute. TF does not track immutable attrs which
@@ -132,3 +131,12 @@ def get_config(self):
132131
@classmethod
133132
def from_config(cls, config, custom_objects=None):
134133
return cls(**config)
134+
135+
def _make_config_dict(self):
136+
return {
137+
'network': self._network,
138+
'num_classes': self.num_classes,
139+
'initializer': self.initializer,
140+
'use_encoder_pooler': self.use_encoder_pooler,
141+
'cls_head': self.cls_head,
142+
}

0 commit comments

Comments
 (0)