Skip to content

Commit 7ef2bde

Browse files
authored
Merge pull request #16 from pskrunner14/master
Input dimensions bug and variable sample size error
2 parents 8296290 + 60fdf7a commit 7ef2bde

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

main.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
FLAGS = flags.FLAGS
4848

4949
def main(_):
50+
assert np.sqrt(FLAGS.sample_size) % 1 == 0., 'Flag `sample_size` needs to be a perfect square'
51+
num_tiles = int(np.sqrt(FLAGS.sample_size))
52+
5053
# Print flags
5154
for flag, _ in FLAGS.__flags.items():
5255
print('"{}": {}'.format(flag, getattr(FLAGS, flag)))
@@ -62,8 +65,8 @@ def main(_):
6265
with tf.device("/gpu:0"):
6366

6467
""" Define Models """
65-
z = tf.placeholder(tf.float32, [FLAGS.batch_size, z_dim], name='z_noise')
66-
real_images = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim], name='real_images')
68+
z = tf.placeholder(tf.float32, [None, z_dim], name='z_noise')
69+
real_images = tf.placeholder(tf.float32, [None, FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim], name='real_images')
6770

6871
# Input noise into generator for training
6972
net_g = generator(z, is_train=True, reuse=False)
@@ -77,12 +80,11 @@ def main(_):
7780
net_g2 = generator(z, is_train=False, reuse=True)
7881

7982
""" Define Training Operations """
80-
# cost for updating discriminator and generator
8183
# discriminator: real images are labelled as 1
8284
d_loss_real = tl.cost.sigmoid_cross_entropy(d2_logits, tf.ones_like(d2_logits), name='dreal')
83-
8485
# discriminator: images from generator (fake) are labelled as 0
8586
d_loss_fake = tl.cost.sigmoid_cross_entropy(d_logits, tf.zeros_like(d_logits), name='dfake')
87+
# cost for updating discriminator
8688
d_loss = d_loss_real + d_loss_fake
8789

8890
# generator: try to make the the fake images look real (1)
@@ -112,17 +114,16 @@ def main(_):
112114

113115
data_files = np.array(glob(os.path.join("./data", FLAGS.dataset, "*.jpg")))
114116
num_files = len(data_files)
115-
shuffle = True
116117

117118
# Mini-batch generator
118-
def iterate_minibatches():
119+
def iterate_minibatches(batch_size, shuffle=True):
119120
if shuffle:
120121
indices = np.random.permutation(num_files)
121-
for start_idx in range(0, num_files - FLAGS.batch_size + 1, FLAGS.batch_size):
122+
for start_idx in range(0, num_files - batch_size + 1, batch_size):
122123
if shuffle:
123-
excerpt = indices[start_idx: start_idx + FLAGS.batch_size]
124+
excerpt = indices[start_idx: start_idx + batch_size]
124125
else:
125-
excerpt = slice(start_idx, start_idx + FLAGS.batch_size)
126+
excerpt = slice(start_idx, start_idx + batch_size)
126127
# Get real images (more image augmentation functions at [http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html])
127128
yield np.array([get_image(file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0)
128129
for file in data_files[excerpt]]).astype(np.float32)
@@ -136,13 +137,13 @@ def iterate_minibatches():
136137
iter_counter = 0
137138
for epoch in range(FLAGS.epoch):
138139

139-
sample_images = next(iterate_minibatches())
140+
sample_images = next(iterate_minibatches(FLAGS.sample_size))
140141
print("[*] Sample images updated!")
141142

142143
steps = 0
143-
for batch_images in iterate_minibatches():
144+
for batch_images in iterate_minibatches(FLAGS.batch_size):
144145

145-
batch_z = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
146+
batch_z = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.batch_size, z_dim)).astype(np.float32)
146147
start_time = time.time()
147148

148149
# Updates the Discriminator(D)
@@ -162,7 +163,7 @@ def iterate_minibatches():
162163
# Generate images
163164
img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z: sample_seed, real_images: sample_images})
164165
# Visualize generated images
165-
tl.visualize.save_images(img, [8, 8], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, steps))
166+
tl.visualize.save_images(img, [num_tiles, num_tiles], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, steps))
166167
print("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD, errG))
167168

168169
if np.mod(iter_counter, FLAGS.save_step) == 0:
@@ -171,10 +172,13 @@ def iterate_minibatches():
171172
tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
172173
tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
173174
print("[*] Saving checkpoints SUCCESS!")
174-
175+
175176
steps += 1
176-
177+
177178
sess.close()
178179

179180
if __name__ == '__main__':
180-
tf.app.run()
181+
try:
182+
tf.app.run()
183+
except KeyboardInterrupt:
184+
print('EXIT')

0 commit comments

Comments
 (0)