1414from libmoon .util .general import FolderDataset
1515from torch .utils .data import DataLoader
1616
17+ os .environ ["PYTORCH_USE_CUDA_DSA" ] = "1"
18+ os .environ ['CUDA_LAUNCH_BLOCKING' ]= "1"
19+ os .environ ['TORCH_USE_CUDA_DSA' ] = "1"
20+
1721device = 'cuda' if torch .cuda .is_available () else 'cpu'
1822
1923def to_img (x ):
@@ -25,11 +29,12 @@ def numel(model):
2529
2630
2731class VAE (nn .Module ):
28- def __init__ (self ):
32+ def __init__ (self , n_channels ):
33+ self .n_channels = n_channels
2934 super (VAE , self ).__init__ ()
3035 # 定义编码器
3136 self .encoder = nn .Sequential (
32- nn .Conv2d (1 , 16 , kernel_size = 3 , stride = 2 , padding = 1 ),
37+ nn .Conv2d (n_channels , 16 , kernel_size = 3 , stride = 2 , padding = 1 ),
3338 nn .BatchNorm2d (16 ),
3439 nn .LeakyReLU (0.2 , inplace = True ),
3540 nn .Conv2d (16 , 32 , kernel_size = 3 , stride = 2 , padding = 1 ),
@@ -46,7 +51,7 @@ def __init__(self):
4651 self .decoder = nn .Sequential (
4752 nn .ConvTranspose2d (32 , 16 , 4 , 2 , 1 ),
4853 nn .ReLU (inplace = True ),
49- nn .ConvTranspose2d (16 , 1 , 4 , 2 , 1 ),
54+ nn .ConvTranspose2d (16 , self . n_channels , 4 , 2 , 1 ),
5055 nn .Sigmoid (),
5156 )
5257
@@ -56,18 +61,32 @@ def noise_reparameterize(self, mean, logvar):
5661 return z
5762
5863 def forward (self , x ):
64+ # print('x.shape', x.shape)
65+ # assert False
66+ # x.shape: (64,3,28,28)
5967 out1 , out2 = self .encoder (x ), self .encoder (x )
6068 mean = self .encoder_fc1 (out1 .view (out1 .shape [0 ], - 1 ))
6169 logstd = self .encoder_fc2 (out2 .view (out2 .shape [0 ], - 1 ))
6270 z = self .noise_reparameterize (mean , logstd )
63- out3 = self .decoder_fc (z )
64- out3 = out3 .view (out3 .shape [0 ], 32 , 7 , 7 )
65- out3 = self .decoder (out3 )
66- return out3 , mean , logstd
67-
68- def loss_function (recon_x , x , mean , std ):
69- BCE = F .binary_cross_entropy (recon_x , x , reduction = 'sum' )
70- # 因为var是标准差的自然对数,先求自然对数然后平方转换成方差
71+ decoded_img = self .decoder_fc (z )
72+ decoded_img = decoded_img .view (decoded_img .shape [0 ], 32 , 7 , 7 )
73+ decoded_img = self .decoder (decoded_img )
74+ return decoded_img .to (device ), mean .to (device ), logstd .to (device )
75+
76+
77+ def vae_loss_function (recon_x , x , mean , std ):
78+ _ , n_channels , _ , _ = x .size ()
79+ BCE_arr = []
80+ for channel_idx in range (n_channels ):
81+ # print('max recon_x', torch.max(recon_x[:, channel_idx, :, :]))
82+ # print('min recon_x', torch.min(recon_x[:, channel_idx, :, :]))
83+ # print('max x', torch.max(x[:, channel_idx, :, :]))
84+ # print('min x', torch.min(x[:, channel_idx, :, :]))
85+ BCE_i = F .binary_cross_entropy (recon_x [:, channel_idx , :, :], x [:, channel_idx , :, :],
86+ reduction = 'sum' )
87+ # print('BCE_i', BCE_i)
88+ BCE_arr .append ( BCE_i )
89+ BCE = torch .sum (torch .stack (BCE_arr ))
7190 var = torch .pow (torch .exp (std ), 2 )
7291 KLD = - 0.5 * torch .sum (1 + torch .log (var ) - torch .pow (mean , 2 ) - var )
7392 return BCE + KLD
@@ -81,45 +100,51 @@ def loss_function(recon_x, x, mean, std):
81100 parser .add_argument ('--batch-size' , type = int , default = 64 )
82101 parser .add_argument ('--data-name1' , type = str , default = 'alarm' )
83102 parser .add_argument ('--data-name2' , type = str , default = 'circle' )
84- parser .add_argument ('--data-type' , type = str , default = 'domainnet' )
103+ parser .add_argument ('--data-type' , type = str , default = 'domainnet' ) # Category: [domainnet, 'quickdraw']
85104 parser .add_argument ('--domain-set-data' , type = str , default = 'airplane' )
86- parser .add_argument ('--domain1' , type = str , default = 'clipart' )
87- parser .add_argument ('--domain2' , type = str , default = 'infograph' )
88-
89- parser .add_argument ('--n-epochs' , type = int , default = 100 )
90- parser .add_argument ('--z-dimension' , type = int , default = 2 )
105+ parser .add_argument ('--domain1' , type = str , default = 'real' )
106+ parser .add_argument ('--domain2' , type = str , default = 'quickdraw' )
107+ parser .add_argument ('--n-epochs' , type = int , default = 5000 )
108+ parser .add_argument ('--z-dimension' , type = int , default = 5 )
91109 parser .add_argument ('--lr' , type = float , default = 3e-4 )
92110 parser .add_argument ('--pref0' , type = float , default = 0.0 )
93111 args = parser .parse_args ()
94112 # batch_size = 64, # num_epoch = 15, # z_dimension = 2
95113 img_transform = transforms .Compose ([
96114 transforms .ToTensor (),
97115 ])
116+
98117 if args .data_type == 'mnist' :
99- # mnist dataset mnist数据集下载
100118 mnist = datasets .MNIST (root = './data/' , train = True , transform = img_transform , download = True )
101- # data loader 数据载入
102119 dataloader = torch .utils .data .DataLoader (dataset = mnist , batch_size = args .batch_size , shuffle = True )
120+ n_channels = 1
103121 elif args .data_type == 'domainnet' :
104- # F:\code\libmoon\libmoon\moogan\data\domainnet
105122 path1 = os .path .join (
106123 root_name , 'libmoon' , 'moogan' , 'data' , 'domainnet' , args .domain1 , args .domain_set_data
107124 )
125+
108126 path2 = os .path .join (
109127 root_name , 'libmoon' , 'moogan' , 'data' , 'domainnet' , args .domain2 , args .domain_set_data
110128 )
111129
112130 dataset1 = FolderDataset (path1 )
113131 dataset2 = FolderDataset (path2 )
114-
132+ print (dataset1 [0 ].shape )
133+ print (dataset2 [0 ].shape )
134+ print ('len dataset1' , len (dataset1 ))
135+ print ('len dataset2' , len (dataset2 ))
115136
116137 dataloader1 = DataLoader (dataset1 , batch_size = args .batch_size , shuffle = True )
117138 dataloader2 = DataLoader (dataset2 , batch_size = args .batch_size , shuffle = True )
118- else :
139+ n_channels = 3
140+
141+ elif args .data_type == 'quickdraw' :
119142 path1 = os .path .join (root_name , 'libmoon' , 'moogan' , 'data' , 'quick_draw' ,
120143 'full_numpy_bitmap_{}.npy' .format (args .data_name1 ))
121144 img1_data = np .load (path1 )
122- img1_data = img1_data .reshape (- 1 , 1 , 28 , 28 )
145+ if args .data_type == 'quickdraw' :
146+ img1_data = img1_data .reshape (- 1 , 1 , 28 , 28 )
147+
123148 img1_data = img1_data / 255
124149 path2 = os .path .join (root_name , 'libmoon' , 'moogan' , 'data' , 'quick_draw' ,
125150 'full_numpy_bitmap_{}.npy' .format (args .data_name2 ))
@@ -131,63 +156,67 @@ def loss_function(recon_x, x, mean, std):
131156 print ('img1_data size: ' , len (img1_data ))
132157 dataloader1 = DataLoader (img1_data , batch_size = args .batch_size , shuffle = True )
133158 dataloader2 = DataLoader (img2_data , batch_size = args .batch_size , shuffle = True )
159+ n_channels = 1
160+ else :
161+ assert False , 'dataset not implemented'
134162
135- vae = VAE ().to (device )
163+ vae = VAE (n_channels = n_channels ).to (device )
136164 num1 = numel (vae .encoder )
137165 num2 = numel (vae .decoder )
138- print ()
139166 vae_optimizer = torch .optim .Adam (vae .parameters (), lr = args .lr ,
140167 betas = (0.9 , 0.999 ), eps = 1e-08 , weight_decay = 0 )
141-
142- for epoch in range (args .n_epochs ): # 进行多个epoch的训练
168+ for epoch in range (args .n_epochs ):
143169 for i , (img1 , img2 ) in tqdm (enumerate (zip (dataloader1 , dataloader2 ))):
170+ img1 = img1 .to (device )
171+ img2 = img2 .to (device )
144172 num_img = img1 .size (0 )
145- img1 = img1 .view (num_img , 1 , 28 , 28 ).to (device ) # 将图片展开为28*28=784
146- x , mean1 , logstd1 = vae (img1 ) # 将真实图片放入判别器中
147- loss = loss_function (x , img1 , mean1 , logstd1 )
173+ if args .data_type == 'quickdraw' :
174+ img1 = img1 .view (num_img , 1 , 28 , 28 ).to (device ) # 将图片展开为28*28=784
175+
176+ decoded_img1 , mean1 , logstd1 = vae (img1 )
177+
178+ loss1 = vae_loss_function (decoded_img1 , img1 , mean1 , logstd1 )
148179 vae_optimizer .zero_grad () # 在反向传播之前,先将梯度归 0.
149- loss .backward () # 将误差反向传播
180+ loss1 .backward ()
150181 vae_optimizer .step () # 更新参数
151182
152183 num_img2 = img2 .size (0 )
153- img2 = img2 .view (num_img2 , 1 , 28 , 28 ).to (device ) # 将图片展开为 28*28=784.
154- x , mean2 , logstd2 = vae (img2 ) # 将真实图片放入判别器中
155- loss = loss_function (x , img2 , mean2 , logstd2 )
156- vae_optimizer .zero_grad () # 在反向传播之前,先将梯度归0
157- loss .backward () # 将误差反向传播
184+ if args .data_type == 'quickdraw' :
185+ img2 = img2 .view (num_img2 , 1 , 28 , 28 ).to (device ) # 将图片展开为 28*28=784.
186+
187+ decoded_img2 , mean2 , logstd2 = vae (img2 )
188+ loss2 = vae_loss_function (decoded_img2 , img2 , mean2 , logstd2 )
189+ vae_optimizer .zero_grad ()
190+ loss2 .backward ()
158191 vae_optimizer .step ()
159192
160193 if (i + 1 ) % 100 == 0 :
161194 print ('Epoch[{}/{}],vae_loss:{:.6f} ' .format (
162195 epoch , args .n_epochs , loss .item (),
163196 ))
164197
165- folder_name = os .path .join (root_name , 'libmoon' , 'moogan' , 'img_VAE' ,
166- '{}_{}' .format (args .data_name1 , args .data_name2 ))
167- os .makedirs (folder_name , exist_ok = True )
198+ if args .data_type == 'quickdraw' :
199+ folder_name = os .path .join (root_name , 'libmoon' , 'moogan' , 'img_VAE' , args .data_type ,
200+ '{}_{}' .format (args .data_name1 , args .data_name2 ))
201+ else :
202+ folder_name = os .path .join (root_name , 'libmoon' , 'moogan' , 'img_VAE' , args .data_type , args .domain_set_data ,
203+ '{}_{}' .format (args .domain1 , args .domain2 ))
168204
205+ os .makedirs (folder_name , exist_ok = True )
169206 if epoch == 0 :
170- real_images1 = make_grid (img [:25 ].cpu (), nrow = 5 , normalize = True ).detach ()
207+ real_images1 = make_grid (img1 [:25 ].cpu (), nrow = 5 , normalize = True ).detach ()
171208 save_image (real_images1 , os .path .join (folder_name , 'real_images1.pdf' ))
172209 real_images2 = make_grid (img2 [:25 ].cpu (), nrow = 5 , normalize = True ).detach ()
173210 save_image (real_images2 , os .path .join (folder_name , 'real_images2.pdf' ))
174211 # sample_size = 25
175212 pref0_arr = np .linspace (0 , 1 , 5 )
176213
177-
178214 if i == 0 :
179215 for pref0 in pref0_arr :
180- # mean1 (mean2).shape: (64,2)
181- # meanA = torch.mean(mean1)
182- # meanB = torch.mean(mean2)
183216 Std1 = torch .exp (logstd1 )
184217 Std2 = torch .exp (logstd2 )
185218 mu , std = mokl (mean1 , mean2 , Std1 , Std2 , pref0 )
186- # mu.shape: 64*2
187- # std.shape: 64*2*2
188219 std = torch .diagonal (std , dim1 = 1 , dim2 = 2 )
189- # print(mu.shape)
190- # print(std.shape)
191220 sample_size = len (mu )
192221 sample = torch .randn (sample_size , args .z_dimension ).to (device ) * std + mu
193222 output = vae .decoder_fc (sample )
@@ -196,7 +225,5 @@ def loss_function(recon_x, x, mean, std):
196225 fig_name = os .path .join (folder_name , 'fake_images_{}_{:.2f}.pdf' .format (epoch + 16 , pref0 ))
197226 save_image (fake_images , fig_name )
198227 print ('img saved in' , fig_name )
199-
200-
201228 # 保存模型
202229 torch .save (vae .state_dict (), './VAE_z2.pth' )
0 commit comments