Skip to content

Commit b858ced

Browse files
committed
X
1 parent 2f59cc1 commit b858ced

File tree

4 files changed

+58
-28
lines changed

4 files changed

+58
-28
lines changed

libmoon/moogan/mnist_gan.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11
import math
2-
import pickle as pkl
3-
import numpy as np
42
import matplotlib.pyplot as plt
5-
import torch
6-
import torch.nn as nn
7-
import torch.optim as optim
83
from torch.utils.data import DataLoader
94
from torchvision import datasets, transforms
105

libmoon/moogan/modm_func.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,19 @@
11
import torch
22
from torch import Tensor
33

4-
54
def mokl(mu1, mu2, Std1, Std2, pref0):
65
# step 1, convert std1 into diag.
7-
86
Std1_mtx = torch.diag_embed(Std1)
97
Std2_mtx = torch.diag_embed(Std2)
10-
118
Sigma_output = []
129
for mtx1, mtx2 in zip(Std1_mtx, Std2_mtx):
1310
mtx = torch.inverse(pref0 * torch.inverse(mtx1) + (1-pref0) * torch.inverse(mtx2))
1411
Sigma_output.append(mtx)
1512
Sigma_output = torch.stack(Sigma_output)
16-
1713
mu_output = []
1814
for mu1_i, mu2_i, Sigma_i, std1_i, std2_i in zip(mu1, mu2, Sigma_output, Std1_mtx, Std2_mtx):
1915
mu = Sigma_i @ (pref0 * torch.inverse(std1_i) @ mu1_i + (1-pref0) * torch.inverse(std2_i) @ mu2_i )
2016
mu_output.append(mu)
21-
2217
mu_output = torch.stack(mu_output)
2318
return mu_output, Sigma_output
2419

libmoon/moogan/moovae.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,18 @@
44
from torchvision import datasets
55
from torchvision.utils import save_image
66
import torch.nn.functional as F
7-
import matplotlib.pyplot as plt
87
import os
98
from torchvision.utils import make_grid
109
import argparse
1110
import numpy as np
1211
from tqdm import tqdm
1312
from libmoon.util.constant import root_name
14-
1513
from modm_func import mokl
14+
from libmoon.util.general import FolderDataset
15+
from torch.utils.data import DataLoader
1616

17-
18-
# 创建文件夹
1917
device = 'cuda' if torch.cuda.is_available() else 'cpu'
2018

21-
2219
def to_img(x):
2320
img = make_grid(x, nrow=8, normalize=True).detach()
2421
return img
@@ -62,7 +59,6 @@ def forward(self, x):
6259
out1, out2 = self.encoder(x), self.encoder(x)
6360
mean = self.encoder_fc1(out1.view(out1.shape[0], -1))
6461
logstd = self.encoder_fc2(out2.view(out2.shape[0], -1))
65-
6662
z = self.noise_reparameterize(mean, logstd)
6763
out3 = self.decoder_fc(z)
6864
out3 = out3.view(out3.shape[0], 32, 7, 7)
@@ -85,6 +81,11 @@ def loss_function(recon_x, x, mean, std):
8581
parser.add_argument('--batch-size', type=int, default=64)
8682
parser.add_argument('--data-name1', type=str, default='alarm')
8783
parser.add_argument('--data-name2', type=str, default='circle')
84+
parser.add_argument('--data-type', type=str, default='domainnet')
85+
parser.add_argument('--domain-set-data', type=str, default='airplane')
86+
parser.add_argument('--domain1', type=str, default='clipart')
87+
parser.add_argument('--domain2', type=str, default='infograph')
88+
8889
parser.add_argument('--n-epochs', type=int, default=100)
8990
parser.add_argument('--z-dimension', type=int, default=2)
9091
parser.add_argument('--lr', type=float, default=3e-4)
@@ -94,40 +95,56 @@ def loss_function(recon_x, x, mean, std):
9495
img_transform = transforms.Compose([
9596
transforms.ToTensor(),
9697
])
97-
if args.data_name1 == 'mnist':
98+
if args.data_type == 'mnist':
9899
# mnist dataset mnist数据集下载
99100
mnist = datasets.MNIST(root='./data/', train=True, transform=img_transform, download=True)
100101
# data loader 数据载入
101102
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=args.batch_size, shuffle=True)
103+
elif args.data_type == 'domainnet':
104+
# F:\code\libmoon\libmoon\moogan\data\domainnet
105+
path1 = os.path.join(
106+
root_name, 'libmoon', 'moogan', 'data', 'domainnet', args.domain1, args.domain_set_data
107+
)
108+
path2 = os.path.join(
109+
root_name, 'libmoon', 'moogan', 'data', 'domainnet', args.domain2, args.domain_set_data
110+
)
111+
112+
dataset1 = FolderDataset(path1)
113+
dataset2 = FolderDataset(path2)
114+
115+
116+
dataloader1 = DataLoader(dataset1, batch_size=args.batch_size, shuffle=True)
117+
dataloader2 = DataLoader(dataset2, batch_size=args.batch_size, shuffle=True)
102118
else:
103-
path1 = os.path.join(root_name, 'libmoon', 'moogan', 'data', 'quick_draw','full_numpy_bitmap_{}.npy'.format(args.data_name1))
119+
path1 = os.path.join(root_name, 'libmoon', 'moogan', 'data', 'quick_draw',
120+
'full_numpy_bitmap_{}.npy'.format(args.data_name1))
104121
img1_data = np.load(path1)
105122
img1_data = img1_data.reshape(-1, 1, 28, 28)
106123
img1_data = img1_data / 255
107-
path2 = os.path.join(root_name, 'libmoon', 'moogan', 'data', 'quick_draw', 'full_numpy_bitmap_{}.npy'.format(args.data_name2))
124+
path2 = os.path.join(root_name, 'libmoon', 'moogan', 'data', 'quick_draw',
125+
'full_numpy_bitmap_{}.npy'.format(args.data_name2))
108126
img2_data = np.load(path2)
109127
img2_data = img2_data.reshape(-1, 1, 28, 28)
110128
img2_data = img2_data / 255
111129
img1_data = torch.from_numpy(img1_data).to(torch.float).to(device)
112130
img2_data = torch.from_numpy(img2_data).to(torch.float).to(device)
113131
print('img1_data size: ', len(img1_data))
114-
dataloader = dataloader1 = torch.utils.data.DataLoader(img1_data, batch_size=args.batch_size, shuffle=True)
115-
dataloader2 = torch.utils.data.DataLoader(img2_data, batch_size=args.batch_size, shuffle=True)
132+
dataloader1 = DataLoader(img1_data, batch_size=args.batch_size, shuffle=True)
133+
dataloader2 = DataLoader(img2_data, batch_size=args.batch_size, shuffle=True)
116134

