Skip to content

Commit b43d825

Browse files
committed
X
1 parent b858ced commit b43d825

File tree

4 files changed

+89
-64
lines changed

4 files changed

+89
-64
lines changed

libmoon/moogan/VAE_z2.pth

58.3 KB
Binary file not shown.

libmoon/moogan/mnist_gan.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,19 @@
66
import os
77
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
88

9-
109
def display_images(images, n_cols=4, figsize=(12, 6)):
1110
"""
12-
Utility function to display a collection of images in a grid
13-
14-
Parameters
15-
----------
16-
images: Tensor
17-
tensor of shape (batch_size, channel, height, width)
18-
containing images to be displayed
19-
n_cols: int
20-
number of columns in the grid
21-
22-
Returns
23-
-------
24-
None
11+
Utility function to display a collection of images in a grid
12+
Parameters
13+
----------
14+
images: Tensor
15+
tensor of shape (batch_size, channel, height, width)
16+
containing images to be displayed
17+
n_cols: int
18+
number of columns in the grid
19+
Returns
20+
-------
21+
None
2522
"""
2623
plt.style.use('ggplot')
2724
n_images = len(images)

libmoon/moogan/modm_func.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ def mokl(mu1, mu2, Std1, Std2, pref0):
1717
mu_output = torch.stack(mu_output)
1818
return mu_output, Sigma_output
1919

20+
2021
if __name__ == '__main__':
2122
print()

libmoon/moogan/moovae.py

Lines changed: 77 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
from libmoon.util.general import FolderDataset
1515
from torch.utils.data import DataLoader
1616

17+
os.environ["PYTORCH_USE_CUDA_DSA"] = "1"
18+
os.environ['CUDA_LAUNCH_BLOCKING']="1"
19+
os.environ['TORCH_USE_CUDA_DSA'] = "1"
20+
1721
device = 'cuda' if torch.cuda.is_available() else 'cpu'
1822

1923
def to_img(x):
@@ -25,11 +29,12 @@ def numel(model):
2529

2630

2731
class VAE(nn.Module):
28-
def __init__(self):
32+
def __init__(self, n_channels):
33+
self.n_channels = n_channels
2934
super(VAE, self).__init__()
3035
# 定义编码器
3136
self.encoder = nn.Sequential(
32-
nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
37+
nn.Conv2d(n_channels, 16, kernel_size=3, stride=2, padding=1),
3338
nn.BatchNorm2d(16),
3439
nn.LeakyReLU(0.2, inplace=True),
3540
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
@@ -46,7 +51,7 @@ def __init__(self):
4651
self.decoder = nn.Sequential(
4752
nn.ConvTranspose2d(32, 16, 4, 2, 1),
4853
nn.ReLU(inplace=True),
49-
nn.ConvTranspose2d(16, 1, 4, 2, 1),
54+
nn.ConvTranspose2d(16, self.n_channels, 4, 2, 1),
5055
nn.Sigmoid(),
5156
)
5257

@@ -56,18 +61,32 @@ def noise_reparameterize(self, mean, logvar):
5661
return z
5762

5863
def forward(self, x):
64+
# print('x.shape', x.shape)
65+
# assert False
66+
# x.shape: (64,3,28,28)
5967
out1, out2 = self.encoder(x), self.encoder(x)
6068
mean = self.encoder_fc1(out1.view(out1.shape[0], -1))
6169
logstd = self.encoder_fc2(out2.view(out2.shape[0], -1))
6270
z = self.noise_reparameterize(mean, logstd)
63-
out3 = self.decoder_fc(z)
64-
out3 = out3.view(out3.shape[0], 32, 7, 7)
65-
out3 = self.decoder(out3)
66-
return out3, mean, logstd
67-
68-
def loss_function(recon_x, x, mean, std):
69-
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
70-
# 因为var是标准差的自然对数,先求自然对数然后平方转换成方差
71+
decoded_img = self.decoder_fc(z)
72+
decoded_img = decoded_img.view(decoded_img.shape[0], 32, 7, 7)
73+
decoded_img = self.decoder(decoded_img)
74+
return decoded_img.to(device), mean.to(device), logstd.to(device)
75+
76+
77+
def vae_loss_function(recon_x, x, mean, std):
78+
_, n_channels, _, _ = x.size()
79+
BCE_arr = []
80+
for channel_idx in range(n_channels):
81+
# print('max recon_x', torch.max(recon_x[:, channel_idx, :, :]))
82+
# print('min recon_x', torch.min(recon_x[:, channel_idx, :, :]))
83+
# print('max x', torch.max(x[:, channel_idx, :, :]))
84+
# print('min x', torch.min(x[:, channel_idx, :, :]))
85+
BCE_i = F.binary_cross_entropy(recon_x[:, channel_idx, :, :], x[:, channel_idx, :, :],
86+
reduction='sum')
87+
# print('BCE_i', BCE_i)
88+
BCE_arr.append( BCE_i )
89+
BCE = torch.sum(torch.stack(BCE_arr))
7190
var = torch.pow(torch.exp(std), 2)
7291
KLD = -0.5 * torch.sum(1 + torch.log(var) - torch.pow(mean, 2) - var)
7392
return BCE + KLD
@@ -81,45 +100,51 @@ def loss_function(recon_x, x, mean, std):
81100
parser.add_argument('--batch-size', type=int, default=64)
82101
parser.add_argument('--data-name1', type=str, default='alarm')
83102
parser.add_argument('--data-name2', type=str, default='circle')
84-
parser.add_argument('--data-type', type=str, default='domainnet')
103+
parser.add_argument('--data-type', type=str, default='domainnet') # Category: [domainnet, 'quickdraw']
85104
parser.add_argument('--domain-set-data', type=str, default='airplane')
86-
parser.add_argument('--domain1', type=str, default='clipart')
87-
parser.add_argument('--domain2', type=str, default='infograph')
88-
89-
parser.add_argument('--n-epochs', type=int, default=100)
90-
parser.add_argument('--z-dimension', type=int, default=2)
105+
parser.add_argument('--domain1', type=str, default='real')
106+
parser.add_argument('--domain2', type=str, default='quickdraw')
107+
parser.add_argument('--n-epochs', type=int, default=5000)
108+
parser.add_argument('--z-dimension', type=int, default=5)
91109
parser.add_argument('--lr', type=float, default=3e-4)
92110
parser.add_argument('--pref0', type=float, default=0.0)
93111
args = parser.parse_args()
94112
# batch_size = 64, # num_epoch = 15, # z_dimension = 2
95113
img_transform = transforms.Compose([
96114
transforms.ToTensor(),
97115
])
116+
98117
if args.data_type == 'mnist':
99-
# mnist dataset mnist数据集下载
100118
mnist = datasets.MNIST(root='./data/', train=True, transform=img_transform, download=True)
101-
# data loader 数据载入
102119
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=args.batch_size, shuffle=True)
120+
n_channels = 1
103121
elif args.data_type == 'domainnet':
104-
# F:\code\libmoon\libmoon\moogan\data\domainnet
105122
path1 = os.path.join(
106123
root_name, 'libmoon', 'moogan', 'data', 'domainnet', args.domain1, args.domain_set_data
107124
)
125+
108126
path2 = os.path.join(
109127
root_name, 'libmoon', 'moogan', 'data', 'domainnet', args.domain2, args.domain_set_data
110128
)
111129

