Skip to content

Commit 8296290

Browse files
pskrunner14zsdonghao
authored andcommitted
Minor fixes and improvements (#15)
* TL compatibility and code readabality * fix versions in README * changed w_init to glorot uni and added minibatch gen * resolve conflict
1 parent d22f5d7 commit 8296290

File tree

3 files changed

+57
-51
lines changed

3 files changed

+57
-51
lines changed

main.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@
2121
import tensorlayer as tl
2222

2323
from glob import glob
24-
from random import shuffle
2524

26-
from model import generator_simplified_api, discriminator_simplified_api
2725
from utils import get_image
26+
from model import generator, discriminator
2827

29-
# Defile TF Flags
28+
# Define TF Flags
3029
flags = tf.app.flags
3130
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
3231
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
@@ -67,15 +66,15 @@ def main(_):
6766
real_images = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim], name='real_images')
6867

6968
# Input noise into generator for training
70-
net_g, g_logits = generator_simplified_api(z, is_train=True, reuse=False)
69+
net_g = generator(z, is_train=True, reuse=False)
7170

7271
# Input real and generated fake images into discriminator for training
73-
net_d, d_logits = discriminator_simplified_api(net_g.outputs, is_train=True, reuse=False)
74-
net_d2, d2_logits = discriminator_simplified_api(real_images, is_train=True, reuse=True)
72+
net_d, d_logits = discriminator(net_g.outputs, is_train=True, reuse=False)
73+
_, d2_logits = discriminator(real_images, is_train=True, reuse=True)
7574

7675
# Input noise into generator for evaluation
7776
# set is_train to False so that BatchNormLayer behave differently
78-
net_g2, g2_logits = generator_simplified_api(z, is_train=False, reuse=True)
77+
net_g2 = generator(z, is_train=False, reuse=True)
7978

8079
""" Define Training Operations """
8180
# cost for updating discriminator and generator
@@ -111,51 +110,59 @@ def main(_):
111110
net_g_name = os.path.join(save_dir, 'net_g.npz')
112111
net_d_name = os.path.join(save_dir, 'net_d.npz')
113112

114-
data_files = glob(os.path.join("./data", FLAGS.dataset, "*.jpg"))
113+
data_files = np.array(glob(os.path.join("./data", FLAGS.dataset, "*.jpg")))
114+
num_files = len(data_files)
115+
shuffle = True
116+
117+
# Mini-batch generator
118+
def iterate_minibatches():
119+
if shuffle:
120+
indices = np.random.permutation(num_files)
121+
for start_idx in range(0, num_files - FLAGS.batch_size + 1, FLAGS.batch_size):
122+
if shuffle:
123+
excerpt = indices[start_idx: start_idx + FLAGS.batch_size]
124+
else:
125+
excerpt = slice(start_idx, start_idx + FLAGS.batch_size)
126+
# Get real images (more image augmentation functions at [http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html])
127+
yield np.array([get_image(file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0)
128+
for file in data_files[excerpt]]).astype(np.float32)
129+
130+
batch_steps = min(num_files, FLAGS.train_size) // FLAGS.batch_size
115131

116-
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)# sample_seed = np.random.uniform(low=-1, high=1, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
132+
# sample noise
133+
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
117134

118135
""" Training models """
119136
iter_counter = 0
120137
for epoch in range(FLAGS.epoch):
121138

122-
# Shuffle data
123-
shuffle(data_files)
124-
125-
# Update sample files based on shuffled data
126-
sample_files = data_files[0:FLAGS.sample_size]
127-
sample = [get_image(sample_file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0) for sample_file in sample_files]
128-
sample_images = np.array(sample).astype(np.float32)
139+
sample_images = next(iterate_minibatches())
129140
print("[*] Sample images updated!")
141+
142+
steps = 0
143+
for batch_images in iterate_minibatches():
130144

131-
# Load image data
132-
batch_idxs = min(len(data_files), FLAGS.train_size) // FLAGS.batch_size
133-
134-
for idx in range(0, batch_idxs):
135-
batch_files = data_files[idx*FLAGS.batch_size:(idx + 1) * FLAGS.batch_size]
136-
137-
# Get real images (more image augmentation functions at [http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html])
138-
batch = [get_image(batch_file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0) for batch_file in batch_files]
139-
batch_images = np.array(batch).astype(np.float32)
140145
batch_z = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
141146
start_time = time.time()
142147

143148
# Updates the Discriminator(D)
144-
errD, _ = sess.run([d_loss, d_optim], feed_dict={z: batch_z, real_images: batch_images })
149+
errD, _ = sess.run([d_loss, d_optim], feed_dict={z: batch_z, real_images: batch_images})
145150

146151
# Updates the Generator(G)
147152
# run generator twice to make sure that d_loss does not go to zero (different from paper)
148153
for _ in range(2):
149154
errG, _ = sess.run([g_loss, g_optim], feed_dict={z: batch_z})
155+
156+
end_time = time.time() - start_time
150157
print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
151-
% (epoch, FLAGS.epoch, idx, batch_idxs, time.time() - start_time, errD, errG))
158+
% (epoch, FLAGS.epoch, steps, batch_steps, end_time, errD, errG))
152159

