1616
1717from __future__ import division , print_function , absolute_import
1818
19+ import itertools
1920import os
2021import numpy as np
2122
2223import tensorflow as tf
23- from tensorflow .python .ops import rnn_cell , rnn , seq2seq
2424
2525import skflow
2626
2727# Get training data
2828
2929# This dataset can be downloaded from http://www.statmt.org/europarl/v6/fr-en.tgz
3030
31- def X_iter ():
31+ ENGLISH_CORPUS = "europarl-v6.fr-en.en"
32+ FRENCH_CORPUS = "europarl-v6.fr-en.fr"
33+
34+ def read_iterator (filename ):
35+ f = open (filename )
36+ for line in f :
37+ yield line .strip ()
38+
39+
40+ def repeated_read_iterator (filename ):
3241 while True :
33- yield "some sentence"
34- yield "some other sentence"
42+ f = open (filename )
43+ for line in f :
44+ yield line .strip ()
45+
46+
47+ def split_train_test (data , partition = 0.2 , random_seed = 42 ):
48+ rnd = np .random .RandomState (random_seed )
49+ for item in data :
50+ if rnd .uniform () > partition :
51+ yield (0 , item )
52+ else :
53+ yield (1 , item )
54+
55+
56+ def save_partitions (data , filenames ):
57+ files = [open (filename , 'w' ) for filename in filenames ]
58+ for partition , item in data :
59+ files [partition ].write (item + '\n ' )
3560
36- X_pred = ["some sentence" , "some other sentence" ]
3761
38- def y_iter ( ):
62+ def loop_iterator ( data ):
3963 while True :
40- yield "какое-то приложение"
41- yield "какое-то другое приложение"
64+ for item in data :
65+ yield item
4266
43- # Translation model
4467
45- MAX_DOCUMENT_LENGTH = 10
46- HIDDEN_SIZE = 10
47-
48- def rnn_decoder (decoder_inputs , initial_state , cell , scope = None ):
49- with tf .variable_scope (scope or "dnn_decoder" ):
50- states , sampling_states = [initial_state ], [initial_state ]
51- outputs , sampling_outputs = [], []
52- with tf .op_scope ([decoder_inputs , initial_state ], "training" ):
53- for i in xrange (len (decoder_inputs )):
54- inp = decoder_inputs [i ]
55- if i > 0 :
56- tf .get_variable_scope ().reuse_variables ()
57- output , new_state = cell (inp , states [- 1 ])
58- outputs .append (output )
59- states .append (new_state )
60- with tf .op_scope ([initial_state ], "sampling" ):
61- for i in xrange (len (decoder_inputs )):
62- if i == 0 :
63- sampling_outputs .append (outputs [i ])
64- sampling_states .append (states [i ])
65- else :
66- sampling_output , sampling_state = cell (sampling_outputs [- 1 ], sampling_states [- 1 ])
67- sampling_outputs .append (sampling_output )
68- sampling_states .append (sampling_state )
69- return outputs , states , sampling_outputs , sampling_states
70-
71-
72- def rnn_seq2seq (encoder_inputs , decoder_inputs , cell , dtype = tf .float32 , scope = None ):
73- with tf .variable_scope (scope or "rnn_seq2seq" ):
74- _ , enc_states = rnn .rnn (cell , encoder_inputs , dtype = dtype )
75- return rnn_decoder (decoder_inputs , enc_states [- 1 ], cell )
68+ if not (os .path .exists ('train.data' ) and os .path .exists ('test.data' )):
69+ english_data = read_iterator (ENGLISH_CORPUS )
70+ french_data = read_iterator (FRENCH_CORPUS )
71+ parallel_data = ('%s;;;%s' % (eng , fr ) for eng , fr in itertools .izip (english_data , french_data ))
72+ save_partitions (split_train_test (parallel_data ), ['train.data' , 'test.data' ])
73+
74+ def Xy (data ):
75+ def split_lines (data ):
76+ for item in data :
77+ yield item .split (';;;' )
78+ X , y = itertools .tee (split_lines (data ))
79+ return (item [0 ] for item in X ), (item [1 ] for item in y )
80+
81+ X_train , y_train = Xy (repeated_read_iterator ('train.data' ))
82+ X_test , y_test = Xy (read_iterator ('test.data' ))
83+
84+
85+ # Translation model
7686
87+ MAX_DOCUMENT_LENGTH = 30
88+ HIDDEN_SIZE = 100
7789
7890def translate_model (X , y ):
7991 byte_list = skflow .ops .one_hot_matrix (X , 256 )
8092 in_X , in_y , out_y = skflow .ops .seq2seq_inputs (
8193 byte_list , y , MAX_DOCUMENT_LENGTH , MAX_DOCUMENT_LENGTH )
82- cell = rnn_cell .OutputProjectionWrapper (rnn_cell .GRUCell (HIDDEN_SIZE ), 256 )
83- decoding , _ , sampling_decoding , _ = rnn_seq2seq (in_X , in_y , cell )
94+ cell = tf . nn . rnn_cell .OutputProjectionWrapper (tf . nn . rnn_cell .GRUCell (HIDDEN_SIZE ), 256 )
95+ decoding , _ , sampling_decoding , _ = skflow . ops . rnn_seq2seq (in_X , in_y , cell )
8496 return skflow .ops .sequence_classifier (decoding , out_y , sampling_decoding )
8597
8698
8799vocab_processor = skflow .preprocessing .ByteProcessor (
88100 max_document_length = MAX_DOCUMENT_LENGTH )
89101
90- x_iter = vocab_processor .transform (X_iter ())
91- y_iter = vocab_processor .transform (y_iter ())
92- xpred = np .array (list (vocab_processor .transform (X_pred )))
102+ x_iter = vocab_processor .transform (X_train )
103+ y_iter = vocab_processor .transform (y_train )
104+ xpred = np .array (list (vocab_processor .transform (X_test ))[:20 ])
105+ ygold = list (y_test )[:20 ]
93106
94107PATH = '/tmp/tf_examples/ntm/'
95108
96109if os .path .exists (PATH ):
97110 translator = skflow .TensorFlowEstimator .restore (PATH )
98111else :
99112 translator = skflow .TensorFlowEstimator (model_fn = translate_model ,
100- n_classes = 256 , continue_training = True )
113+ n_classes = 256 ,
114+ optimizer = 'Adam' , learning_rate = 0.01 , batch_size = 128 ,
115+ continue_training = True )
101116
102117while True :
103118 translator .fit (x_iter , y_iter , logdir = PATH )
@@ -106,7 +121,9 @@ def translate_model(X, y):
106121 predictions = translator .predict (xpred , axis = 2 )
107122 xpred_inp = vocab_processor .reverse (xpred )
108123 text_outputs = vocab_processor .reverse (predictions )
109- for inp_data , input_text , pred , output_text in zip (xpred , xpred_inp , predictions , text_outputs ):
110- print (input_text , output_text )
124+ for inp_data , input_text , pred , output_text , gold in zip (xpred , xpred_inp ,
125+ predictions , text_outputs , ygold ):
126+ print ('English: %s. French (pred): %s, French (gold): %s' %
127+ (input_text , output_text , gold .decode ('utf-8' )))
111128 print (inp_data , pred )
112129
0 commit comments