112130
dataset1 = FolderDataset(path1)
113131
dataset2 = FolderDataset(path2)
114-
132+
print(dataset1[0].shape)
133+
print(dataset2[0].shape)
134+
print('len dataset1', len(dataset1))
135+
print('len dataset2', len(dataset2))
115136

116137
dataloader1 = DataLoader(dataset1, batch_size=args.batch_size, shuffle=True)
117138
dataloader2 = DataLoader(dataset2, batch_size=args.batch_size, shuffle=True)
118-
else:
139+
n_channels = 3
140+
141+
elif args.data_type == 'quickdraw':
119142
path1 = os.path.join(root_name, 'libmoon', 'moogan', 'data', 'quick_draw',
120143
'full_numpy_bitmap_{}.npy'.format(args.data_name1))
121144
img1_data = np.load(path1)
122-
img1_data = img1_data.reshape(-1, 1, 28, 28)
145+
if args.data_type == 'quickdraw':
146+
img1_data = img1_data.reshape(-1, 1, 28, 28)
147+
123148
img1_data = img1_data / 255
124149
path2 = os.path.join(root_name, 'libmoon', 'moogan', 'data', 'quick_draw',
125150
'full_numpy_bitmap_{}.npy'.format(args.data_name2))
@@ -131,63 +156,67 @@ def loss_function(recon_x, x, mean, std):
131156
print('img1_data size: ', len(img1_data))
132157
dataloader1 = DataLoader(img1_data, batch_size=args.batch_size, shuffle=True)
133158
dataloader2 = DataLoader(img2_data, batch_size=args.batch_size, shuffle=True)
159+
n_channels = 1
160+
else:
161+
assert False, 'dataset not implemented'
134162

