Skip to content

Commit 5e3d506

Browse files
authored
Merge pull request #20 from tensorlayer/tl2
update to TL2
2 parents 5834c50 + da778ad commit 5e3d506

File tree

126 files changed

+89
-31173
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

126 files changed

+89
-31173
lines changed

README.md

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,30 @@
11
# DCGAN in TensorLayer
22

3-
TensorLayer implementation of [Deep Convolutional Generative Adversarial Networks](http://arxiv.org/abs/1511.06434).
43

4+
This is the TensorLayer implementation of [Deep Convolutional Generative Adversarial Networks](http://arxiv.org/abs/1511.06434).
55
Looking for Text to Image Synthesis ? [click here](https://github.com/zsdonghao/text-to-image)
66

77
![alt tag](img/DCGAN.png)
88

9+
10+
- 🆕 🔥 2019 May: We just update this project to support TF2 and TL2. Enjoy!
11+
- 🆕 🔥 2019 May: This project is chosen as the default template of TL projects.
12+
13+
914
## Prerequisites
1015

11-
- Python3
12-
- TensorFlow==1.13
13-
- TensorLayer (self-contained)
16+
- Python3.5 3.6
17+
- TensorFlow==2.0.0a0 `pip3 install tensorflow-gpu==2.0.0a0`
18+
- TensorLayer=2.0.0 `pip3 install tensorlayer==2.0.0`
1419

1520
## Usage
1621

1722
First, download the aligned face images from [google](https://drive.google.com/open?id=0B7EVK8r0v71pWEZsZE9oNnFzTm8) or [baidu](https://pan.baidu.com/s/1eSNpdRG#list/path=%2F) to a `data` folder.
1823

1924
Second, train the GAN:
2025

21-
$ python main_eager_mode.py # single GPU for TF>=1.13
22-
$ python main_graph_mode.py # single GPU for TF<=1.13
23-
$ python main_eager_mode_horovod.py # multiple GPU (TODO)
24-
$ python main_eager_mode_tlmagic.py # multiple GPU (TODO)
25-
26+
$ python train.py
27+
2628
## Result on celebA
2729

2830

data.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import os
2+
import numpy as np
3+
import tensorflow as tf
4+
import tensorlayer as tl
5+
## enable debug logging
6+
tl.logging.set_verbosity(tl.logging.DEBUG)
7+
8+
class FLAGS(object):
9+
def __init__(self):
10+
self.n_epoch = 25 # "Epoch to train [25]"
11+
self.z_dim = 100 # "Num of noise value]"
12+
self.learning_rate = 0.0002 # "Learning rate of for adam [0.0002]")
13+
self.beta1 = 0.5 # "Momentum term of adam [0.5]")
14+
self.batch_size = 64 # "The number of batch images [64]")
15+
self.output_size = 64 # "The size of the output images to produce [64]")
16+
self.sample_size = 64 # "The number of sample images [64]")
17+
self.c_dim = 3 # "Number of image channels. [3]")
18+
self.save_step = 500 # "The interval of saveing checkpoints. [500]")
19+
# self.dataset = "celebA" # "The name of dataset [celebA, mnist, lsun]")
20+
self.checkpoint_dir = "checkpoint" # "Directory name to save the checkpoints [checkpoint]")
21+
self.sample_dir = "samples" # "Directory name to save the image samples [samples]")
22+
assert np.sqrt(self.sample_size) % 1 == 0., 'Flag `sample_size` needs to be a perfect square'
23+
flags = FLAGS()
24+
25+
tl.files.exists_or_mkdir(flags.checkpoint_dir) # save model
26+
tl.files.exists_or_mkdir(flags.sample_dir) # save generated image
27+
28+
def get_celebA(output_size, n_epoch, batch_size):
29+
# dataset API and augmentation
30+
images_path = tl.files.load_file_list(path='data', regx='.*.jpg', keep_prefix=True, printable=False)
31+
def generator_train():
32+
for image_path in images_path:
33+
yield image_path.encode('utf-8')
34+
def _map_fn(image_path):
35+
image = tf.io.read_file(image_path)
36+
image = tf.image.decode_jpeg(image, channels=3) # get RGB with 0~1
37+
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
38+
# image = tf.image.crop_central(image, [FLAGS.output_size, FLAGS.output_size, FLAGS.c_dim])
39+
# image = tf.image.resize_images(image, FLAGS.output_size])
40+
image = image[45:173, 25:153, :] # central crop
41+
image = tf.image.resize([image], (output_size, output_size))[0]
42+
# image = tf.image.crop_and_resize(image, boxes=[[]], crop_size=[64, 64])
43+
# image = tf.image.resize_image_with_crop_or_pad(image, FLAGS.output_size, FLAGS.output_size) # central crop
44+
image = tf.image.random_flip_left_right(image)
45+
image = image * 2 - 1
46+
return image
47+
train_ds = tf.data.Dataset.from_generator(generator_train, output_types=tf.string)
48+
ds = train_ds.shuffle(buffer_size=4096)
49+
# ds = ds.shard(num_shards=hvd.size(), index=hvd.rank())
50+
ds = ds.repeat(n_epoch)
51+
ds = ds.map(_map_fn, num_parallel_calls=4)
52+
ds = ds.batch(batch_size)
53+
ds = ds.prefetch(buffer_size=2)
54+
return ds, images_path
55+
# for batch_images in train_ds:
56+
# print(batch_images.shape)
57+
# value = ds.make_one_shot_iterator().get_next()

main_eager_mode2.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

main_eager_mode_horovod.py

Lines changed: 0 additions & 89 deletions
This file was deleted.

main_graph_mode.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
import tensorlayer as tl
33
from tensorlayer.layers import Input, Dense, DeConv2d, Reshape, BatchNorm2d, Conv2d, Flatten, BatchNorm
44

5-
flags = tf.app.flags
6-
FLAGS = flags.FLAGS
7-
85
def get_generator(shape, gf_dim=64): # Dimension of gen filters in first conv layer. [64]
96
image_size = 64
107
s16 = image_size // 16
11-
w_init = tf.glorot_normal_initializer()
8+
# w_init = tf.glorot_normal_initializer()
9+
w_init = tf.random_normal_initializer(stddev=0.02)
1210
gamma_init = tf.random_normal_initializer(1., 0.02)
1311

1412
ni = Input(shape)
@@ -26,7 +24,8 @@ def get_generator(shape, gf_dim=64): # Dimension of gen filters in first conv la
2624
return tl.models.Model(inputs=ni, outputs=nn, name='generator')
2725

2826
def get_discriminator(shape, df_dim=64): # Dimension of discrim filters in first conv layer. [64]
29-
w_init = tf.glorot_normal_initializer()
27+
# w_init = tf.glorot_normal_initializer()
28+
w_init = tf.random_normal_initializer(stddev=0.02)
3029
gamma_init = tf.random_normal_initializer(1., 0.02)
3130
lrelu = lambda x : tf.nn.leaky_relu(x, 0.2)
3231

0 commit comments

Comments
 (0)