|
19 | 19 | from official.core import config_definitions as cfg
|
20 | 20 | from official.core import exp_factory
|
21 | 21 | from official.modeling import optimization
|
22 |
| -from official.modeling.progressive import trainer as prog_trainer_lib |
23 | 22 | from official.nlp.data import wmt_dataloader
|
24 |
| -from official.nlp.tasks import progressive_translation |
25 | 23 | from official.nlp.tasks import translation
|
26 | 24 |
|
27 | 25 |
|
@@ -110,65 +108,3 @@ def wmt_transformer_large() -> cfg.ExperimentConfig:
|
110 | 108 | 'task.sentencepiece_model_path != None',
|
111 | 109 | ])
|
112 | 110 | return config
|
113 |
| - |
114 |
| - |
115 |
| -@exp_factory.register_config_factory('wmt_transformer/large_progressive') |
116 |
| -def wmt_transformer_large_progressive() -> cfg.ExperimentConfig: |
117 |
| - """WMT Transformer Larger with progressive training. |
118 |
| -
|
119 |
| - Please refer to |
120 |
| - tensorflow_models/official/nlp/data/train_sentencepiece.py |
121 |
| - to generate sentencepiece_model |
122 |
| - and pass |
123 |
| - --params_override=task.sentencepiece_model_path='YOUR_PATH' |
124 |
| - to the train script. |
125 |
| - """ |
126 |
| - hidden_size = 1024 |
127 |
| - train_steps = 300000 |
128 |
| - token_batch_size = 24576 |
129 |
| - encdecoder = translation.EncDecoder( |
130 |
| - num_attention_heads=16, intermediate_size=hidden_size * 4) |
131 |
| - config = cfg.ExperimentConfig( |
132 |
| - task=progressive_translation.ProgTranslationConfig( |
133 |
| - model=translation.ModelConfig( |
134 |
| - encoder=encdecoder, |
135 |
| - decoder=encdecoder, |
136 |
| - embedding_width=hidden_size, |
137 |
| - padded_decode=True, |
138 |
| - decode_max_length=100), |
139 |
| - train_data=wmt_dataloader.WMTDataConfig( |
140 |
| - tfds_name='wmt14_translate/de-en', |
141 |
| - tfds_split='train', |
142 |
| - src_lang='en', |
143 |
| - tgt_lang='de', |
144 |
| - is_training=True, |
145 |
| - global_batch_size=token_batch_size, |
146 |
| - static_batch=True, |
147 |
| - max_seq_length=64 |
148 |
| - ), |
149 |
| - validation_data=wmt_dataloader.WMTDataConfig( |
150 |
| - tfds_name='wmt14_translate/de-en', |
151 |
| - tfds_split='test', |
152 |
| - src_lang='en', |
153 |
| - tgt_lang='de', |
154 |
| - is_training=False, |
155 |
| - global_batch_size=32, |
156 |
| - static_batch=True, |
157 |
| - max_seq_length=100, |
158 |
| - ), |
159 |
| - sentencepiece_model_path=None, |
160 |
| - ), |
161 |
| - trainer=prog_trainer_lib.ProgressiveTrainerConfig( |
162 |
| - train_steps=train_steps, |
163 |
| - validation_steps=-1, |
164 |
| - steps_per_loop=1000, |
165 |
| - summary_interval=1000, |
166 |
| - checkpoint_interval=5000, |
167 |
| - validation_interval=5000, |
168 |
| - optimizer_config=None, |
169 |
| - ), |
170 |
| - restrictions=[ |
171 |
| - 'task.train_data.is_training != None', |
172 |
| - 'task.sentencepiece_model_path != None', |
173 |
| - ]) |
174 |
| - return config |
0 commit comments