Skip to content

Commit 33855bd

Browse files
committed
..
1 parent ff66ad6 commit 33855bd

File tree

17 files changed

+1281
-1
lines changed

17 files changed

+1281
-1
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# Ignore files generated by IDEs and editors
22
gallery
33
libmoon.egg-info/
4-
4+
libmoon/moogan/data
5+
libmoon/moogan/images
56
archive/
67
libmoon/Output
78
.git/

libmoon/moogan/VAE_z2.pth

141 KB
Binary file not shown.

libmoon/moogan/download.sh

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
gsutil -m cp \
2+
"gs://quickdraw_dataset/full/numpy_bitmap/The Eiffel Tower.npy" \
3+
"gs://quickdraw_dataset/full/numpy_bitmap/The Great Wall of China.npy" \
4+
"gs://quickdraw_dataset/full/numpy_bitmap/The Mona Lisa.npy" \
5+
"gs://quickdraw_dataset/full/numpy_bitmap/aircraft carrier.npy" \
6+
"gs://quickdraw_dataset/full/numpy_bitmap/airplane.npy" \
7+
"gs://quickdraw_dataset/full/numpy_bitmap/alarm clock.npy" \
8+
"gs://quickdraw_dataset/full/numpy_bitmap/ambulance.npy" \
9+
"gs://quickdraw_dataset/full/numpy_bitmap/angel.npy" \
10+
"gs://quickdraw_dataset/full/numpy_bitmap/animal migration.npy" \
11+
"gs://quickdraw_dataset/full/numpy_bitmap/ant.npy" \
12+
"gs://quickdraw_dataset/full/numpy_bitmap/anvil.npy" \
13+
"gs://quickdraw_dataset/full/numpy_bitmap/apple.npy" \
14+
"gs://quickdraw_dataset/full/numpy_bitmap/arm.npy" \
15+
"gs://quickdraw_dataset/full/numpy_bitmap/asparagus.npy" \
16+
"gs://quickdraw_dataset/full/numpy_bitmap/axe.npy" \
17+
"gs://quickdraw_dataset/full/numpy_bitmap/backpack.npy" \
18+
"gs://quickdraw_dataset/full/numpy_bitmap/banana.npy" \
19+
"gs://quickdraw_dataset/full/numpy_bitmap/bandage.npy" \
20+
"gs://quickdraw_dataset/full/numpy_bitmap/barn.npy" \
21+
"gs://quickdraw_dataset/full/numpy_bitmap/baseball bat.npy" \
22+
"gs://quickdraw_dataset/full/numpy_bitmap/baseball.npy" \
23+
"gs://quickdraw_dataset/full/numpy_bitmap/basket.npy" \
24+
"gs://quickdraw_dataset/full/numpy_bitmap/basketball.npy" \
25+
"gs://quickdraw_dataset/full/numpy_bitmap/bat.npy" \
26+
"gs://quickdraw_dataset/full/numpy_bitmap/bathtub.npy" \
27+
"gs://quickdraw_dataset/full/numpy_bitmap/beach.npy" \
28+
"gs://quickdraw_dataset/full/numpy_bitmap/bear.npy" \
29+
"gs://quickdraw_dataset/full/numpy_bitmap/beard.npy" \
30+
"gs://quickdraw_dataset/full/numpy_bitmap/bed.npy" \
31+
"gs://quickdraw_dataset/full/numpy_bitmap/bee.npy" \
32+
"gs://quickdraw_dataset/full/numpy_bitmap/belt.npy" \
33+
"gs://quickdraw_dataset/full/numpy_bitmap/bench.npy" \
34+
"gs://quickdraw_dataset/full/numpy_bitmap/bicycle.npy" \
35+
"gs://quickdraw_dataset/full/numpy_bitmap/binoculars.npy" \
36+
"gs://quickdraw_dataset/full/numpy_bitmap/bird.npy" \
37+
"gs://quickdraw_dataset/full/numpy_bitmap/birthday cake.npy" \
38+
"gs://quickdraw_dataset/full/numpy_bitmap/blackberry.npy" \
39+
"gs://quickdraw_dataset/full/numpy_bitmap/blueberry.npy" \
40+
"gs://quickdraw_dataset/full/numpy_bitmap/book.npy" \
41+
"gs://quickdraw_dataset/full/numpy_bitmap/boomerang.npy" \
42+
"gs://quickdraw_dataset/full/numpy_bitmap/bottlecap.npy" \
43+
"gs://quickdraw_dataset/full/numpy_bitmap/bowtie.npy" \
44+
"gs://quickdraw_dataset/full/numpy_bitmap/bracelet.npy" \
45+
"gs://quickdraw_dataset/full/numpy_bitmap/brain.npy" \
46+
"gs://quickdraw_dataset/full/numpy_bitmap/bread.npy" \
47+
"gs://quickdraw_dataset/full/numpy_bitmap/bridge.npy" \
48+
"gs://quickdraw_dataset/full/numpy_bitmap/broccoli.npy" \
49+
"gs://quickdraw_dataset/full/numpy_bitmap/broom.npy" \
50+
"gs://quickdraw_dataset/full/numpy_bitmap/bucket.npy" \
51+
"gs://quickdraw_dataset/full/numpy_bitmap/bulldozer.npy" \

