Skip to content

Commit d22f5d7

Browse files
pskrunner14zsdonghao
authored andcommitted
TL compatibility and code readabality (#14)
* TL compatibility and code readabality * fix versions in README
1 parent 14b6fd6 commit d22f5d7

File tree

114 files changed

+109
-17827
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

114 files changed

+109
-17827
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,7 @@ dist
1010
docs/_build
1111
tensorlayer.egg-info
1212
tensorlayer/__pacache__
13+
14+
.vscode/*
15+
data/*
16+
samples/*

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ Looking for Text to Image Synthesis ? [click here](https://github.com/zsdonghao/
1313
## Prerequisites
1414

1515
- Python 2.7 or Python 3.3+
16-
- [TensorFlow==1.0+](https://www.tensorflow.org/)
17-
- [TensorLayer==1.4+](https://github.com/zsdonghao/tensorlayer)
16+
- [TensorFlow==1.10.0+](https://www.tensorflow.org/)
17+
- [TensorLayer==1.10.1+](https://github.com/tensorlayer/tensorlayer)
1818

1919

2020
## Usage

main.py

Lines changed: 65 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,37 @@
1-
import os, pprint, time
1+
""" TensorLayer implementation of Deep Convolutional Generative Adversarial Network (DCGAN).
2+
Using deep convolutional generative adversarial networks (DCGAN)
3+
to generate face images from a noise distribution.
4+
References:
5+
-Generative Adversarial Nets.
6+
Goodfellow et al. arXiv: 1406.2661.
7+
- Unsupervised Representation Learning with Deep Convolutional
8+
Generative Adversarial Networks. A Radford, L Metz, S Chintala.
9+
arXiv: 1511.06434.
10+
Links:
11+
- [GAN Paper](https://arxiv.org/pdf/1406.2661.pdf)
12+
- [DCGAN Paper](https://arxiv.org/abs/1511.06434)
13+
Usage:
14+
- See README.md
15+
"""
16+
import os
17+
import time
18+
219
import numpy as np
320
import tensorflow as tf
421
import tensorlayer as tl
5-
from tensorlayer.layers import *
22+
623
from glob import glob
724
from random import shuffle
8-
from model import *
9-
from utils import *
10-
11-
pp = pprint.PrettyPrinter()
1225

13-
"""
14-
TensorLayer implementation of DCGAN to generate face image.
26+
from model import generator_simplified_api, discriminator_simplified_api
27+
from utils import get_image
1528

16-
Usage : see README.md
17-
"""
29+
# Defile TF Flags
1830
flags = tf.app.flags
1931
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
2032
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
2133
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
22-
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
34+
flags.DEFINE_float("train_size", np.inf, "The size of train images [np.inf]")
2335
flags.DEFINE_integer("batch_size", 64, "The number of batch images [64]")
2436
flags.DEFINE_integer("image_size", 108, "The size of image to use (will be center cropped) [108]")
2537
flags.DEFINE_integer("output_size", 64, "The size of the output images to produce [64]")
@@ -36,57 +48,65 @@
3648
FLAGS = flags.FLAGS
3749

3850
def main(_):
39-
pp.pprint(flags.FLAGS.__flags)
51+
# Print flags
52+
for flag, _ in FLAGS.__flags.items():
53+
print('"{}": {}'.format(flag, getattr(FLAGS, flag)))
54+
print("--------------------")
4055

56+
# Configure checkpoint/samples dir
4157
tl.files.exists_or_mkdir(FLAGS.checkpoint_dir)
4258
tl.files.exists_or_mkdir(FLAGS.sample_dir)
4359

44-
z_dim = 100
60+
z_dim = 100 # noise dim
61+
62+
# Construct graph on GPU
4563
with tf.device("/gpu:0"):
46-
##========================= DEFINE MODEL ===========================##
64+
65+
""" Define Models """
4766
z = tf.placeholder(tf.float32, [FLAGS.batch_size, z_dim], name='z_noise')
4867
real_images = tf.placeholder(tf.float32, [FLAGS.batch_size, FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim], name='real_images')
4968

50-
# z --> generator for training
69+
# Input noise into generator for training
5170
net_g, g_logits = generator_simplified_api(z, is_train=True, reuse=False)
52-
# generated fake images --> discriminator
71+
72+
# Input real and generated fake images into discriminator for training
5373
net_d, d_logits = discriminator_simplified_api(net_g.outputs, is_train=True, reuse=False)
54-
# real images --> discriminator
5574
net_d2, d2_logits = discriminator_simplified_api(real_images, is_train=True, reuse=True)
56-
# sample_z --> generator for evaluation, set is_train to False
57-
# so that BatchNormLayer behave differently
75+
76+
# Input noise into generator for evaluation
77+
# set is_train to False so that BatchNormLayer behave differently
5878
net_g2, g2_logits = generator_simplified_api(z, is_train=False, reuse=True)
5979

60-
##========================= DEFINE TRAIN OPS =======================##
80+
""" Define Training Operations """
6181
# cost for updating discriminator and generator
6282
# discriminator: real images are labelled as 1
6383
d_loss_real = tl.cost.sigmoid_cross_entropy(d2_logits, tf.ones_like(d2_logits), name='dreal')
84+
6485
# discriminator: images from generator (fake) are labelled as 0
6586
d_loss_fake = tl.cost.sigmoid_cross_entropy(d_logits, tf.zeros_like(d_logits), name='dfake')
6687
d_loss = d_loss_real + d_loss_fake
88+
6789
# generator: try to make the the fake images look real (1)
6890
g_loss = tl.cost.sigmoid_cross_entropy(d_logits, tf.ones_like(d_logits), name='gfake')
6991

7092
g_vars = tl.layers.get_variables_with_name('generator', True, True)
7193
d_vars = tl.layers.get_variables_with_name('discriminator', True, True)
7294

73-
net_g.print_params(False)
74-
print("---------------")
75-
net_d.print_params(False)
76-
77-
# optimizers for updating discriminator and generator
95+
# Define optimizers for updating discriminator and generator
7896
d_optim = tf.train.AdamOptimizer(FLAGS.learning_rate, beta1=FLAGS.beta1) \
7997
.minimize(d_loss, var_list=d_vars)
8098
g_optim = tf.train.AdamOptimizer(FLAGS.learning_rate, beta1=FLAGS.beta1) \
8199
.minimize(g_loss, var_list=g_vars)
82100

101+
# Init Session
83102
sess = tf.InteractiveSession()
84-
tl.layers.initialize_global_variables(sess)
103+
sess.run(tf.global_variables_initializer())
85104

86105
model_dir = "%s_%s_%s" % (FLAGS.dataset, FLAGS.batch_size, FLAGS.output_size)
87106
save_dir = os.path.join(FLAGS.checkpoint_dir, model_dir)
88107
tl.files.exists_or_mkdir(FLAGS.sample_dir)
89108
tl.files.exists_or_mkdir(save_dir)
109+
90110
# load the latest checkpoints
91111
net_g_name = os.path.join(save_dir, 'net_g.npz')
92112
net_d_name = os.path.join(save_dir, 'net_d.npz')
@@ -95,50 +115,57 @@ def main(_):
95115

96116
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)# sample_seed = np.random.uniform(low=-1, high=1, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
97117

98-
##========================= TRAIN MODELS ================================##
118+
""" Training models """
99119
iter_counter = 0
100120
for epoch in range(FLAGS.epoch):
101-
## shuffle data
121+
122+
# Shuffle data
102123
shuffle(data_files)
103124

104-
## update sample files based on shuffled data
125+
# Update sample files based on shuffled data
105126
sample_files = data_files[0:FLAGS.sample_size]
106127
sample = [get_image(sample_file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0) for sample_file in sample_files]
107128
sample_images = np.array(sample).astype(np.float32)
108129
print("[*] Sample images updated!")
109130

110-
## load image data
131+
# Load image data
111132
batch_idxs = min(len(data_files), FLAGS.train_size) // FLAGS.batch_size
112133

113134
for idx in range(0, batch_idxs):
114-
batch_files = data_files[idx*FLAGS.batch_size:(idx+1)*FLAGS.batch_size]
115-
## get real images
116-
# more image augmentation functions in http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html
135+
batch_files = data_files[idx*FLAGS.batch_size:(idx + 1) * FLAGS.batch_size]
136+
137+
# Get real images (more image augmentation functions at [http://tensorlayer.readthedocs.io/en/latest/modules/prepro.html])
117138
batch = [get_image(batch_file, FLAGS.image_size, is_crop=FLAGS.is_crop, resize_w=FLAGS.output_size, is_grayscale = 0) for batch_file in batch_files]
118139
batch_images = np.array(batch).astype(np.float32)
119-
batch_z = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32) # batch_z = np.random.uniform(low=-1, high=1, size=(FLAGS.batch_size, z_dim)).astype(np.float32)
140+
batch_z = np.random.normal(loc=0.0, scale=1.0, size=(FLAGS.sample_size, z_dim)).astype(np.float32)
120141
start_time = time.time()
121-
# updates the discriminator
142+
143+
# Updates the Discriminator(D)
122144
errD, _ = sess.run([d_loss, d_optim], feed_dict={z: batch_z, real_images: batch_images })
123-
# updates the generator, run generator twice to make sure that d_loss does not go to zero (difference from paper)
145+
146+
# Updates the Generator(G)
147+
# run generator twice to make sure that d_loss does not go to zero (different from paper)
124148
for _ in range(2):
125149
errG, _ = sess.run([g_loss, g_optim], feed_dict={z: batch_z})
126150
print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
127151
% (epoch, FLAGS.epoch, idx, batch_idxs, time.time() - start_time, errD, errG))
128152

129153
iter_counter += 1
130154
if np.mod(iter_counter, FLAGS.sample_step) == 0:
131-
# generate and visualize generated images
155+
# Generate images
132156
img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z : sample_seed, real_images: sample_images})
157+
# Visualize generated images
133158
tl.visualize.save_images(img, [8, 8], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, idx))
134159
print("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD, errG))
135160

136161
if np.mod(iter_counter, FLAGS.save_step) == 0:
137-
# save current network parameters
162+
# Save current network parameters
138163
print("[*] Saving checkpoints...")
139164
tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
140165
tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
141166
print("[*] Saving checkpoints SUCCESS!")
167+
168+
sess.close()
142169

143170
if __name__ == '__main__':
144171
tf.app.run()

model.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1-
21
import tensorflow as tf
32
import tensorlayer as tl
4-
from tensorlayer.layers import *
3+
from tensorlayer.layers import (
4+
InputLayer,
5+
DenseLayer,
6+
DeConv2d,
7+
ReshapeLayer,
8+
BatchNormLayer,
9+
Conv2d,
10+
FlattenLayer
11+
)
512

613
flags = tf.app.flags
714
FLAGS = flags.FLAGS
@@ -11,11 +18,10 @@ def generator_simplified_api(inputs, is_train=True, reuse=False):
1118
s2, s4, s8, s16 = int(image_size/2), int(image_size/4), int(image_size/8), int(image_size/16)
1219
gf_dim = 64 # Dimension of gen filters in first conv layer. [64]
1320
c_dim = FLAGS.c_dim # n_color 3
14-
batch_size = FLAGS.batch_size # 64
1521
w_init = tf.random_normal_initializer(stddev=0.02)
1622
gamma_init = tf.random_normal_initializer(1., 0.02)
23+
1724
with tf.variable_scope("generator", reuse=reuse):
18-
tl.layers.set_name_reuse(reuse)
1925

2026
net_in = InputLayer(inputs, name='g/in')
2127
net_h0 = DenseLayer(net_in, n_units=gf_dim*8*s16*s16, W_init=w_init,
@@ -24,53 +30,52 @@ def generator_simplified_api(inputs, is_train=True, reuse=False):
2430
net_h0 = BatchNormLayer(net_h0, act=tf.nn.relu, is_train=is_train,
2531
gamma_init=gamma_init, name='g/h0/batch_norm')
2632

27-
net_h1 = DeConv2d(net_h0, gf_dim*4, (5, 5), out_size=(s8, s8), strides=(2, 2),
28-
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h1/decon2d')
33+
net_h1 = DeConv2d(net_h0, gf_dim*4, (5, 5), strides=(2, 2),
34+
padding='SAME', act=None, W_init=w_init, name='g/h1/decon2d')
2935
net_h1 = BatchNormLayer(net_h1, act=tf.nn.relu, is_train=is_train,
3036
gamma_init=gamma_init, name='g/h1/batch_norm')
3137

32-
net_h2 = DeConv2d(net_h1, gf_dim*2, (5, 5), out_size=(s4, s4), strides=(2, 2),
33-
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h2/decon2d')
38+
net_h2 = DeConv2d(net_h1, gf_dim*2, (5, 5), strides=(2, 2),
39+
padding='SAME', act=None, W_init=w_init, name='g/h2/decon2d')
3440
net_h2 = BatchNormLayer(net_h2, act=tf.nn.relu, is_train=is_train,
3541
gamma_init=gamma_init, name='g/h2/batch_norm')
3642

37-
net_h3 = DeConv2d(net_h2, gf_dim, (5, 5), out_size=(s2, s2), strides=(2, 2),
38-
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h3/decon2d')
43+
net_h3 = DeConv2d(net_h2, gf_dim, (5, 5), strides=(2, 2),
44+
padding='SAME', act=None, W_init=w_init, name='g/h3/decon2d')
3945
net_h3 = BatchNormLayer(net_h3, act=tf.nn.relu, is_train=is_train,
4046
gamma_init=gamma_init, name='g/h3/batch_norm')
4147

42-
net_h4 = DeConv2d(net_h3, c_dim, (5, 5), out_size=(image_size, image_size), strides=(2, 2),
43-
padding='SAME', batch_size=batch_size, act=None, W_init=w_init, name='g/h4/decon2d')
48+
net_h4 = DeConv2d(net_h3, c_dim, (5, 5), strides=(2, 2),
49+
padding='SAME', act=None, W_init=w_init, name='g/h4/decon2d')
4450
logits = net_h4.outputs
4551
net_h4.outputs = tf.nn.tanh(net_h4.outputs)
4652
return net_h4, logits
4753

4854
def discriminator_simplified_api(inputs, is_train=True, reuse=False):
4955
df_dim = 64 # Dimension of discrim filters in first conv layer. [64]
5056
c_dim = FLAGS.c_dim # n_color 3
51-
batch_size = FLAGS.batch_size # 64
5257
w_init = tf.random_normal_initializer(stddev=0.02)
5358
gamma_init = tf.random_normal_initializer(1., 0.02)
59+
5460
with tf.variable_scope("discriminator", reuse=reuse):
55-
tl.layers.set_name_reuse(reuse)
5661

5762
net_in = InputLayer(inputs, name='d/in')
58-
net_h0 = Conv2d(net_in, df_dim, (5, 5), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2),
63+
net_h0 = Conv2d(net_in, df_dim, (5, 5), (2, 2), act=tf.nn.leaky_relu,
5964
padding='SAME', W_init=w_init, name='d/h0/conv2d')
6065

6166
net_h1 = Conv2d(net_h0, df_dim*2, (5, 5), (2, 2), act=None,
6267
padding='SAME', W_init=w_init, name='d/h1/conv2d')
63-
net_h1 = BatchNormLayer(net_h1, act=lambda x: tl.act.lrelu(x, 0.2),
68+
net_h1 = BatchNormLayer(net_h1, act=tf.nn.leaky_relu,
6469
is_train=is_train, gamma_init=gamma_init, name='d/h1/batch_norm')
6570

6671
net_h2 = Conv2d(net_h1, df_dim*4, (5, 5), (2, 2), act=None,
6772
padding='SAME', W_init=w_init, name='d/h2/conv2d')
68-
net_h2 = BatchNormLayer(net_h2, act=lambda x: tl.act.lrelu(x, 0.2),
73+
net_h2 = BatchNormLayer(net_h2, act=tf.nn.leaky_relu,
6974
is_train=is_train, gamma_init=gamma_init, name='d/h2/batch_norm')
7075

7176
net_h3 = Conv2d(net_h2, df_dim*8, (5, 5), (2, 2), act=None,
7277
padding='SAME', W_init=w_init, name='d/h3/conv2d')
73-
net_h3 = BatchNormLayer(net_h3, act=lambda x: tl.act.lrelu(x, 0.2),
78+
net_h3 = BatchNormLayer(net_h3, act=tf.nn.leaky_relu,
7479
is_train=is_train, gamma_init=gamma_init, name='d/h3/batch_norm')
7580

7681
net_h4 = FlattenLayer(net_h3, name='d/h4/flatten')

tensorlayer/__init__.py

Lines changed: 0 additions & 29 deletions
This file was deleted.
-1 KB
Binary file not shown.
-971 Bytes
Binary file not shown.
-545 Bytes
Binary file not shown.
-648 Bytes
Binary file not shown.
-3.58 KB
Binary file not shown.

0 commit comments

Comments
 (0)