Skip to content

Commit be4e155

Browse files
[nlp][progressive] Opensource progressive tasks.
PiperOrigin-RevId: 365242134
1 parent cc44fd8 commit be4e155

File tree

4 files changed

+750
-0
lines changed

4 files changed

+750
-0
lines changed
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Masked language task with progressive training."""
16+
17+
from typing import List
18+
# Import libraries
19+
from absl import logging
20+
import dataclasses
21+
import orbit
22+
import tensorflow as tf
23+
24+
from official.core import config_definitions as cfg
25+
from official.core import task_factory
26+
from official.modeling import optimization
27+
from official.modeling.hyperparams import base_config
28+
from official.modeling.progressive import policies
29+
from official.nlp.tasks import masked_lm
30+
31+
32+
@dataclasses.dataclass
33+
class StackingStageConfig(base_config.Config):
34+
num_layers: int = 0
35+
num_steps: int = 0
36+
warmup_steps: int = 10000
37+
initial_learning_rate: float = 1e-4
38+
end_learning_rate: float = 0.0
39+
decay_steps: int = 1000000
40+
41+
42+
@dataclasses.dataclass
43+
class ProgMaskedLMConfig(masked_lm.MaskedLMConfig):
44+
"""The progressive model config."""
45+
optimizer_config: optimization.OptimizationConfig = (
46+
optimization.OptimizationConfig(
47+
optimizer=optimization.OptimizerConfig(type='adamw'),
48+
learning_rate=optimization.LrConfig(type='polynomial'),
49+
warmup=optimization.WarmupConfig(type='polynomial'),
50+
)
51+
)
52+
stage_list: List[StackingStageConfig] = dataclasses.field(
53+
default_factory=lambda: [ # pylint: disable=g-long-lambda
54+
StackingStageConfig(num_layers=3,
55+
num_steps=112500,
56+
warmup_steps=10000,
57+
initial_learning_rate=1e-4,
58+
end_learning_rate=1e-4,
59+
decay_steps=112500),
60+
StackingStageConfig(num_layers=6,
61+
num_steps=112500,
62+
warmup_steps=10000,
63+
initial_learning_rate=1e-4,
64+
end_learning_rate=1e-4,
65+
decay_steps=112500),
66+
StackingStageConfig(num_layers=12,
67+
num_steps=450000,
68+
warmup_steps=10000,
69+
initial_learning_rate=1e-4,
70+
end_learning_rate=0.0,
71+
decay_steps=450000)])
72+
73+
74+
@task_factory.register_task_cls(ProgMaskedLMConfig)
75+
class ProgressiveMaskedLM(policies.ProgressivePolicy, masked_lm.MaskedLMTask):
76+
"""Masked Language Model that supports progressive training.
77+
78+
Inherate from the MaskedLmTask class to build model datasets etc.
79+
"""
80+
81+
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
82+
masked_lm.MaskedLMTask.__init__(
83+
self, params=params, logging_dir=logging_dir)
84+
self._model_config = params.model
85+
self._optimizer_config = params.optimizer_config
86+
self._the_only_train_dataset = None
87+
self._the_only_eval_dataset = None
88+
policies.ProgressivePolicy.__init__(self)
89+
90+
# Override
91+
def num_stages(self):
92+
return len(self.task_config.stage_list)
93+
94+
# Override
95+
def num_steps(self, stage_id):
96+
return self.task_config.stage_list[stage_id].num_steps
97+
98+
# Override
99+
def get_model(self, stage_id, old_model=None):
100+
"""Build model for each stage."""
101+
num_layers = self.task_config.stage_list[stage_id].num_layers
102+
encoder_type = self._model_config.encoder.type
103+
params = self._model_config.replace(
104+
encoder={encoder_type: {
105+
'num_layers': num_layers
106+
}})
107+
model = self.build_model(params)
108+
109+
# Run the model once, to make sure that all layers are built.
110+
# Otherwise, not all weights will be copied.
111+
_ = model(model.inputs)
112+
113+
if stage_id > 0 and old_model is not None:
114+
logging.info('Stage %d copying weights.', stage_id)
115+
self._copy_weights_to_new_model(old_model=old_model,
116+
new_model=model)
117+
return model
118+
119+
# Override
120+
def get_optimizer(self, stage_id):
121+
"""Build optimizer for each stage."""
122+
params = self._optimizer_config.replace(
123+
learning_rate={
124+
'polynomial':
125+
{'decay_steps':
126+
self.task_config.stage_list[
127+
stage_id].decay_steps,
128+
'initial_learning_rate':
129+
self.task_config.stage_list[
130+
stage_id].initial_learning_rate,
131+
'end_learning_rate':
132+
self.task_config.stage_list[
133+
stage_id].end_learning_rate,
134+
'power': 1,
135+
'cycle': False,
136+
}
137+
},
138+
warmup={
139+
'polynomial':
140+
{'warmup_steps':
141+
self.task_config.stage_list[stage_id].warmup_steps,
142+
'power': 1,
143+
}
144+
}
145+
)
146+
opt_factory = optimization.OptimizerFactory(params)
147+
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
148+
149+
return optimizer
150+
151+
# overrides policies.ProgressivePolicy
152+
def get_train_dataset(self, stage_id):
153+
del stage_id
154+
if self._the_only_train_dataset is None:
155+
strategy = tf.distribute.get_strategy()
156+
self._the_only_train_dataset = orbit.utils.make_distributed_dataset(
157+
strategy,
158+
self.build_inputs,
159+
self.task_config.train_data)
160+
return self._the_only_train_dataset
161+
162+
# overrides policies.ProgressivePolicy
163+
def get_eval_dataset(self, stage_id):
164+
del stage_id
165+
if self._the_only_eval_dataset is None:
166+
strategy = tf.distribute.get_strategy()
167+
self._the_only_eval_dataset = orbit.utils.make_distributed_dataset(
168+
strategy,
169+
self.build_inputs,
170+
self.task_config.validation_data)
171+
return self._the_only_eval_dataset
172+
173+
def _copy_weights_to_new_model(self, old_model, new_model):
174+
"""Copy model weights from the previous stage to the next.
175+
176+
Args:
177+
old_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of
178+
the previous stage.
179+
new_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of
180+
the next stage.
181+
"""
182+
# Copy weights of the embedding layers.
183+
# pylint: disable=protected-access
184+
# When using `encoder_scaffold`, there may be `_embedding_network`.
185+
if hasattr(new_model.encoder_network, '_embedding_network') and hasattr(
186+
old_model.encoder_network, '_embedding_network') and (
187+
new_model.encoder_network._embedding_network is not None):
188+
new_model.encoder_network._embedding_network.set_weights(
189+
old_model.encoder_network._embedding_network.get_weights())
190+
else:
191+
new_model.encoder_network._embedding_layer.set_weights(
192+
old_model.encoder_network._embedding_layer.get_weights())
193+
new_model.encoder_network._position_embedding_layer.set_weights(
194+
old_model.encoder_network._position_embedding_layer.get_weights())
195+
new_model.encoder_network._type_embedding_layer.set_weights(
196+
old_model.encoder_network._type_embedding_layer.get_weights())
197+
new_model.encoder_network._embedding_norm_layer.set_weights(
198+
old_model.encoder_network._embedding_norm_layer.get_weights())
199+
if hasattr(new_model.encoder_network, '_embedding_projection') and hasattr(
200+
old_model.encoder_network, '_embedding_projection'):
201+
if old_model.encoder_network._embedding_projection is not None:
202+
new_model.encoder_network._embedding_projection.set_weights(
203+
old_model.encoder_network._embedding_projection.get_weights())
204+
# pylint: enable=protected-access
205+
206+
# Copy weights of the transformer layers.
207+
# The model can be EncoderScaffold or TransformerEncoder.
208+
if hasattr(old_model.encoder_network, 'hidden_layers'):
209+
old_layer_group = old_model.encoder_network.hidden_layers
210+
elif hasattr(old_model.encoder_network, 'transformer_layers'):
211+
old_layer_group = old_model.encoder_network.transformer_layers
212+
else:
213+
raise ValueError('Unrecognized encoder network: {}'.format(
214+
old_model.encoder_network))
215+
if hasattr(new_model.encoder_network, 'hidden_layers'):
216+
new_layer_group = new_model.encoder_network.hidden_layers
217+
elif hasattr(new_model.encoder_network, 'transformer_layers'):
218+
new_layer_group = new_model.encoder_network.transformer_layers
219+
else:
220+
raise ValueError('Unrecognized encoder network: {}'.format(
221+
new_model.encoder_network))
222+
for new_layer_idx in range(len(new_layer_group)):
223+
old_layer_idx = new_layer_idx % len(old_layer_group)
224+
new_layer_group[new_layer_idx].set_weights(
225+
old_layer_group[old_layer_idx].get_weights())
226+
if old_layer_idx != new_layer_idx:
227+
if hasattr(new_layer_group[new_layer_idx], 'reset_rezero'):
228+
# Reset ReZero's alpha to 0.
229+
new_layer_group[new_layer_idx].reset_rezero()
230+
231+
# Copy weights of the final layer norm (if needed).
232+
# pylint: disable=protected-access
233+
if hasattr(new_model.encoder_network, '_output_layer_norm') and hasattr(
234+
old_model.encoder_network, '_output_layer_norm'):
235+
new_model.encoder_network._output_layer_norm.set_weights(
236+
old_model.encoder_network._output_layer_norm.get_weights())
237+
# pylint: enable=protected-access
238+
239+
# Copy weights of the pooler layer.
240+
new_model.encoder_network.pooler_layer.set_weights(
241+
old_model.encoder_network.pooler_layer.get_weights())
242+
243+
# Copy weights of the classification head.
244+
for idx in range(len(new_model.classification_heads)):
245+
new_model.classification_heads[idx].set_weights(
246+
old_model.classification_heads[idx].get_weights())
247+
248+
# Copy weights of the masked_lm layer.
249+
new_model.masked_lm.set_weights(old_model.masked_lm.get_weights())
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for google.nlp.progressive_masked_lm."""
16+
17+
# Import libraries
18+
from absl.testing import parameterized
19+
import gin
20+
import tensorflow as tf
21+
22+
from tensorflow.python.distribute import combinations
23+
from tensorflow.python.distribute import strategy_combinations
24+
from official.core import config_definitions as cfg
25+
from official.modeling.progressive import trainer as prog_trainer_lib
26+
from official.nlp.configs import bert
27+
from official.nlp.configs import encoders
28+
from official.nlp.data import pretrain_dataloader
29+
from official.nlp.tasks import progressive_masked_lm
30+
31+
32+
def all_strategy_combinations():
33+
return combinations.combine(
34+
distribution=[
35+
strategy_combinations.default_strategy,
36+
strategy_combinations.cloud_tpu_strategy,
37+
strategy_combinations.one_device_strategy_gpu,
38+
],)
39+
40+
41+
class ProgressiveMaskedLMTest(tf.test.TestCase, parameterized.TestCase):
42+
43+
def setUp(self):
44+
super(ProgressiveMaskedLMTest, self).setUp()
45+
self.task_config = progressive_masked_lm.ProgMaskedLMConfig(
46+
model=bert.PretrainerConfig(
47+
encoder=encoders.EncoderConfig(
48+
bert=encoders.BertEncoderConfig(vocab_size=30522,
49+
num_layers=2)),
50+
cls_heads=[
51+
bert.ClsHeadConfig(
52+
inner_dim=10, num_classes=2, name="next_sentence")
53+
]),
54+
train_data=pretrain_dataloader.BertPretrainDataConfig(
55+
input_path="dummy",
56+
max_predictions_per_seq=20,
57+
seq_length=128,
58+
global_batch_size=1),
59+
validation_data=pretrain_dataloader.BertPretrainDataConfig(
60+
input_path="dummy",
61+
max_predictions_per_seq=20,
62+
seq_length=128,
63+
global_batch_size=1),
64+
stage_list=[
65+
progressive_masked_lm.StackingStageConfig(
66+
num_layers=1, num_steps=4),
67+
progressive_masked_lm.StackingStageConfig(
68+
num_layers=2, num_steps=8),
69+
],
70+
)
71+
self.exp_config = cfg.ExperimentConfig(
72+
task=self.task_config,
73+
trainer=prog_trainer_lib.ProgressiveTrainerConfig())
74+
75+
@combinations.generate(all_strategy_combinations())
76+
def test_num_stages(self, distribution):
77+
with distribution.scope():
78+
prog_masked_lm = progressive_masked_lm.ProgressiveMaskedLM(
79+
self.task_config)
80+
self.assertEqual(prog_masked_lm.num_stages(), 2)
81+
self.assertEqual(prog_masked_lm.num_steps(0), 4)
82+
self.assertEqual(prog_masked_lm.num_steps(1), 8)
83+
84+
@combinations.generate(all_strategy_combinations())
85+
def test_weight_copying(self, distribution):
86+
with distribution.scope():
87+
prog_masked_lm = progressive_masked_lm.ProgressiveMaskedLM(
88+
self.task_config)
89+
old_model = prog_masked_lm.get_model(stage_id=0)
90+
for w in old_model.trainable_weights:
91+
w.assign(tf.zeros_like(w) + 0.12345)
92+
new_model = prog_masked_lm.get_model(stage_id=1, old_model=old_model)
93+
for w in new_model.trainable_weights:
94+
self.assertAllClose(w, tf.zeros_like(w) + 0.12345)
95+
96+
gin.parse_config_files_and_bindings(
97+
None, "encoders.build_encoder.encoder_cls = @EncoderScaffold")
98+
with distribution.scope():
99+
prog_masked_lm = progressive_masked_lm.ProgressiveMaskedLM(
100+
self.task_config)
101+
old_model = prog_masked_lm.get_model(stage_id=0)
102+
for w in old_model.trainable_weights:
103+
w.assign(tf.zeros_like(w) + 0.12345)
104+
new_model = prog_masked_lm.get_model(stage_id=1, old_model=old_model)
105+
for w in new_model.trainable_weights:
106+
self.assertAllClose(w, tf.zeros_like(w) + 0.12345)
107+
108+
109+
if __name__ == "__main__":
110+
tf.test.main()

0 commit comments

Comments
 (0)