153160
iter_counter += 1
154161
if np.mod(iter_counter, FLAGS.sample_step) == 0:
155162
# Generate images
156-
img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z : sample_seed, real_images: sample_images})
163+
img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z: sample_seed, real_images: sample_images})
157164
# Visualize generated images
158-
tl.visualize.save_images(img, [8, 8], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, idx))
165+
tl.visualize.save_images(img, [8, 8], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, steps))
159166
print("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD, errG))
160167

161168
if np.mod(iter_counter, FLAGS.save_step) == 0:
@@ -164,6 +171,8 @@ def main(_):
164171
tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
165172
tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
166173
print("[*] Saving checkpoints SUCCESS!")
174+
175+
steps += 1
167176

168177
sess.close()
169178

model.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,29 @@
1313
flags = tf.app.flags
1414
FLAGS = flags.FLAGS
1515

16-
def generator_simplified_api(inputs, is_train=True, reuse=False):
16+
def generator(inputs, is_train=True, reuse=False):
1717
image_size = 64
18-
s2, s4, s8, s16 = int(image_size/2), int(image_size/4), int(image_size/8), int(image_size/16)
19-
gf_dim = 64 # Dimension of gen filters in first conv layer. [64]
20-
c_dim = FLAGS.c_dim # n_color 3
21-
w_init = tf.random_normal_initializer(stddev=0.02)
18+
s16 = image_size // 16
19+
gf_dim = 64 # Dimension of gen filters in first conv layer. [64]
20+
c_dim = FLAGS.c_dim # n_color 3
21+
w_init = tf.glorot_normal_initializer()
2222
gamma_init = tf.random_normal_initializer(1., 0.02)
2323

2424
with tf.variable_scope("generator", reuse=reuse):
2525

2626
net_in = InputLayer(inputs, name='g/in')
27-
net_h0 = DenseLayer(net_in, n_units=gf_dim*8*s16*s16, W_init=w_init,
27+
net_h0 = DenseLayer(net_in, n_units=(gf_dim * 8 * s16 * s16), W_init=w_init,
2828
act = tf.identity, name='g/h0/lin')
2929
net_h0 = ReshapeLayer(net_h0, shape=[-1, s16, s16, gf_dim*8], name='g/h0/reshape')
3030
net_h0 = BatchNormLayer(net_h0, act=tf.nn.relu, is_train=is_train,
3131
gamma_init=gamma_init, name='g/h0/batch_norm')
3232

33-
net_h1 = DeConv2d(net_h0, gf_dim*4, (5, 5), strides=(2, 2),
33+
net_h1 = DeConv2d(net_h0, gf_dim * 4, (5, 5), strides=(2, 2),
3434
padding='SAME', act=None, W_init=w_init, name='g/h1/decon2d')
3535
net_h1 = BatchNormLayer(net_h1, act=tf.nn.relu, is_train=is_train,
3636
gamma_init=gamma_init, name='g/h1/batch_norm')
3737

38-
net_h2 = DeConv2d(net_h1, gf_dim*2, (5, 5), strides=(2, 2),
38+
net_h2 = DeConv2d(net_h1, gf_dim * 2, (5, 5), strides=(2, 2),
3939
padding='SAME', act=None, W_init=w_init, name='g/h2/decon2d')
4040
net_h2 = BatchNormLayer(net_h2, act=tf.nn.relu, is_train=is_train,
4141
gamma_init=gamma_init, name='g/h2/batch_norm')
@@ -47,14 +47,12 @@ def generator_simplified_api(inputs, is_train=True, reuse=False):
4747

4848
net_h4 = DeConv2d(net_h3, c_dim, (5, 5), strides=(2, 2),
4949
padding='SAME', act=None, W_init=w_init, name='g/h4/decon2d')
50-
logits = net_h4.outputs
5150
net_h4.outputs = tf.nn.tanh(net_h4.outputs)
52-
return net_h4, logits
51+
return net_h4
5352

54-
def discriminator_simplified_api(inputs, is_train=True, reuse=False):
55-
df_dim = 64 # Dimension of discrim filters in first conv layer. [64]
56-
c_dim = FLAGS.c_dim # n_color 3
57-
w_init = tf.random_normal_initializer(stddev=0.02)
53+
def discriminator(inputs, is_train=True, reuse=False):
54+
df_dim = 64 # Dimension of discrim filters in first conv layer. [64]
55+
w_init = tf.glorot_normal_initializer()
5856
gamma_init = tf.random_normal_initializer(1., 0.02)
5957

6058
with tf.variable_scope("discriminator", reuse=reuse):

utils.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from random import shuffle
2-
31
import scipy.misc
2+
import imageio as io
43
import numpy as np
54

65
def center_crop(x, crop_h, crop_w=None, resize_w=64):
@@ -18,27 +17,27 @@ def merge(images, size):
1817
for idx, image in enumerate(images):
1918
i = idx % size[1]
2019
j = idx // size[1]
21-
img[j*h:j*h+h, i*w:i*w+w, :] = image
20+
img[j * h: j * h + h, i * w: i * w + w, :] = image
2221
return img
2322

2423
def transform(image, npx=64, is_crop=True, resize_w=64):
2524
if is_crop:
2625
cropped_image = center_crop(image, npx, resize_w=resize_w)
2726
else:
2827
cropped_image = image
29-
return np.array(cropped_image)/127.5 - 1.
28+
return (np.array(cropped_image) / 127.5) - 1.
3029

3130
def inverse_transform(images):
32-
return (images+1.)/2.
31+
return (images + 1.) / 2.
3332

3433
def imread(path, is_grayscale = False):
3534
if (is_grayscale):
36-
return scipy.misc.imread(path, flatten = True).astype(np.float)
35+
return io.imread(path).astype(np.float).flatten()
3736
else:
38-
return scipy.misc.imread(path).astype(np.float)
37+
return io.imread(path).astype(np.float)
3938

4039
def imsave(images, size, path):
41-
return scipy.misc.imsave(path, merge(images, size))
40+
return io.imsave(path, merge(images, size))
4241

4342
def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale = False):
4443
return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w)

0 commit comments

Comments
 (0)