1212import numpy as np
1313from tqdm import tqdm
1414
15+ from libmoon .util .constant import root_name
16+
17+
1518# 创建文件夹
1619device = 'cuda' if torch .cuda .is_available () else 'cpu'
1720
@@ -86,19 +89,15 @@ def loss_function(recon_x, x, mean, std):
8689if __name__ == '__main__' :
8790 parser = argparse .ArgumentParser ()
8891 parser .add_argument ('--batch-size' , type = int , default = 64 )
89- # parser.add_argument('--dataset-name', type=str, default='mnist')
9092 parser .add_argument ('--data-name1' , type = str , default = 'alarm' )
9193 parser .add_argument ('--data-name2' , type = str , default = 'circle' )
92-
9394 parser .add_argument ('--n-epochs' , type = int , default = 100 )
9495 parser .add_argument ('--z-dimension' , type = int , default = 2 )
9596 parser .add_argument ('--lr' , type = float , default = 3e-4 )
9697 parser .add_argument ('--pref0' , type = float , default = 0.0 )
9798 args = parser .parse_args ()
98- # batch_size = 64
99- # num_epoch = 15
100- # z_dimension = 2
101- # 图形啊处理过程
99+ # batch_size = 64, # num_epoch = 15, # z_dimension = 2
100+
102101 img_transform = transforms .Compose ([
103102 transforms .ToTensor (),
104103 ])
@@ -109,12 +108,12 @@ def loss_function(recon_x, x, mean, std):
109108 # data loader 数据载入
110109 dataloader = torch .utils .data .DataLoader (dataset = mnist , batch_size = args .batch_size , shuffle = True )
111110 else :
112- path1 = 'D: \\ pycharm_project \\ libmoon\\ libmoon \\ moogan\\ data\\ full_numpy_bitmap_{}.npy' .format (args .data_name1 )
111+ path1 = os . path . join ( root_name , ' libmoon' , ' moogan' , ' data' , ' full_numpy_bitmap_{}.npy' .format (args .data_name1 ) )
113112 img1_data = np .load (path1 )
114113 img1_data = img1_data .reshape (- 1 , 1 , 28 , 28 )
115114 img1_data = img1_data / 255
116115
117- path2 = 'D: \\ pycharm_project \\ libmoon\\ libmoon \\ moogan\\ data\\ full_numpy_bitmap_{}.npy' .format (args .data_name2 )
116+ path2 = os . path . join ( root_name , ' libmoon' , ' moogan' , ' data' , ' full_numpy_bitmap_{}.npy' .format (args .data_name2 ) )
118117 img2_data = np .load (path2 )
119118 img2_data = img2_data .reshape (- 1 , 1 , 28 , 28 )
120119 img2_data = img2_data / 255
@@ -139,7 +138,6 @@ def loss_function(recon_x, x, mean, std):
139138 num_img = img .size (0 )
140139 # view()函数作用把img变成[batch_size,channel_size,784]
141140 img = img .view (num_img , 1 , 28 , 28 ).to (device ) # 将图片展开为28*28=784
142-
143141 x , mean1 , logstd1 = vae (img ) # 将真实图片放入判别器中
144142 loss = loss_function (x , img , mean1 , logstd1 )
145143 vae_optimizer .zero_grad () # 在反向传播之前,先将梯度归 0.
0 commit comments