libmoon/moogan/gan.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
img_list = []
2+
G_losses = []
3+
D_losses = []
4+
iters = 0
5+
6+
print("Starting Training Loop...")
7+
for epoch in range(EPOCH_NUM):
8+
for i, data in enumerate(dataloader, 0):
9+
10+
# (1) Update the discriminator with real data
11+
netD.zero_grad()
12+
# Format batch
13+
real_cpu = data[0].to(device)
14+
b_size = real_cpu.size(0)
15+
label = torch.full((b_size,), REAL_LABEL, dtype=torch.float, device=device)
16+
# Forward pass real batch through D
17+
output = netD(real_cpu).view(-1)
18+
# Calculate loss on all-real batch
19+
errD_real = criterion(output, label)
20+
# Calculate gradients for D in backward pass
21+
errD_real.backward()
22+
D_x = output.mean().item()
23+
24+
# (2) Update the discriminator with fake data
25+
# Generate batch of latent vectors
26+
noise = torch.randn(b_size, Z_DIM, 1, 1, device=device)
27+
# Generate fake image batch with G
28+
fake = netG(noise)
29+
label.fill_(FAKE_LABEL)
30+
# Classify all fake batch with D
31+
output = netD(fake.detach()).view(-1)
32+
# Calculate D's loss on the all-fake batch
33+
errD_fake = criterion(output, label)
34+
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
35+
errD_fake.backward()
36+
D_G_z1 = output.mean().item()
37+
# Compute error of D as sum over the fake and the real batches
38+
errD = errD_real + errD_fake
39+
# Update D
40+
optimizerD.step()
41+
42+
# (3) Update the generator with fake data
43+
netG.zero_grad()
44+
label.fill_(REAL_LABEL) # fake labels are real for generator cost
45+
# Since we just updated D, perform another forward pass of all-fake batch through D
46+
output = netD(fake).view(-1)
47+
# Calculate G's loss based on this output
48+
errG = criterion(output, label)
49+
# Calculate gradients for G
50+
errG.backward()
51+
D_G_z2 = output.mean().item()
52+
# Update G
53+
optimizerG.step()
54+
55+
# Output training stats
56+
if i % 50 == 0:
57+
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
58+
% (epoch, EPOCH_NUM, i, len(dataloader),
59+
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
60+
61+
# Save Losses for plotting later
62+
G_losses.append(errG.item())
63+
D_losses.append(errD.item())
64+
65+
# Check how the generator is doing by saving G's output on fixed_noise
66+
if (iters % 500 == 0) or ((epoch == EPOCH_NUM-1) and (i == len(dataloader)-1)):
67+
with torch.no_grad():
68+
fake = netG(viz_noise).detach().cpu()
69+
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
70+
71+
iters += 1

libmoon/moogan/mnist_gan.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import math
2+
import pickle as pkl
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
import torch
6+
import torch.nn as nn
7+
import torch.optim as optim
8+
from torch.utils.data import DataLoader
9+
from torchvision import datasets, transforms
10+
11+
import os
12+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
13+
14+
15+
def display_images(images, n_cols=4, figsize=(12, 6)):
16+
"""
17+
Utility function to display a collection of images in a grid
18+
19+
Parameters
20+
----------
21+
images: Tensor
22+
tensor of shape (batch_size, channel, height, width)
23+
containing images to be displayed
24+
n_cols: int
25+
number of columns in the grid
26+
27+
Returns
28+
-------
29+
None
30+
"""
31+
plt.style.use('ggplot')
32+
n_images = len(images)
33+
n_rows = math.ceil(n_images / n_cols)
34+
plt.figure(figsize=figsize)
35+
for idx in range(n_images):
36+
ax = plt.subplot(n_rows, n_cols, idx + 1)
37+
image = images[idx]
38+
# make dims H x W x C
39+
image = image.permute(1, 2, 0)
40+
cmap = 'gray' if image.shape[2] == 1 else plt.cm.viridis
41+
ax.imshow(image, cmap=cmap)
42+
ax.set_xticks([])
43+
ax.set_yticks([])
44+
plt.tight_layout()
45+
plt.show()
46+
47+
48+
49+
50+
if __name__ == '__main__':
51+
transform = transforms.Compose([transforms.ToTensor()])
52+
train_ds = datasets.MNIST(root='./data',
53+
train=True,
54+
download=True,
55+
transform=transform)
56+
57+
print(train_ds.data.shape)
58+
print(train_ds.targets.shape)
59+
print(train_ds.classes)
60+
print(train_ds.data[0])
61+
print(train_ds.targets[0])
62+
print(train_ds.data[0].max())
63+
print(train_ds.data[0].min())
64+
print(train_ds.data[0].float().mean())
65+
print(train_ds.data[0].float().std())
66+
67+
# Build dataloader
68+
dl = DataLoader(dataset=train_ds, shuffle=True, batch_size=64)
69+
70+
image_batch = next(iter(dl))
71+
print(len(image_batch), type(image_batch))
72+
print(image_batch[0].shape)
73+
print(image_batch[1].shape)
74+
75+
# display_images(images=image_batch[0], n_cols=8)
76+
77+
78+
79+
80+
81+
82+
83+
84+
85+
86+
87+
88+

libmoon/moogan/moogan.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import matplotlib.pyplot as plt
2+
import torch
3+
import torch.nn as nn
4+
import torch.optim as optim
5+
import argparse
6+
import os
7+
import numpy as np
8+
9+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
10+
11+
def plot_figure(folder_name, generated_samples, sample1, sample2, pref):
12+
plt.scatter(generated_samples[:, 0], generated_samples[:, 1], label='Generated', s=50)
13+
plt.scatter(sample1[:, 0], sample1[:, 1], label='Sample 1', s=25, alpha=0.5)
14+
plt.scatter(sample2[:, 0], sample2[:, 1], label='Sample 2', s=25, alpha=0.5)
15+
if abs(pref[0]) < 1e-6:
16+
plt.legend(fontsize=20, loc='lower right')
17+
18+
plt.xlabel('$X_1$', fontsize=25)
19+
plt.ylabel('$X_2$', fontsize=25)
20+
plt.xticks(fontsize=20)
21+
plt.yticks(fontsize=20)
22+
plt.axis('equal')
23+
plt.plot([0, 4], [0, 4], linewidth=2, color='black')
24+
fig_name = os.path.join(folder_name, 'res_{:.2f}.pdf'.format(pref[0]))
25+
plt.savefig(fig_name, bbox_inches='tight')
26+
print('Save fig to {}'.format(fig_name))
27+
28+
29+
# Generator: Transforms random noise into samples resembling the target distribution
30+
class Generator(nn.Module):
31+
def __init__(self, input_dim, output_dim):
32+
super(Generator, self).__init__()
33+
self.model = nn.Sequential(
34+
nn.Linear(input_dim, 256),
35+
nn.ReLU(),
36+
nn.Linear(256, 256),
37+
nn.ReLU(),
38+
nn.Linear(256, output_dim)
39+
)
40+
41+
def forward(self, z):
42+
return self.model(z)
43+
44+
45+
# Discriminator: Classifies whether samples are real (from target Gaussian) or fake (from generator)
46+
class Discriminator(nn.Module):
47+
def __init__(self, input_dim):
48+
super(Discriminator, self).__init__()
49+
self.model = nn.Sequential(
50+
nn.Linear(input_dim, 256),
51+
nn.ReLU(),
52+
nn.Linear(256, 256),
53+
nn.ReLU(),
54+
nn.Linear(256, 1),
55+
nn.Sigmoid()
56+
)
57+
58+
def forward(self, x):
59+
return self.model(x)
60+
61+
62+
# Function to sample from a target Gaussian distribution
63+
def sample_multiple_gaussian(batch_size, dim, mean=0, std=1):
64+
dist1 = torch.normal(mean=mean, std=std, size=(batch_size, dim))
65+
dist2 = torch.normal(mean=mean + 4, std=std, size=(batch_size, dim))
66+
return dist1, dist2
67+
68+
69+
class MOGANTrainer:
70+
def __init__(self, lr, num_epochs, batch_size, n_obj, pref, input_dim, output_dim):
71+
'''
72+
:param lr, float: learning rate.
73+
:param num_epochs, int : number of epochs.
74+
:param batch_size, int : batch size.
75+
:param n_obj, int : number of objectives.
76+
:param pref, np.array : preference vector.
77+
'''
78+
self.lr = lr
79+
self.pref = pref
80+
self.num_epochs = num_epochs
81+
self.batch_size = batch_size
82+
self.n_obj = n_obj
83+
self.input_dim = input_dim
84+
self.output_dim = output_dim
85+
86+
self.generator = Generator(args.input_dim, args.output_dim)
87+
self.discriminator_arr = [Discriminator(args.output_dim) for _ in range(n_obj)]
88+
self.d_optimizer_arr = [optim.Adam(discriminator.parameters(), lr=args.lr) for discriminator in
89+
self.discriminator_arr]
90+
self.g_optimizer = optim.Adam(self.generator.parameters(), lr=args.lr)
91+
self.criterion = nn.BCELoss()
92+
93+
def train(self):
94+
d_loss_arr = []
95+
for epoch in range(self.num_epochs):
96+
for _ in range(1): # Training discriminator more than generator improves stability
97+
real_samples_1, real_samples_2 = sample_multiple_gaussian(self.batch_size,
98+
self.output_dim) # Real samples from Gaussian distribution
99+
real_samples_arr = [real_samples_1, real_samples_2]
100+
z = torch.randn(self.batch_size, self.input_dim) # Random noise
101+
fake_samples = self.generator(z) # Fake samples from generator
102+
103+
for idx, discriminator in enumerate(self.discriminator_arr):
104+
real_samples = real_samples_arr[idx]
105+
d_real = discriminator(real_samples)
106+
d_fake = discriminator(
107+
fake_samples.detach()) # Detach to avoid backpropagating through the generator
108+
109+
real_loss = self.criterion(d_real, torch.ones_like(d_real))
110+
fake_loss = self.criterion(d_fake, torch.zeros_like(d_fake))
111+
d_loss = (real_loss + fake_loss) / 2
112+
113+
self.d_optimizer_arr[idx].zero_grad()
114+
d_loss.backward()
115+
self.d_optimizer_arr[idx].step()
116+
117+
# Generator training
118+
z = torch.randn(self.batch_size, self.input_dim)
119+
fake_samples = self.generator(z)
120+
121+
g_loss_arr = []
122+
for idx, discriminator in enumerate(self.discriminator_arr):
123+
d_fake = discriminator(fake_samples)
124+
g_loss = self.criterion(d_fake, torch.ones_like(d_fake))
125+
g_loss_arr.append(g_loss)
126+
127+
g_loss_arr = torch.stack(g_loss_arr)
128+
self.g_optimizer.zero_grad()
129+
scalar_loss = torch.dot(g_loss_arr, torch.Tensor(self.pref))
130+
scalar_loss.backward()
131+
self.g_optimizer.step()
132+
# Logging
133+
if (epoch + 1) % 100 == 0:
134+
print(
135+
f'Epoch [{epoch + 1}/{self.num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')
136+
d_loss_arr.append(d_loss.item())
137+
138+
def generate_samples(self, test_size):
139+
with torch.no_grad():
140+
z = torch.randn(test_size, self.input_dim)
141+
generated_samples = self.generator(z)
142+
real_samples_1, real_samples_2 = sample_multiple_gaussian(self.batch_size,
143+
self.output_dim) # Real samples from Gaussian distribution
144+
return generated_samples.numpy(), real_samples_1, real_samples_2
145+
146+
147+
if __name__ == '__main__':
148+
parser = argparse.ArgumentParser(description='example script')
149+
parser.add_argument('--input-dim', type=int, default=10) # What does it mean?
150+
parser.add_argument('--output-dim', type=int, default=2)
151+
parser.add_argument('--n-obj', type=int, default=2)
152+
153+
parser.add_argument('--batch-size', type=int, default=64)
154+
parser.add_argument('--test-size', type=int, default=64)
155+
parser.add_argument('--num-epochs', type=int, default=500)
156+
parser.add_argument('--lr', type=float, default=5e-6)
157+
parser.add_argument('--pref0', type=float, default=0.2)
158+
159+
# Hyperparameters
160+
args = parser.parse_args()
161+
pref = np.array([args.pref0, 1 - args.pref0])
162+
print('Preference: ', pref)
163+
# Model, optimizer, and loss function
164+
165+
trainer = MOGANTrainer(lr=args.lr, num_epochs=args.num_epochs,
166+
n_obj=args.n_obj, pref=pref, batch_size=args.batch_size, input_dim=args.input_dim,
167+
output_dim=args.output_dim)
168+
169+
trainer.train()
170+
generate_samples, sample1, sample2 = trainer.generate_samples(args.test_size)
171+
172+
folder_name = 'D:\\pycharm_project\\libmoon\\Output\\divergence'
173+
os.makedirs(folder_name, exist_ok=True)
174+
plot_figure(folder_name, generate_samples, sample1, sample2, pref)

libmoon/moogan/moogan.rar

12.1 KB
Binary file not shown.

0 commit comments

Comments
 (0)