117135
vae = VAE().to(device)
118136
num1 = numel(vae.encoder)
119137
num2 = numel(vae.decoder)
120138
print()
121139
vae_optimizer = torch.optim.Adam(vae.parameters(), lr=args.lr,
122140
betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
123-
###########################进入训练##判别器的判断过程#####################
141+
124142
for epoch in range(args.n_epochs): # 进行多个epoch的训练
125-
for i, (img, img2) in tqdm(enumerate(zip(dataloader, dataloader2))):
126-
num_img = img.size(0)
127-
# view()函数作用把img变成[batch_size,channel_size,784]
128-
img = img.view(num_img, 1, 28, 28).to(device) # 将图片展开为28*28=784
129-
x, mean1, logstd1 = vae(img) # 将真实图片放入判别器中
130-
loss = loss_function(x, img, mean1, logstd1)
143+
for i, (img1, img2) in tqdm(enumerate(zip(dataloader1, dataloader2))):
144+
num_img = img1.size(0)
145+
img1 = img1.view(num_img, 1, 28, 28).to(device) # 将图片展开为28*28=784
146+
x, mean1, logstd1 = vae(img1) # 将真实图片放入判别器中
147+
loss = loss_function(x, img1, mean1, logstd1)
131148
vae_optimizer.zero_grad() # 在反向传播之前,先将梯度归 0.
132149
loss.backward() # 将误差反向传播
133150
vae_optimizer.step() # 更新参数

libmoon/util/general.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22
import numpy as np
33
import torch
44

5+
from torch.utils.data import Dataset
6+
from torch.utils.data import DataLoader
7+
import os
8+
import cv2
9+
10+
# Read an image
11+
# image = cv2.imread("image.jpg")
12+
13+
514

615
def set_indicators_rank(Indicators, indicator_dict_dict_saved, mtd_arr):
716
for indicator in Indicators:
@@ -32,8 +41,22 @@ def get_indicator(problem_name, mtd_name, num_seed, use_save=False):
3241
return mean_indicator_dict, std_indicator_dict
3342

3443

44+
class FolderDataset(Dataset):
45+
def __init__(self, folder_name):
46+
self.folder_name = folder_name
47+
file_names = [f for f in os.listdir(self.folder_name) if os.path.isfile(os.path.join(self.folder_name, f))]
48+
image_array = []
49+
for idx, file_name in enumerate(file_names):
50+
file_path = os.path.join(self.folder_name, file_name)
51+
image = cv2.imread(file_path)
52+
image_array.append(image)
53+
self.image_array = image_array
3554

55+
def __getitem__(self, idx):
56+
return self.image_array[idx]
3657

58+
def __len__(self):
59+
return len(self.image_array)
3760

3861

3962
def random_everything(seed):

0 commit comments

Comments
 (0)