Skip to content

Commit 3fca8af

Browse files
authored
Add transformer model (#4148)
1 parent dea7ecf commit 3fca8af

24 files changed

+10638
-0
lines changed

official/transformer/README.md

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# Transformer Translation Model
2+
This is an implementation of the Transformer translation model as described in the [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper. Based on the code provided by the authors: [Transformer code](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py) from [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor).
3+
4+
Transformer is a neural network architecture that solves sequence to sequence problems using attention mechanisms. Unlike traditional neural seq2seq models, Transformer does not involve recurrent connections. The attention mechanism learns dependencies between tokens in two sequences. Since attention weights apply to all tokens in the sequences, the Tranformer model is able to easily capture long-distance depedencies.
5+
6+
Transformer's overall structure follows the standard encoder-decoder pattern. The encoder uses self-attention to compute a representation of the input sequence. The decoder generates the output sequence one token at a time, taking the encoder output and previous decoder-outputted tokens as inputs.
7+
8+
The model also applies embeddings on the input and output tokens, and adds a constant positional encoding. The positional encoding adds information about the position of each token.
9+
10+
## Contents
11+
* [Contents](#contents)
12+
* [Walkthrough](#walkthrough)
13+
* [Benchmarks](#benchmarks)
14+
* [Training times](#training-times)
15+
* [Evaluation results](#evaluation-results)
16+
* [Detailed instructions](#detailed-instructions)
17+
* [Export variables (optional)](#export-variables-optional)
18+
* [Download and preprocess datasets](#download-and-preprocess-datasets)
19+
* [Model training and evaluation](#model-training-and-evaluation)
20+
* [Translate using the model](#translate-using-the-model)
21+
* [Compute official BLEU score](#compute-official-bleu-score)
22+
* [Implementation overview](#implementation-overview)
23+
* [Model Definition](#model-definition)
24+
* [Model Estimator](#model-estimator)
25+
* [Other scripts](#other-scripts)
26+
* [Test dataset](#test-dataset)
27+
* [Term definitions](#term-definitions)
28+
29+
## Walkthrough
30+
31+
Below are the commands for running the Transformer model. See the [Detailed instrutions](#detailed-instructions) for more details on running the model.
32+
33+
```
34+
PARAMS=big
35+
DATA_DIR=$HOME/transformer/data
36+
MODEL_DIR=$HOME/transformer/model_$PARAMS
37+
38+
# Download training/evaluation datasets
39+
python data_download.py --data_dir=$DATA_DIR
40+
41+
# Train the model for 10 epochs, and evaluate after every epoch.
42+
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
43+
--params=$PARAMS --bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de
44+
45+
# Run during training in a separate process to get continuous updates,
46+
# or after training is complete.
47+
tensorboard --logdir=$MODEL_DIR
48+
49+
# Translate some text using the trained model
50+
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
51+
--params=$PARAMS --text="hello world"
52+
53+
# Compute model's BLEU score using the newstest2014 dataset.
54+
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
55+
--params=$PARAMS --file=test_data/newstest2014.en --file_out=translation.en
56+
python compute_bleu.py --translation=translation.en --reference=test_data/newstest2014.de
57+
```
58+
59+
## Benchmarks
60+
### Training times
61+
62+
Currently, both big and base params run on a single GPU. The measurements below
63+
are reported from running the model on a P100 GPU.
64+
65+
Params | batches/sec | batches per epoch | time per epoch
66+
--- | --- | --- | ---
67+
base | 4.8 | 83244 | 4 hr
68+
big | 1.1 | 41365 | 10 hr
69+
70+
### Evaluation results
71+
Below are the case-insensitive BLEU scores after 10 epochs.
72+
73+
Params | Score
74+
--- | --- |
75+
base | 27.7
76+
big | 28.9
77+
78+
79+
## Detailed instructions
80+
81+
82+
0. ### Export variables (optional)
83+
84+
Export the following variables, or modify the values in each of the snippets below:
85+
```
86+
PARAMS=big
87+
DATA_DIR=$HOME/transformer/data
88+
MODEL_DIR=$HOME/transformer/model_$PARAMS
89+
```
90+
91+
1. ### Download and preprocess datasets
92+
93+
[data_download.py](data_download.py) downloads and preprocesses the training and evaluation WMT datasets. After the data is downloaded and extracted, the training data is used to generate a vocabulary of subtokens. The evaluation and training strings are tokenized, and the resulting data is sharded, shuffled, and saved as TFRecords.
94+
95+
1.75GB of compressed data will be downloaded. In total, the raw files (compressed, extracted, and combined files) take up 8.4GB of disk space. The resulting TFRecord and vocabulary files are 722MB. The script takes around 40 minutes to run, with the bulk of the time spent downloading and ~15 minutes spent on preprocessing.
96+
97+
Command to run:
98+
```
99+
python data_download.py --data_dir=$DATA_DIR
100+
```
101+
102+
Arguments:
103+
* `--data_dir`: Path where the preprocessed TFRecord data, and vocab file will be saved.
104+
* Use the `--help` or `-h` flag to get a full list of possible arguments.
105+
106+
2. ### Model training and evaluation
107+
108+
[transformer_main.py](transformer_main.py) creates a Transformer model, and trains it using Tensorflow Estimator.
109+
110+
Command to run:
111+
```
112+
python transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --params=$PARAMS
113+
```
114+
115+
Arguments:
116+
* `--data_dir`: This should be set to the same directory given to the `data_download`'s `data_dir` argument.
117+
* `--model_dir`: Directory to save Transformer model training checkpoints.
118+
* `--params`: Parameter set to use when creating and training the model. Options are `base` and `big` (default).
119+
* Use the `--help` or `-h` flag to get a full list of possible arguments.
120+
121+
#### Customizing training schedule
122+
123+
By default, the model will train for 10 epochs, and evaluate after every epoch. The training schedule may be defined through the flags:
124+
* Training with epochs (default):
125+
* `--train_epochs`: The total number of complete passes to make through the dataset
126+
* `--epochs_between_eval`: The number of epochs to train between evaluations.
127+
* Training with steps:
128+
* `--train_steps`: sets the total number of training steps to run.
129+
* `--steps_between_eval`: Number of training steps to run between evaluations.
130+
131+
Only one of `train_epochs` or `train_steps` may be set. Since the default option is to evaluate the model after training for an epoch, it may take 4 or more hours between model evaluations. To get more frequent evaluations, use the flags `--train_steps=250000 --steps_between_eval=1000`.
132+
133+
Note: At the beginning of each training session, the training dataset is reloaded and shuffled. Stopping the training before completing an epoch may result in worse model quality, due to the chance that some examples may be seen more than others. Therefore, it is recommended to use epochs when the model quality is important.
134+
135+
#### Compute BLEU score during model evaluation
136+
137+
Use these flags to compute the BLEU when the model evaluates:
138+
* `--bleu_source`: Path to file containing text to translate.
139+
* `--bleu_ref`: Path to file containing the reference translation.
140+
* `--bleu_threshold`: Train until the BLEU score reaches this lower bound. This setting overrides the `--train_steps` and `--train_epochs` flags.
141+
142+
The test source and reference files located in the `test_data` directory are extracted from the preprocessed dataset from the [NMT Seq2Seq tutorial](https://google.github.io/seq2seq/nmt/#download-data).
143+
144+
When running `transformer_main.py`, use the flags: `--bleu_source=test_data/newstest2014.en --bleu_ref=test_data/newstest2014.de`
145+
146+
#### Tensorboard
147+
Training and evaluation metrics (loss, accuracy, approximate BLEU score, etc.) are logged, and can be displayed in the browser using Tensorboard.
148+
```
149+
tensorboard --logdir=$MODEL_DIR
150+
```
151+
The values are displayed at [localhost:6006](localhost:6006).
152+
153+
3. ### Translate using the model
154+
[translate.py](translate.py) contains the script to use the trained model to translate input text or file. Each line in the file is translated separately.
155+
156+
Command to run:
157+
```
158+
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR --params=$PARAMS --text="hello world"
159+
```
160+
161+
Arguments for initializing the Subtokenizer and trained model:
162+
* `--data_dir`: Used to locate the vocabulary file to create a Subtokenizer, which encodes the input and decodes the model output.
163+
* `--model_dir` and `--params`: These parameters are used to rebuild the trained model
164+
165+
Arguments for specifying what to translate:
166+
* `--text`: Text to translate
167+
* `--file`: Path to file containing text to translate
168+
* `--file_out`: If `--file` is set, then this file will store the input file's translations.
169+
170+
To translate the newstest2014 data, run:
171+
```
172+
python translate.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
173+
--params=$PARAMS --file=test_data/newstest2014.en --file_out=translation.en
174+
```
175+
176+
Translating the file takes around 15 minutes on a GTX1080, or 5 minutes on a P100.
177+
178+
4. ### Compute official BLEU score
179+
Use [compute_bleu.py](compute_bleu.py) to compute the BLEU by comparing generated translations to the reference translation.
180+
181+
Command to run:
182+
```
183+
python compute_bleu.py --translation=translation.en --reference=test_data/newstest2014.de
184+
```
185+
186+
Arguments:
187+
* `--translation`: Path to file containing generated translations.
188+
* `--reference`: Path to file containing reference translations.
189+
* Use the `--help` or `-h` flag to get a full list of possible arguments.
190+
191+
## Implementation overview
192+
193+
A brief look at each component in the code:
194+
195+
### Model Definition
196+
The [model](model) subdirectory contains the implementation of the Transformer model. The following files define the Transformer model and its layers:
197+
* [transformer.py](model/transformer.py): Defines the transformer model and its encoder/decoder layer stacks.
198+
* [embedding_layer.py](model/embedding_layer.py): Contains the layer that calculates the embeddings. The embedding weights are also used to calculate the pre-softmax probabilities from the decoder output.
199+
* [attention_layer.py](model/attention_layer.py): Defines the multi-headed and self attention layers that are used in the encoder/decoder stacks.
200+
* [ffn_layer.py](model/ffn_layer.py): Defines the feedforward network that is used in the encoder/decoder stacks. The network is composed of 2 fully connected layers.
201+
202+
Other files:
203+
* [beam_search.py](model/beam_search.py) contains the beam search implementation, which is used during model inference to find high scoring translations.
204+
* [model_params.py](model/model_params.py) contains the parameters used for the big and base models.
205+
* [model_utils.py](model/model_utils.py) defines some helper functions used in the model (calculating padding, bias, etc.).
206+
207+
208+
### Model Estimator
209+
[transformer_main.py](model/transformer.py) creates an `Estimator` to train and evaluate the model.
210+
211+
Helper functions:
212+
* [utils/dataset.py](utils/dataset.py): contains functions for creating a `dataset` that is passed to the `Estimator`.
213+
* [utils/metrics.py](utils/metrics.py): defines metrics functions used by the `Estimator` to evaluate the
214+
215+
### Other scripts
216+
217+
Aside from the main file to train the Transformer model, we provide other scripts for using the model or downloading the data:
218+
219+
#### Data download and preprocessing
220+
221+
[data_download.py](data_download.py) downloads and extracts data, then uses `Subtokenizer` to tokenize strings into arrays of int IDs. The int arrays are converted to `tf.Examples` and saved in the `tf.RecordDataset` format.
222+
223+
The data is downloaded from the Workshop of Machine Transtion (WMT) [news translation task](http://www.statmt.org/wmt17/translation-task.html). The following datasets are used:
224+
225+
* Europarl v7
226+
* Common Crawl corpus
227+
* News Commentary v12
228+
229+
See the [download section](http://www.statmt.org/wmt17/translation-task.html#download) to explore the raw datasets. The parameters in this model are tuned to fit the English-German translation data, so the EN-DE texts are extracted from the downloaded compressed files.
230+
231+
The text is transformed into arrays of integer IDs using the `Subtokenizer` defined in [`utils/tokenizer.py`](util/tokenizer.py). During initialization of the `Subtokenizer`, the raw training data is used to generate a vocabulary list containing common subtokens.
232+
233+
The target vocabulary size of the WMT dataset is 32,768. The set of subtokens is found through binary search on the minimum number of times a subtoken appears in the data. The actual vocabulary size is 33,708, and is stored in a 324kB file.
234+
235+
#### Translation
236+
Translation is defined in [translate.py](translate.py). First, `Subtokenizer` tokenizes the input. The vocabulary file is the same used to tokenize the training/eval files. Next, beam search is used to find the combination of tokens that maximizes the probability outputted by the model decoder. The tokens are then converted back to strings with `Subtokenizer`.
237+
238+
#### BLEU computation
239+
[compute_bleu.py](compute_bleu.py): Implementation from [https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py).
240+
241+
### Test dataset
242+
The [newstest2014 files](test_data) are extracted from the [NMT Seq2Seq tutorial](https://google.github.io/seq2seq/nmt/#download-data). The raw text files are converted from the SGM format of the [WMT 2016](http://www.statmt.org/wmt16/translation-task.html) test sets.
243+
244+
## Term definitions
245+
246+
**Steps / Epochs**:
247+
* Step: unit for processing a single batch of data
248+
* Epoch: a complete run through the dataset
249+
250+
Example: Consider a training a dataset with 100 examples that is divided into 20 batches with 5 examples per batch. A single training step trains the model on one batch. After 20 training steps, the model will have trained on every batch in the dataset, or one epoch.
251+
252+
**Subtoken**: Words are referred to as tokens, and parts of words are referred to as 'subtokens'. For example, the word 'inclined' may be split into `['incline', 'd_']`. The '\_' indicates the end of the token. The subtoken vocabulary list is guaranteed to contain the alphabet (including numbers and special characters), so all words can be tokenized.

official/transformer/__init__.py

Whitespace-only changes.
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Script to compute official BLEU score.
16+
17+
Source:
18+
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
19+
"""
20+
21+
from __future__ import absolute_import
22+
from __future__ import division
23+
from __future__ import print_function
24+
25+
import argparse
26+
import re
27+
import sys
28+
import unicodedata
29+
30+
# pylint: disable=g-bad-import-order
31+
import six
32+
import tensorflow as tf
33+
# pylint: enable=g-bad-import-order
34+
35+
from official.transformer.utils import metrics
36+
37+
38+
class UnicodeRegex(object):
39+
"""Ad-hoc hack to recognize all punctuation and symbols."""
40+
41+
def __init__(self):
42+
punctuation = self.property_chars("P")
43+
self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])")
44+
self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])")
45+
self.symbol_re = re.compile("([" + self.property_chars("S") + "])")
46+
47+
def property_chars(self, prefix):
48+
return "".join(six.unichr(x) for x in range(sys.maxunicode)
49+
if unicodedata.category(six.unichr(x)).startswith(prefix))
50+
51+
52+
uregex = UnicodeRegex()
53+
54+
55+
def bleu_tokenize(string):
56+
r"""Tokenize a string following the official BLEU implementation.
57+
58+
See https://github.com/moses-smt/mosesdecoder/'
59+
'blob/master/scripts/generic/mteval-v14.pl#L954-L983
60+
In our case, the input string is expected to be just one line
61+
and no HTML entities de-escaping is needed.
62+
So we just tokenize on punctuation and symbols,
63+
except when a punctuation is preceded and followed by a digit
64+
(e.g. a comma/dot as a thousand/decimal separator).
65+
66+
Note that a numer (e.g. a year) followed by a dot at the end of sentence
67+
is NOT tokenized,
68+
i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
69+
does not match this case (unless we add a space after each sentence).
70+
However, this error is already in the original mteval-v14.pl
71+
and we want to be consistent with it.
72+
73+
Args:
74+
string: the input string
75+
76+
Returns:
77+
a list of tokens
78+
"""
79+
string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
80+
string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
81+
string = uregex.symbol_re.sub(r" \1 ", string)
82+
return string.split()
83+
84+
85+
def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
86+
"""Compute BLEU for two files (reference and hypothesis translation)."""
87+
ref_lines = tf.gfile.Open(ref_filename).read().strip().splitlines()
88+
hyp_lines = tf.gfile.Open(hyp_filename).read().strip().splitlines()
89+
90+
if len(ref_lines) != len(hyp_lines):
91+
raise ValueError("Reference and translation files have different number of "
92+
"lines.")
93+
if not case_sensitive:
94+
ref_lines = [x.lower() for x in ref_lines]
95+
hyp_lines = [x.lower() for x in hyp_lines]
96+
ref_tokens = [bleu_tokenize(x) for x in ref_lines]
97+
hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
98+
return metrics.compute_bleu(ref_tokens, hyp_tokens) * 100
99+
100+
101+
def main(unused_argv):
102+
if FLAGS.bleu_variant is None or "uncased" in FLAGS.bleu_variant:
103+
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, False)
104+
print("Case-insensitive results:", score)
105+
106+
if FLAGS.bleu_variant is None or "cased" in FLAGS.bleu_variant:
107+
score = bleu_wrapper(FLAGS.reference, FLAGS.translation, True)
108+
print("Case-sensitive results:", score)
109+
110+
111+
if __name__ == "__main__":
112+
parser = argparse.ArgumentParser()
113+
parser.add_argument(
114+
"--translation", "-t", type=str, default=None, required=True,
115+
help="[default: %(default)s] File containing translated text.",
116+
metavar="<T>")
117+
parser.add_argument(
118+
"--reference", "-r", type=str, default=None, required=True,
119+
help="[default: %(default)s] File containing reference translation",
120+
metavar="<R>")
121+
parser.add_argument(
122+
"--bleu_variant", "-bv", type=str, choices=["uncased", "cased"],
123+
nargs="*", default=None,
124+
help="Specify one or more BLEU variants to calculate (both are "
125+
"calculated by default. Variants: \"cased\" or \"uncased\".",
126+
metavar="<BV>")
127+
128+
FLAGS, unparsed = parser.parse_known_args()
129+
main(sys.argv)

0 commit comments

Comments
 (0)