Skip to content

Commit cc44fd8

Browse files
[nlp][translation] Opensource and register WMT transformer experiment
PiperOrigin-RevId: 365188568
1 parent bc43944 commit cc44fd8

File tree

2 files changed

+175
-0
lines changed

2 files changed

+175
-0
lines changed

official/nlp/configs/experiment_configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
# pylint: disable=unused-import
1717
from official.nlp.configs import finetuning_experiments
1818
from official.nlp.configs import pretraining_experiments
19+
from official.nlp.configs import wmt_transformer_experiments
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Lint as: python3
16+
# pylint: disable=g-doc-return-or-yield,line-too-long
17+
"""WMT translation configurations."""
18+
19+
from official.core import config_definitions as cfg
20+
from official.core import exp_factory
21+
from official.modeling import optimization
22+
from official.modeling.progressive import trainer as prog_trainer_lib
23+
from official.nlp.data import wmt_dataloader
24+
from official.nlp.tasks import progressive_translation
25+
from official.nlp.tasks import translation
26+
27+
28+
@exp_factory.register_config_factory('wmt_transformer/large')
29+
def wmt_transformer_large() -> cfg.ExperimentConfig:
30+
"""WMT Transformer Large.
31+
32+
Please refer to
33+
tensorflow_models/official/nlp/data/train_sentencepiece.py
34+
to generate sentencepiece_model
35+
and pass
36+
--params_override=task.sentencepiece_model_path='YOUR_PATH'
37+
to the train script.
38+
"""
39+
learning_rate = 2.0
40+
hidden_size = 1024
41+
learning_rate *= (hidden_size**-0.5)
42+
warmup_steps = 16000
43+
train_steps = 300000
44+
token_batch_size = 24576
45+
encdecoder = translation.EncDecoder(
46+
num_attention_heads=16, intermediate_size=hidden_size * 4)
47+
config = cfg.ExperimentConfig(
48+
task=translation.TranslationConfig(
49+
model=translation.ModelConfig(
50+
encoder=encdecoder,
51+
decoder=encdecoder,
52+
embedding_width=hidden_size,
53+
padded_decode=True,
54+
decode_max_length=100),
55+
train_data=wmt_dataloader.WMTDataConfig(
56+
tfds_name='wmt14_translate/de-en',
57+
tfds_split='train',
58+
src_lang='en',
59+
tgt_lang='de',
60+
is_training=True,
61+
global_batch_size=token_batch_size,
62+
static_batch=True,
63+
max_seq_length=64
64+
),
65+
validation_data=wmt_dataloader.WMTDataConfig(
66+
tfds_name='wmt14_translate/de-en',
67+
tfds_split='test',
68+
src_lang='en',
69+
tgt_lang='de',
70+
is_training=False,
71+
global_batch_size=32,
72+
static_batch=True,
73+
max_seq_length=100,
74+
),
75+
sentencepiece_model_path=None,
76+
),
77+
trainer=cfg.TrainerConfig(
78+
train_steps=train_steps,
79+
validation_steps=-1,
80+
steps_per_loop=1000,
81+
summary_interval=1000,
82+
checkpoint_interval=5000,
83+
validation_interval=5000,
84+
max_to_keep=1,
85+
optimizer_config=optimization.OptimizationConfig({
86+
'optimizer': {
87+
'type': 'adam',
88+
'adam': {
89+
'beta_2': 0.997,
90+
'epsilon': 1e-9,
91+
},
92+
},
93+
'learning_rate': {
94+
'type': 'power',
95+
'power': {
96+
'initial_learning_rate': learning_rate,
97+
'power': -0.5,
98+
}
99+
},
100+
'warmup': {
101+
'type': 'linear',
102+
'linear': {
103+
'warmup_steps': warmup_steps,
104+
'warmup_learning_rate': 0.0
105+
}
106+
}
107+
})),
108+
restrictions=[
109+
'task.train_data.is_training != None',
110+
'task.sentencepiece_model_path != None',
111+
])
112+
return config
113+
114+
115+
@exp_factory.register_config_factory('wmt_transformer/large_progressive')
116+
def wmt_transformer_large_progressive() -> cfg.ExperimentConfig:
117+
"""WMT Transformer Larger with progressive training.
118+
119+
Please refer to
120+
tensorflow_models/official/nlp/data/train_sentencepiece.py
121+
to generate sentencepiece_model
122+
and pass
123+
--params_override=task.sentencepiece_model_path='YOUR_PATH'
124+
to the train script.
125+
"""
126+
hidden_size = 1024
127+
train_steps = 300000
128+
token_batch_size = 24576
129+
encdecoder = translation.EncDecoder(
130+
num_attention_heads=16, intermediate_size=hidden_size * 4)
131+
config = cfg.ExperimentConfig(
132+
task=progressive_translation.ProgTranslationConfig(
133+
model=translation.ModelConfig(
134+
encoder=encdecoder,
135+
decoder=encdecoder,
136+
embedding_width=hidden_size,
137+
padded_decode=True,
138+
decode_max_length=100),
139+
train_data=wmt_dataloader.WMTDataConfig(
140+
tfds_name='wmt14_translate/de-en',
141+
tfds_split='train',
142+
src_lang='en',
143+
tgt_lang='de',
144+
is_training=True,
145+
global_batch_size=token_batch_size,
146+
static_batch=True,
147+
max_seq_length=64
148+
),
149+
validation_data=wmt_dataloader.WMTDataConfig(
150+
tfds_name='wmt14_translate/de-en',
151+
tfds_split='test',
152+
src_lang='en',
153+
tgt_lang='de',
154+
is_training=False,
155+
global_batch_size=32,
156+
static_batch=True,
157+
max_seq_length=100,
158+
),
159+
sentencepiece_model_path=None,
160+
),
161+
trainer=prog_trainer_lib.ProgressiveTrainerConfig(
162+
train_steps=train_steps,
163+
validation_steps=-1,
164+
steps_per_loop=1000,
165+
summary_interval=1000,
166+
checkpoint_interval=5000,
167+
validation_interval=5000,
168+
optimizer_config=None,
169+
),
170+
restrictions=[
171+
'task.train_data.is_training != None',
172+
'task.sentencepiece_model_path != None',
173+
])
174+
return config

0 commit comments

Comments
 (0)