Skip to content

Commit 4790126

Browse files
authored
Update train.py
1 parent 83bce75 commit 4790126

File tree

1 file changed

+31
-28
lines changed

1 file changed

+31
-28
lines changed

train.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,44 @@ def train():
1616
G.train()
1717
D.train()
1818

19-
d_optimizer = tf.optimizers.Adam(flags.learning_rate, beta_1=flags.beta1)
20-
g_optimizer = tf.optimizers.Adam(flags.learning_rate, beta_1=flags.beta1)
19+
d_optimizer = tf.optimizers.Adam(flags.lr, beta_1=flags.beta1)
20+
g_optimizer = tf.optimizers.Adam(flags.lr, beta_1=flags.beta1)
2121

2222
n_step_epoch = int(len(images_path) // flags.batch_size)
23-
24-
for step, batch_images in enumerate(images):
25-
step_time = time.time()
26-
with tf.GradientTape(persistent=True) as tape:
27-
# z = tf.distributions.Normal(0., 1.).sample([flags.batch_size, flags.z_dim]) #tf.placeholder(tf.float32, [None, z_dim], name='z_noise')
28-
z = np.random.normal(loc=0.0, scale=1.0, size=[flags.batch_size, flags.z_dim]).astype(np.float32)
29-
d_logits = D(G(z))
30-
d2_logits = D(batch_images)
31-
# discriminator: real images are labelled as 1
32-
d_loss_real = tl.cost.sigmoid_cross_entropy(d2_logits, tf.ones_like(d2_logits), name='dreal')
33-
# discriminator: images from generator (fake) are labelled as 0
34-
d_loss_fake = tl.cost.sigmoid_cross_entropy(d_logits, tf.zeros_like(d_logits), name='dfake')
35-
# combined loss for updating discriminator
36-
d_loss = d_loss_real + d_loss_fake
37-
# generator: try to fool discriminator to output 1
38-
g_loss = tl.cost.sigmoid_cross_entropy(d_logits, tf.ones_like(d_logits), name='gfake')
39-
40-
grad = tape.gradient(g_loss, G.trainable_weights)
41-
g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
42-
grad = tape.gradient(d_loss, D.trainable_weights)
43-
d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
44-
del tape
45-
46-
print("Epoch: [{}/{}] [{}/{}] took: {:3f}, d_loss: {:5f}, g_loss: {:5f}".format(step//n_step_epoch, flags.n_epoch, step, n_step_epoch, time.time()-step_time, d_loss, g_loss))
47-
if np.mod(step, flags.save_step) == 0:
23+
24+
for epoch in range(flags.n_epoch):
25+
for step, batch_images in enumerate(images):
26+
step_time = time.time()
27+
with tf.GradientTape(persistent=True) as tape:
28+
# z = tf.distributions.Normal(0., 1.).sample([flags.batch_size, flags.z_dim]) #tf.placeholder(tf.float32, [None, z_dim], name='z_noise')
29+
z = np.random.normal(loc=0.0, scale=1.0, size=[flags.batch_size, flags.z_dim]).astype(np.float32)
30+
d_logits = D(G(z))
31+
d2_logits = D(batch_images)
32+
# discriminator: real images are labelled as 1
33+
d_loss_real = tl.cost.sigmoid_cross_entropy(d2_logits, tf.ones_like(d2_logits), name='dreal')
34+
# discriminator: images from generator (fake) are labelled as 0
35+
d_loss_fake = tl.cost.sigmoid_cross_entropy(d_logits, tf.zeros_like(d_logits), name='dfake')
36+
# combined loss for updating discriminator
37+
d_loss = d_loss_real + d_loss_fake
38+
# generator: try to fool discriminator to output 1
39+
g_loss = tl.cost.sigmoid_cross_entropy(d_logits, tf.ones_like(d_logits), name='gfake')
40+
41+
grad = tape.gradient(g_loss, G.trainable_weights)
42+
g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
43+
grad = tape.gradient(d_loss, D.trainable_weights)
44+
d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
45+
del tape
46+
47+
print("Epoch: [{}/{}] [{}/{}] took: {:3f}, d_loss: {:5f}, g_loss: {:5f}".format(epoch, \
48+
flags.n_epoch, step, n_step_epoch, time.time()-step_time, d_loss, g_loss))
49+
50+
if np.mod(epoch, flags.save_every_epoch) == 0:
4851
G.save_weights('{}/G.npz'.format(flags.checkpoint_dir), format='npz')
4952
D.save_weights('{}/D.npz'.format(flags.checkpoint_dir), format='npz')
5053
G.eval()
5154
result = G(z)
5255
G.train()
53-
tl.visualize.save_images(result.numpy(), [num_tiles, num_tiles], '{}/train_{:02d}_{:04d}.png'.format(flags.sample_dir, step//n_step_epoch, step))
56+
tl.visualize.save_images(result.numpy(), [num_tiles, num_tiles], '{}/train_{:02d}.png'.format(flags.sample_dir, epoch))
5457

5558
if __name__ == '__main__':
5659
train()

0 commit comments

Comments
 (0)