@@ -230,10 +230,13 @@ def inference(x, is_train, sequence_length, reuse=None):
230230 rnn_init = tf .random_uniform_initializer (- init_scale , init_scale )
231231 with tf .variable_scope ("model" , reuse = reuse ):
232232 network = EmbeddingInputlayer (x , vocab_size , hidden_size , rnn_init , name = 'embedding' )
233- network = RNNLayer (network , cell_fn = tf .contrib .rnn .BasicLSTMCell , \
234- cell_init_args = {'forget_bias' : 0.0 , 'state_is_tuple' : True }, \
235- n_hidden = hidden_size , initializer = rnn_init , n_steps = sequence_length , return_last = False ,
236- return_seq_2d = True , name = 'lstm1' )
233+ network = RNNLayer (
234+ network , cell_fn = tf .contrib .rnn .BasicLSTMCell , cell_init_args = {
235+ 'forget_bias' : 0.0 ,
236+ 'state_is_tuple' : True
237+ }, n_hidden = hidden_size , initializer = rnn_init , n_steps = sequence_length , return_last = False ,
238+ return_seq_2d = True , name = 'lstm1'
239+ )
237240 lstm1 = network
238241 network = DenseLayer (network , vocab_size , W_init = rnn_init , b_init = rnn_init , act = tf .identity , name = 'output' )
239242 return network , lstm1
@@ -297,14 +300,21 @@ def loss_fn(outputs, targets, batch_size, sequence_length):
297300 ## reset all states at the begining of every epoch
298301 state1 = tl .layers .initialize_rnn_state (lstm1 .initial_state )
299302 for step , (x , y ) in enumerate (tl .iterate .ptb_iterator (train_data , batch_size , sequence_length )):
300- _cost , state1 , _ = sess .run ([cost , lstm1 .final_state , train_op ], \
301- feed_dict = {input_data : x , targets : y , lstm1 .initial_state : state1 })
303+ _cost , state1 , _ = sess .run (
304+ [cost , lstm1 .final_state , train_op ], feed_dict = {
305+ input_data : x ,
306+ targets : y ,
307+ lstm1 .initial_state : state1
308+ }
309+ )
302310 costs += _cost
303311 iters += sequence_length
304312
305313 if step % (epoch_size // 10 ) == 1 :
306- print ("%.3f perplexity: %.3f speed: %.0f wps" % \
307- (step * 1.0 / epoch_size , np .exp (costs / iters ), iters * batch_size / (time .time () - start_time )))
314+ print (
315+ "%.3f perplexity: %.3f speed: %.0f wps" %
316+ (step * 1.0 / epoch_size , np .exp (costs / iters ), iters * batch_size / (time .time () - start_time ))
317+ )
308318 train_perplexity = np .exp (costs / iters )
309319 # print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
310320 print ("Epoch: %d/%d Train Perplexity: %.3f" % (i + 1 , max_max_epoch , train_perplexity ))
@@ -319,14 +329,22 @@ def loss_fn(outputs, targets, batch_size, sequence_length):
319329 # feed the seed to initialize the state for generation.
320330 for ids in outs_id [:- 1 ]:
321331 a_id = np .asarray (ids ).reshape (1 , 1 )
322- state1 = sess .run ([lstm1_test .final_state ], \
323- feed_dict = {input_data_test : a_id , lstm1_test .initial_state : state1 })
332+ state1 = sess .run (
333+ [lstm1_test .final_state ], feed_dict = {
334+ input_data_test : a_id ,
335+ lstm1_test .initial_state : state1
336+ }
337+ )
324338 # feed the last word in seed, and start to generate sentence.
325339 a_id = outs_id [- 1 ]
326340 for _ in range (print_length ):
327341 a_id = np .asarray (a_id ).reshape (1 , 1 )
328- out , state1 = sess .run ([y_soft , lstm1_test .final_state ], \
329- feed_dict = {input_data_test : a_id , lstm1_test .initial_state : state1 })
342+ out , state1 = sess .run (
343+ [y_soft , lstm1_test .final_state ], feed_dict = {
344+ input_data_test : a_id ,
345+ lstm1_test .initial_state : state1
346+ }
347+ )
330348 ## Without sampling
331349 # a_id = np.argmax(out[0])
332350 ## Sample from all words, if vocab_size is large,
0 commit comments