diff --git a/advanced_source/cpp_frontend.rst b/advanced_source/cpp_frontend.rst index 901658183c7..de22fbf05a1 100644 --- a/advanced_source/cpp_frontend.rst +++ b/advanced_source/cpp_frontend.rst @@ -969,7 +969,7 @@ the data loader every epoch and then write the GAN training code: discriminator->zero_grad(); torch::Tensor real_images = batch.data; torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0); - torch::Tensor real_output = discriminator->forward(real_images); + torch::Tensor real_output = discriminator->forward(real_images).reshape(real_labels.sizes()); torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels); d_loss_real.backward(); @@ -977,7 +977,7 @@ the data loader every epoch and then write the GAN training code: torch::Tensor noise = torch::randn({batch.data.size(0), kNoiseSize, 1, 1}); torch::Tensor fake_images = generator->forward(noise); torch::Tensor fake_labels = torch::zeros(batch.data.size(0)); - torch::Tensor fake_output = discriminator->forward(fake_images.detach()); + torch::Tensor fake_output = discriminator->forward(fake_images.detach()).reshape(fake_labels.sizes()); torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels); d_loss_fake.backward(); @@ -987,7 +987,7 @@ the data loader every epoch and then write the GAN training code: // Train generator. generator->zero_grad(); fake_labels.fill_(1); - fake_output = discriminator->forward(fake_images); + fake_output = discriminator->forward(fake_images).reshape(fake_labels.sizes()); torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels); g_loss.backward(); generator_optimizer.step();