44from torchvision import datasets
55from torchvision .utils import save_image
66import torch .nn .functional as F
7- import matplotlib .pyplot as plt
87import os
98from torchvision .utils import make_grid
109import argparse
1110import numpy as np
1211from tqdm import tqdm
1312from libmoon .util .constant import root_name
14-
1513from modm_func import mokl
14+ from libmoon .util .general import FolderDataset
15+ from torch .utils .data import DataLoader
1616
17-
18- # 创建文件夹
1917device = 'cuda' if torch .cuda .is_available () else 'cpu'
2018
21-
2219def to_img (x ):
2320 img = make_grid (x , nrow = 8 , normalize = True ).detach ()
2421 return img
@@ -62,7 +59,6 @@ def forward(self, x):
6259 out1 , out2 = self .encoder (x ), self .encoder (x )
6360 mean = self .encoder_fc1 (out1 .view (out1 .shape [0 ], - 1 ))
6461 logstd = self .encoder_fc2 (out2 .view (out2 .shape [0 ], - 1 ))
65-
6662 z = self .noise_reparameterize (mean , logstd )
6763 out3 = self .decoder_fc (z )
6864 out3 = out3 .view (out3 .shape [0 ], 32 , 7 , 7 )
@@ -85,6 +81,11 @@ def loss_function(recon_x, x, mean, std):
8581 parser .add_argument ('--batch-size' , type = int , default = 64 )
8682 parser .add_argument ('--data-name1' , type = str , default = 'alarm' )
8783 parser .add_argument ('--data-name2' , type = str , default = 'circle' )
84+ parser .add_argument ('--data-type' , type = str , default = 'domainnet' )
85+ parser .add_argument ('--domain-set-data' , type = str , default = 'airplane' )
86+ parser .add_argument ('--domain1' , type = str , default = 'clipart' )
87+ parser .add_argument ('--domain2' , type = str , default = 'infograph' )
88+
8889 parser .add_argument ('--n-epochs' , type = int , default = 100 )
8990 parser .add_argument ('--z-dimension' , type = int , default = 2 )
9091 parser .add_argument ('--lr' , type = float , default = 3e-4 )
@@ -94,40 +95,56 @@ def loss_function(recon_x, x, mean, std):
9495 img_transform = transforms .Compose ([
9596 transforms .ToTensor (),
9697 ])
97- if args .data_name1 == 'mnist' :
98+ if args .data_type == 'mnist' :
9899 # mnist dataset mnist数据集下载
99100 mnist = datasets .MNIST (root = './data/' , train = True , transform = img_transform , download = True )
100101 # data loader 数据载入
101102 dataloader = torch .utils .data .DataLoader (dataset = mnist , batch_size = args .batch_size , shuffle = True )
103+ elif args .data_type == 'domainnet' :
104+ # F:\code\libmoon\libmoon\moogan\data\domainnet
105+ path1 = os .path .join (
106+ root_name , 'libmoon' , 'moogan' , 'data' , 'domainnet' , args .domain1 , args .domain_set_data
107+ )
108+ path2 = os .path .join (
109+ root_name , 'libmoon' , 'moogan' , 'data' , 'domainnet' , args .domain2 , args .domain_set_data
110+ )
111+
112+ dataset1 = FolderDataset (path1 )
113+ dataset2 = FolderDataset (path2 )
114+
115+
116+ dataloader1 = DataLoader (dataset1 , batch_size = args .batch_size , shuffle = True )
117+ dataloader2 = DataLoader (dataset2 , batch_size = args .batch_size , shuffle = True )
102118 else :
103- path1 = os .path .join (root_name , 'libmoon' , 'moogan' , 'data' , 'quick_draw' ,'full_numpy_bitmap_{}.npy' .format (args .data_name1 ))
119+ path1 = os .path .join (root_name , 'libmoon' , 'moogan' , 'data' , 'quick_draw' ,
120+ 'full_numpy_bitmap_{}.npy' .format (args .data_name1 ))
104121 img1_data = np .load (path1 )
105122 img1_data = img1_data .reshape (- 1 , 1 , 28 , 28 )
106123 img1_data = img1_data / 255
107- path2 = os .path .join (root_name , 'libmoon' , 'moogan' , 'data' , 'quick_draw' , 'full_numpy_bitmap_{}.npy' .format (args .data_name2 ))
124+ path2 = os .path .join (root_name , 'libmoon' , 'moogan' , 'data' , 'quick_draw' ,
125+ 'full_numpy_bitmap_{}.npy' .format (args .data_name2 ))
108126 img2_data = np .load (path2 )
109127 img2_data = img2_data .reshape (- 1 , 1 , 28 , 28 )
110128 img2_data = img2_data / 255
111129 img1_data = torch .from_numpy (img1_data ).to (torch .float ).to (device )
112130 img2_data = torch .from_numpy (img2_data ).to (torch .float ).to (device )
113131 print ('img1_data size: ' , len (img1_data ))
114- dataloader = dataloader1 = torch . utils . data . DataLoader (img1_data , batch_size = args .batch_size , shuffle = True )
115- dataloader2 = torch . utils . data . DataLoader (img2_data , batch_size = args .batch_size , shuffle = True )
132+ dataloader1 = DataLoader (img1_data , batch_size = args .batch_size , shuffle = True )
133+ dataloader2 = DataLoader (img2_data , batch_size = args .batch_size , shuffle = True )
116134
117135 vae = VAE ().to (device )
118136 num1 = numel (vae .encoder )
119137 num2 = numel (vae .decoder )
120138 print ()
121139 vae_optimizer = torch .optim .Adam (vae .parameters (), lr = args .lr ,
122140 betas = (0.9 , 0.999 ), eps = 1e-08 , weight_decay = 0 )
123- ###########################进入训练##判别器的判断过程#####################
141+
124142 for epoch in range (args .n_epochs ): # 进行多个epoch的训练
125- for i , (img , img2 ) in tqdm (enumerate (zip (dataloader , dataloader2 ))):
126- num_img = img .size (0 )
127- # view()函数作用把img变成[batch_size,channel_size,784]
128- img = img .view (num_img , 1 , 28 , 28 ).to (device ) # 将图片展开为28*28=784
129- x , mean1 , logstd1 = vae (img ) # 将真实图片放入判别器中
130- loss = loss_function (x , img , mean1 , logstd1 )
143+ for i , (img1 , img2 ) in tqdm (enumerate (zip (dataloader1 , dataloader2 ))):
144+ num_img = img1 .size (0 )
145+ img1 = img1 .view (num_img , 1 , 28 , 28 ).to (device ) # 将图片展开为28*28=784
146+ x , mean1 , logstd1 = vae (img1 ) # 将真实图片放入判别器中
147+ loss = loss_function (x , img1 , mean1 , logstd1 )
131148 vae_optimizer .zero_grad () # 在反向传播之前,先将梯度归 0.
132149 loss .backward () # 将误差反向传播
133150 vae_optimizer .step () # 更新参数
0 commit comments