|
| 1 | +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Masked language task with progressive training.""" |
| 16 | + |
| 17 | +from typing import List |
| 18 | +# Import libraries |
| 19 | +from absl import logging |
| 20 | +import dataclasses |
| 21 | +import orbit |
| 22 | +import tensorflow as tf |
| 23 | + |
| 24 | +from official.core import config_definitions as cfg |
| 25 | +from official.core import task_factory |
| 26 | +from official.modeling import optimization |
| 27 | +from official.modeling.hyperparams import base_config |
| 28 | +from official.modeling.progressive import policies |
| 29 | +from official.nlp.tasks import masked_lm |
| 30 | + |
| 31 | + |
| 32 | +@dataclasses.dataclass |
| 33 | +class StackingStageConfig(base_config.Config): |
| 34 | + num_layers: int = 0 |
| 35 | + num_steps: int = 0 |
| 36 | + warmup_steps: int = 10000 |
| 37 | + initial_learning_rate: float = 1e-4 |
| 38 | + end_learning_rate: float = 0.0 |
| 39 | + decay_steps: int = 1000000 |
| 40 | + |
| 41 | + |
| 42 | +@dataclasses.dataclass |
| 43 | +class ProgMaskedLMConfig(masked_lm.MaskedLMConfig): |
| 44 | + """The progressive model config.""" |
| 45 | + optimizer_config: optimization.OptimizationConfig = ( |
| 46 | + optimization.OptimizationConfig( |
| 47 | + optimizer=optimization.OptimizerConfig(type='adamw'), |
| 48 | + learning_rate=optimization.LrConfig(type='polynomial'), |
| 49 | + warmup=optimization.WarmupConfig(type='polynomial'), |
| 50 | + ) |
| 51 | + ) |
| 52 | + stage_list: List[StackingStageConfig] = dataclasses.field( |
| 53 | + default_factory=lambda: [ # pylint: disable=g-long-lambda |
| 54 | + StackingStageConfig(num_layers=3, |
| 55 | + num_steps=112500, |
| 56 | + warmup_steps=10000, |
| 57 | + initial_learning_rate=1e-4, |
| 58 | + end_learning_rate=1e-4, |
| 59 | + decay_steps=112500), |
| 60 | + StackingStageConfig(num_layers=6, |
| 61 | + num_steps=112500, |
| 62 | + warmup_steps=10000, |
| 63 | + initial_learning_rate=1e-4, |
| 64 | + end_learning_rate=1e-4, |
| 65 | + decay_steps=112500), |
| 66 | + StackingStageConfig(num_layers=12, |
| 67 | + num_steps=450000, |
| 68 | + warmup_steps=10000, |
| 69 | + initial_learning_rate=1e-4, |
| 70 | + end_learning_rate=0.0, |
| 71 | + decay_steps=450000)]) |
| 72 | + |
| 73 | + |
| 74 | +@task_factory.register_task_cls(ProgMaskedLMConfig) |
| 75 | +class ProgressiveMaskedLM(policies.ProgressivePolicy, masked_lm.MaskedLMTask): |
| 76 | + """Masked Language Model that supports progressive training. |
| 77 | +
|
| 78 | + Inherate from the MaskedLmTask class to build model datasets etc. |
| 79 | + """ |
| 80 | + |
| 81 | + def __init__(self, params: cfg.TaskConfig, logging_dir: str = None): |
| 82 | + masked_lm.MaskedLMTask.__init__( |
| 83 | + self, params=params, logging_dir=logging_dir) |
| 84 | + self._model_config = params.model |
| 85 | + self._optimizer_config = params.optimizer_config |
| 86 | + self._the_only_train_dataset = None |
| 87 | + self._the_only_eval_dataset = None |
| 88 | + policies.ProgressivePolicy.__init__(self) |
| 89 | + |
| 90 | + # Override |
| 91 | + def num_stages(self): |
| 92 | + return len(self.task_config.stage_list) |
| 93 | + |
| 94 | + # Override |
| 95 | + def num_steps(self, stage_id): |
| 96 | + return self.task_config.stage_list[stage_id].num_steps |
| 97 | + |
| 98 | + # Override |
| 99 | + def get_model(self, stage_id, old_model=None): |
| 100 | + """Build model for each stage.""" |
| 101 | + num_layers = self.task_config.stage_list[stage_id].num_layers |
| 102 | + encoder_type = self._model_config.encoder.type |
| 103 | + params = self._model_config.replace( |
| 104 | + encoder={encoder_type: { |
| 105 | + 'num_layers': num_layers |
| 106 | + }}) |
| 107 | + model = self.build_model(params) |
| 108 | + |
| 109 | + # Run the model once, to make sure that all layers are built. |
| 110 | + # Otherwise, not all weights will be copied. |
| 111 | + _ = model(model.inputs) |
| 112 | + |
| 113 | + if stage_id > 0 and old_model is not None: |
| 114 | + logging.info('Stage %d copying weights.', stage_id) |
| 115 | + self._copy_weights_to_new_model(old_model=old_model, |
| 116 | + new_model=model) |
| 117 | + return model |
| 118 | + |
| 119 | + # Override |
| 120 | + def get_optimizer(self, stage_id): |
| 121 | + """Build optimizer for each stage.""" |
| 122 | + params = self._optimizer_config.replace( |
| 123 | + learning_rate={ |
| 124 | + 'polynomial': |
| 125 | + {'decay_steps': |
| 126 | + self.task_config.stage_list[ |
| 127 | + stage_id].decay_steps, |
| 128 | + 'initial_learning_rate': |
| 129 | + self.task_config.stage_list[ |
| 130 | + stage_id].initial_learning_rate, |
| 131 | + 'end_learning_rate': |
| 132 | + self.task_config.stage_list[ |
| 133 | + stage_id].end_learning_rate, |
| 134 | + 'power': 1, |
| 135 | + 'cycle': False, |
| 136 | + } |
| 137 | + }, |
| 138 | + warmup={ |
| 139 | + 'polynomial': |
| 140 | + {'warmup_steps': |
| 141 | + self.task_config.stage_list[stage_id].warmup_steps, |
| 142 | + 'power': 1, |
| 143 | + } |
| 144 | + } |
| 145 | + ) |
| 146 | + opt_factory = optimization.OptimizerFactory(params) |
| 147 | + optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) |
| 148 | + |
| 149 | + return optimizer |
| 150 | + |
| 151 | + # overrides policies.ProgressivePolicy |
| 152 | + def get_train_dataset(self, stage_id): |
| 153 | + del stage_id |
| 154 | + if self._the_only_train_dataset is None: |
| 155 | + strategy = tf.distribute.get_strategy() |
| 156 | + self._the_only_train_dataset = orbit.utils.make_distributed_dataset( |
| 157 | + strategy, |
| 158 | + self.build_inputs, |
| 159 | + self.task_config.train_data) |
| 160 | + return self._the_only_train_dataset |
| 161 | + |
| 162 | + # overrides policies.ProgressivePolicy |
| 163 | + def get_eval_dataset(self, stage_id): |
| 164 | + del stage_id |
| 165 | + if self._the_only_eval_dataset is None: |
| 166 | + strategy = tf.distribute.get_strategy() |
| 167 | + self._the_only_eval_dataset = orbit.utils.make_distributed_dataset( |
| 168 | + strategy, |
| 169 | + self.build_inputs, |
| 170 | + self.task_config.validation_data) |
| 171 | + return self._the_only_eval_dataset |
| 172 | + |
| 173 | + def _copy_weights_to_new_model(self, old_model, new_model): |
| 174 | + """Copy model weights from the previous stage to the next. |
| 175 | +
|
| 176 | + Args: |
| 177 | + old_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of |
| 178 | + the previous stage. |
| 179 | + new_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of |
| 180 | + the next stage. |
| 181 | + """ |
| 182 | + # Copy weights of the embedding layers. |
| 183 | + # pylint: disable=protected-access |
| 184 | + # When using `encoder_scaffold`, there may be `_embedding_network`. |
| 185 | + if hasattr(new_model.encoder_network, '_embedding_network') and hasattr( |
| 186 | + old_model.encoder_network, '_embedding_network') and ( |
| 187 | + new_model.encoder_network._embedding_network is not None): |
| 188 | + new_model.encoder_network._embedding_network.set_weights( |
| 189 | + old_model.encoder_network._embedding_network.get_weights()) |
| 190 | + else: |
| 191 | + new_model.encoder_network._embedding_layer.set_weights( |
| 192 | + old_model.encoder_network._embedding_layer.get_weights()) |
| 193 | + new_model.encoder_network._position_embedding_layer.set_weights( |
| 194 | + old_model.encoder_network._position_embedding_layer.get_weights()) |
| 195 | + new_model.encoder_network._type_embedding_layer.set_weights( |
| 196 | + old_model.encoder_network._type_embedding_layer.get_weights()) |
| 197 | + new_model.encoder_network._embedding_norm_layer.set_weights( |
| 198 | + old_model.encoder_network._embedding_norm_layer.get_weights()) |
| 199 | + if hasattr(new_model.encoder_network, '_embedding_projection') and hasattr( |
| 200 | + old_model.encoder_network, '_embedding_projection'): |
| 201 | + if old_model.encoder_network._embedding_projection is not None: |
| 202 | + new_model.encoder_network._embedding_projection.set_weights( |
| 203 | + old_model.encoder_network._embedding_projection.get_weights()) |
| 204 | + # pylint: enable=protected-access |
| 205 | + |
| 206 | + # Copy weights of the transformer layers. |
| 207 | + # The model can be EncoderScaffold or TransformerEncoder. |
| 208 | + if hasattr(old_model.encoder_network, 'hidden_layers'): |
| 209 | + old_layer_group = old_model.encoder_network.hidden_layers |
| 210 | + elif hasattr(old_model.encoder_network, 'transformer_layers'): |
| 211 | + old_layer_group = old_model.encoder_network.transformer_layers |
| 212 | + else: |
| 213 | + raise ValueError('Unrecognized encoder network: {}'.format( |
| 214 | + old_model.encoder_network)) |
| 215 | + if hasattr(new_model.encoder_network, 'hidden_layers'): |
| 216 | + new_layer_group = new_model.encoder_network.hidden_layers |
| 217 | + elif hasattr(new_model.encoder_network, 'transformer_layers'): |
| 218 | + new_layer_group = new_model.encoder_network.transformer_layers |
| 219 | + else: |
| 220 | + raise ValueError('Unrecognized encoder network: {}'.format( |
| 221 | + new_model.encoder_network)) |
| 222 | + for new_layer_idx in range(len(new_layer_group)): |
| 223 | + old_layer_idx = new_layer_idx % len(old_layer_group) |
| 224 | + new_layer_group[new_layer_idx].set_weights( |
| 225 | + old_layer_group[old_layer_idx].get_weights()) |
| 226 | + if old_layer_idx != new_layer_idx: |
| 227 | + if hasattr(new_layer_group[new_layer_idx], 'reset_rezero'): |
| 228 | + # Reset ReZero's alpha to 0. |
| 229 | + new_layer_group[new_layer_idx].reset_rezero() |
| 230 | + |
| 231 | + # Copy weights of the final layer norm (if needed). |
| 232 | + # pylint: disable=protected-access |
| 233 | + if hasattr(new_model.encoder_network, '_output_layer_norm') and hasattr( |
| 234 | + old_model.encoder_network, '_output_layer_norm'): |
| 235 | + new_model.encoder_network._output_layer_norm.set_weights( |
| 236 | + old_model.encoder_network._output_layer_norm.get_weights()) |
| 237 | + # pylint: enable=protected-access |
| 238 | + |
| 239 | + # Copy weights of the pooler layer. |
| 240 | + new_model.encoder_network.pooler_layer.set_weights( |
| 241 | + old_model.encoder_network.pooler_layer.get_weights()) |
| 242 | + |
| 243 | + # Copy weights of the classification head. |
| 244 | + for idx in range(len(new_model.classification_heads)): |
| 245 | + new_model.classification_heads[idx].set_weights( |
| 246 | + old_model.classification_heads[idx].get_weights()) |
| 247 | + |
| 248 | + # Copy weights of the masked_lm layer. |
| 249 | + new_model.masked_lm.set_weights(old_model.masked_lm.get_weights()) |
0 commit comments