@@ -969,15 +969,15 @@ the data loader every epoch and then write the GAN training code:
969969 discriminator->zero_grad();
970970 torch::Tensor real_images = batch.data;
971971 torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
972- torch::Tensor real_output = discriminator->forward(real_images);
972+ torch::Tensor real_output = discriminator->forward(real_images).reshape(real_labels.sizes()) ;
973973 torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
974974 d_loss_real.backward();
975975
976976 // Train discriminator with fake images.
977977 torch::Tensor noise = torch::randn({batch.data.size(0), kNoiseSize, 1, 1});
978978 torch::Tensor fake_images = generator->forward(noise);
979979 torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
980- torch::Tensor fake_output = discriminator->forward(fake_images.detach());
980+ torch::Tensor fake_output = discriminator->forward(fake_images.detach()).reshape(fake_labels.sizes()) ;
981981 torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
982982 d_loss_fake.backward();
983983
@@ -987,7 +987,7 @@ the data loader every epoch and then write the GAN training code:
987987 // Train generator.
988988 generator->zero_grad();
989989 fake_labels.fill_(1);
990- fake_output = discriminator->forward(fake_images);
990+ fake_output = discriminator->forward(fake_images).reshape(fake_labels.sizes()) ;
991991 torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
992992 g_loss.backward();
993993 generator_optimizer.step();
0 commit comments