Skip to content

Commit 96a08b6

Browse files
authored
Merge pull request #127 from dakuang/tutorial-fix
added batch_size parameter in loss_fn
2 parents 586f8fc + 2a09af2 commit 96a08b6

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

example/tutorial_ptb_lstm_state_is_tuple.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def main(_):
180180
# same with MNIST example, it is the number of concurrent processes for
181181
# computational reasons.
182182

183-
# Training and Validing
183+
# Training and Validation
184184
input_data = tf.placeholder(tf.int32, [batch_size, num_steps])
185185
targets = tf.placeholder(tf.int32, [batch_size, num_steps])
186186
# Testing (Evaluation)
@@ -251,7 +251,7 @@ def inference(x, is_training, num_steps, reuse=None):
251251
# sess.run(tf.initialize_all_variables())
252252
tl.layers.initialize_global_variables(sess)
253253

254-
def loss_fn(outputs, targets):#, batch_size, num_steps):
254+
def loss_fn(outputs, targets, batch_size):
255255
# See tl.cost.cross_entropy_seq()
256256
# Returns the cost function of Cross-entropy of two sequences, implement
257257
# softmax internally.
@@ -270,11 +270,11 @@ def loss_fn(outputs, targets):#, batch_size, num_steps):
270270
return cost
271271

272272
# Cost for Training
273-
cost = loss_fn(network.outputs, targets)#, batch_size, num_steps)
273+
cost = loss_fn(network.outputs, targets, batch_size)
274274
# Cost for Validating
275-
cost_val = loss_fn(network_val.outputs, targets)#, batch_size, num_steps)
275+
cost_val = loss_fn(network_val.outputs, targets, batch_size)
276276
# Cost for Testing (Evaluation)
277-
cost_test = loss_fn(network_test.outputs, targets_test)#, 1, 1)
277+
cost_test = loss_fn(network_test.outputs, targets_test, 1)
278278

279279
# Truncated Backpropagation for training
280280
with tf.variable_scope('learning_rate'):
@@ -339,7 +339,7 @@ def loss_fn(outputs, targets):#, batch_size, num_steps):
339339
print("Epoch: %d/%d Train Perplexity: %.3f" % (i + 1, max_max_epoch,
340340
train_perplexity))
341341

342-
# Validing
342+
# Validation
343343
start_time = time.time()
344344
costs = 0.0; iters = 0
345345
# reset all states at the begining of every epoch

0 commit comments

Comments
 (0)