135-
vae = VAE().to(device)
163+
vae = VAE(n_channels=n_channels).to(device)
136164
num1 = numel(vae.encoder)
137165
num2 = numel(vae.decoder)
138-
print()
139166
vae_optimizer = torch.optim.Adam(vae.parameters(), lr=args.lr,
140167
betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
141-
142-
for epoch in range(args.n_epochs): # 进行多个epoch的训练
168+
for epoch in range(args.n_epochs):
143169
for i, (img1, img2) in tqdm(enumerate(zip(dataloader1, dataloader2))):
170+
img1 = img1.to(device)
171+
img2 = img2.to(device)
144172
num_img = img1.size(0)
145-
img1 = img1.view(num_img, 1, 28, 28).to(device) # 将图片展开为28*28=784
146-
x, mean1, logstd1 = vae(img1) # 将真实图片放入判别器中
147-
loss = loss_function(x, img1, mean1, logstd1)
173+
if args.data_type == 'quickdraw':
174+
img1 = img1.view(num_img, 1, 28, 28).to(device) # 将图片展开为28*28=784
175+
176+
decoded_img1, mean1, logstd1 = vae(img1)
177+
178+
loss1 = vae_loss_function(decoded_img1, img1, mean1, logstd1)
148179
vae_optimizer.zero_grad() # 在反向传播之前,先将梯度归 0.
149-
loss.backward() # 将误差反向传播
180+
loss1.backward()
150181
vae_optimizer.step() # 更新参数
151182

152183
num_img2 = img2.size(0)
153-
img2 = img2.view(num_img2, 1, 28, 28).to(device) # 将图片展开为 28*28=784.
154-
x, mean2, logstd2 = vae(img2) # 将真实图片放入判别器中
155-
loss = loss_function(x, img2, mean2, logstd2)
156-
vae_optimizer.zero_grad() # 在反向传播之前,先将梯度归0
157-
loss.backward() # 将误差反向传播
184+
if args.data_type == 'quickdraw':
185+
img2 = img2.view(num_img2, 1, 28, 28).to(device) # 将图片展开为 28*28=784.
186+
187+
decoded_img2, mean2, logstd2 = vae(img2)
188+
loss2 = vae_loss_function(decoded_img2, img2, mean2, logstd2)
189+
vae_optimizer.zero_grad()
190+
loss2.backward()
158191
vae_optimizer.step()
159192

160193
if (i + 1) % 100 == 0:
161194
print('Epoch[{}/{}],vae_loss:{:.6f} '.format(
162195
epoch, args.n_epochs, loss.item(),
163196
))
164197

165-
folder_name = os.path.join(root_name, 'libmoon', 'moogan', 'img_VAE',
166-
'{}_{}'.format(args.data_name1, args.data_name2))
167-
os.makedirs(folder_name, exist_ok=True)
198+
if args.data_type == 'quickdraw':
199+
folder_name = os.path.join(root_name, 'libmoon', 'moogan', 'img_VAE', args.data_type,
200+
'{}_{}'.format(args.data_name1, args.data_name2))
201+
else:
202+
folder_name = os.path.join(root_name, 'libmoon', 'moogan', 'img_VAE', args.data_type, args.domain_set_data,
203+
'{}_{}'.format(args.domain1, args.domain2))
168204

205+
os.makedirs(folder_name, exist_ok=True)
169206
if epoch == 0:
170-
real_images1 = make_grid(img[:25].cpu(), nrow=5, normalize=True).detach()
207+
real_images1 = make_grid(img1[:25].cpu(), nrow=5, normalize=True).detach()
171208
save_image(real_images1, os.path.join(folder_name, 'real_images1.pdf'))
172209
real_images2 = make_grid(img2[:25].cpu(), nrow=5, normalize=True).detach()
173210
save_image(real_images2, os.path.join(folder_name, 'real_images2.pdf'))
174211
# sample_size = 25
175212
pref0_arr = np.linspace(0, 1, 5)
176213

177-
178214
if i == 0:
179215
for pref0 in pref0_arr:
180-
# mean1 (mean2).shape: (64,2)
181-
# meanA = torch.mean(mean1)
182-
# meanB = torch.mean(mean2)
183216
Std1 = torch.exp(logstd1)
184217
Std2 = torch.exp(logstd2)
185218
mu, std = mokl(mean1, mean2, Std1, Std2, pref0)
186-
# mu.shape: 64*2
187-
# std.shape: 64*2*2
188219
std = torch.diagonal(std, dim1=1, dim2=2)
189-
# print(mu.shape)
190-
# print(std.shape)
191220
sample_size = len(mu)
192221
sample = torch.randn(sample_size, args.z_dimension).to(device) * std + mu
193222
output = vae.decoder_fc(sample)
@@ -196,7 +225,5 @@ def loss_function(recon_x, x, mean, std):
196225
fig_name = os.path.join(folder_name, 'fake_images_{}_{:.2f}.pdf'.format(epoch + 16, pref0))
197226
save_image(fake_images, fig_name)
198227
print('img saved in', fig_name)
199-
200-
201228
# 保存模型
202229
torch.save(vae.state_dict(), './VAE_z2.pth')

0 commit comments

Comments
 (0)