diff --git a/notebooks/logistic_regression_mnist.ipynb b/notebooks/logistic_regression_mnist.ipynb index 9a326c1..be5b16f 100755 --- a/notebooks/logistic_regression_mnist.ipynb +++ b/notebooks/logistic_regression_mnist.ipynb @@ -155,19 +155,20 @@ "# Launch the graph\n", "with tf.Session() as sess:\n", " sess.run(init)\n", + " n_train = trainimg.shape[0]\n", "\n", " # Training cycle\n", " for epoch in range(training_epochs):\n", " avg_cost = 0.\n", " num_batch = int(mnist.train.num_examples/batch_size)\n", + " randidx = np.random.permutation(n_train)\n", " # Loop over all batches\n", " for i in range(num_batch): \n", " if 0: # Using tensorflow API\n", " batch_xs, batch_ys = mnist.train.next_batch(batch_size)\n", " else: # Random batch sampling \n", - " randidx = np.random.randint(trainimg.shape[0], size=batch_size)\n", - " batch_xs = trainimg[randidx, :]\n", - " batch_ys = trainlabel[randidx, :] \n", + " batch_xs = trainimg[randidx[i*batch_size:(i+1)*batch_size], :]\n", + " batch_ys = trainlabel[randidx[i*batch_size:(i+1)*batch_size], :]\n", " \n", " # Fit training using batch data\n", " sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys})\n",