Skip to content

Commit ac1925b

Browse files
authored
Update train.py
1 parent 90812b6 commit ac1925b

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def train():
3737
# generator: try to fool discriminator to output 1
3838
g_loss = tl.cost.sigmoid_cross_entropy(d_logits, tf.ones_like(d_logits), name='gfake')
3939

40-
grad = tape.gradient(g_loss, G.weights)
41-
g_optimizer.apply_gradients(zip(grad, G.weights))
42-
grad = tape.gradient(d_loss, D.weights)
43-
d_optimizer.apply_gradients(zip(grad, D.weights))
40+
grad = tape.gradient(g_loss, G.trainable_weights)
41+
g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
42+
grad = tape.gradient(d_loss, D.trainable_weights)
43+
d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
4444
del tape
4545

4646
print("Epoch: [{}/{}] [{}/{}] took: {:3f}, d_loss: {:5f}, g_loss: {:5f}".format(step//n_step_epoch, flags.n_epoch, step, n_step_epoch, time.time()-step_time, d_loss, g_loss))

0 commit comments

Comments
 (0)