|
| 1 | +# encoding: utf-8 |
| 2 | + |
| 3 | +# Copyright 2015-present Scikit Flow Authors. All Rights Reserved. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | + |
| 17 | +from __future__ import division, print_function, absolute_import |
| 18 | + |
| 19 | +import itertools |
| 20 | +import os |
| 21 | +import numpy as np |
| 22 | + |
| 23 | +import tensorflow as tf |
| 24 | + |
| 25 | +import skflow |
| 26 | + |
| 27 | +# Get training data |
| 28 | + |
| 29 | +# This dataset can be downloaded from http://www.statmt.org/europarl/v6/fr-en.tgz |
| 30 | + |
| 31 | +ENGLISH_CORPUS = "europarl-v6.fr-en.en" |
| 32 | +FRENCH_CORPUS = "europarl-v6.fr-en.fr" |
| 33 | + |
| 34 | +def read_iterator(filename, reporting=True): |
| 35 | + f = open(filename) |
| 36 | + line_count = 0 |
| 37 | + for line in f: |
| 38 | + line_count += 1 |
| 39 | + if reporting and line_count % 100000 == 0: |
| 40 | + print("%d lines read from %s" % (line_count, filename)) |
| 41 | + yield line.strip() |
| 42 | + |
| 43 | + |
| 44 | +def repeated_read_iterator(filename): |
| 45 | + while True: |
| 46 | + f = open(filename) |
| 47 | + for line in f: |
| 48 | + yield line.strip() |
| 49 | + |
| 50 | + |
| 51 | +def split_train_test(data, partition=0.2, random_seed=42): |
| 52 | + rnd = np.random.RandomState(random_seed) |
| 53 | + for item in data: |
| 54 | + if rnd.uniform() > partition: |
| 55 | + yield (0, item) |
| 56 | + else: |
| 57 | + yield (1, item) |
| 58 | + |
| 59 | + |
| 60 | +def save_partitions(data, filenames): |
| 61 | + files = [open(filename, 'w') for filename in filenames] |
| 62 | + for partition, item in data: |
| 63 | + files[partition].write(item + '\n') |
| 64 | + |
| 65 | + |
| 66 | +def loop_iterator(data): |
| 67 | + while True: |
| 68 | + for item in data: |
| 69 | + yield item |
| 70 | + |
| 71 | + |
| 72 | +if not (os.path.exists('train.data') and os.path.exists('test.data')): |
| 73 | + english_data = read_iterator(ENGLISH_CORPUS) |
| 74 | + french_data = read_iterator(FRENCH_CORPUS) |
| 75 | + parallel_data = ('%s;;;%s' % (eng, fr) for eng, fr in itertools.izip(english_data, french_data)) |
| 76 | + save_partitions(split_train_test(parallel_data), ['train.data', 'test.data']) |
| 77 | + |
| 78 | +def Xy(data): |
| 79 | + def split_lines(data): |
| 80 | + for item in data: |
| 81 | + yield item.split(';;;') |
| 82 | + X, y = itertools.tee(split_lines(data)) |
| 83 | + return (item[0] for item in X), (item[1] for item in y) |
| 84 | + |
| 85 | +X_train, y_train = Xy(repeated_read_iterator('train.data')) |
| 86 | +X_test, y_test = Xy(read_iterator('test.data')) |
| 87 | + |
| 88 | +# Preprocessing |
| 89 | + |
| 90 | +MAX_DOCUMENT_LENGTH = 10 |
| 91 | + |
| 92 | +X_vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH, |
| 93 | + min_frequency=5) |
| 94 | +y_vocab_processor = skflow.preprocessing.VocabularyPRocessor(MAX_DOCUMENT_LENGTH, |
| 95 | + min_frequency=5) |
| 96 | +Xtrainff, ytrainff = Xy(read_iterator('train.data')) |
| 97 | +print('Fitting dictionary for English...') |
| 98 | +X_vocab_processor.fit(Xtrainff) |
| 99 | +print('Fitting dictionary for French...') |
| 100 | +y_vocab_processor.fit(ytrainff) |
| 101 | +print('Transforming...') |
| 102 | +X_train = X_vocab_processor.transform(X_train) |
| 103 | +y_train = y_vocab_processor.transform(y_train) |
| 104 | +X_test = np.array(list(X_vocab_processor.transform(X_test))[:20]) |
| 105 | +y_test = list(y_test)[:20] |
| 106 | + |
| 107 | +n_words = len(X_vocab_processor.vocabulary_) |
| 108 | +print('Total words: %d' % n_words) |
| 109 | + |
| 110 | +# Translation model |
| 111 | + |
| 112 | +HIDDEN_SIZE = 20 |
| 113 | +EMBEDDING_SIZE = 20 |
| 114 | + |
| 115 | +def translate_model(X, y): |
| 116 | + word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words, |
| 117 | + embedding_size=EMBEDDING_SIZE, name='words') |
| 118 | + in_X, in_y, out_y = skflow.ops.seq2seq_inputs( |
| 119 | + word_list, y, MAX_DOCUMENT_LENGTH, MAX_DOCUMENT_LENGTH) |
| 120 | + cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), 256) |
| 121 | + decoding, _, sampling_decoding, _ = skflow.ops.rnn_seq2seq(in_X, in_y, cell) |
| 122 | + return skflow.ops.sequence_classifier(decoding, out_y, sampling_decoding) |
| 123 | + |
| 124 | + |
| 125 | +PATH = '/tmp/tf_examples/ntm_words/' |
| 126 | + |
| 127 | +if os.path.exists(PATH): |
| 128 | + translator = skflow.TensorFlowEstimator.restore(PATH) |
| 129 | +else: |
| 130 | + translator = skflow.TensorFlowEstimator(model_fn=translate_model, |
| 131 | + n_classes=n_words, |
| 132 | + optimizer='Adam', learning_rate=0.01, batch_size=128, |
| 133 | + continue_training=True) |
| 134 | + |
| 135 | +while True: |
| 136 | + translator.fit(X_train, y_train, logdir=PATH) |
| 137 | + translator.save(PATH) |
| 138 | + |
| 139 | + predictions = translator.predict(xpred, axis=2) |
| 140 | + xpred_inp = X_vocab_processor.reverse(xpred) |
| 141 | + text_outputs = y_vocab_processor.reverse(predictions) |
| 142 | + for inp_data, input_text, pred, output_text, gold in zip(xpred, xpred_inp, |
| 143 | + predictions, text_outputs, ygold): |
| 144 | + print('English: %s. French (pred): %s, French (gold): %s' % |
| 145 | + (input_text, output_text, gold.decode('utf-8'))) |
| 146 | + print(inp_data, pred) |
| 147 | + |
0 commit comments