Skip to content

Commit c15330a

Browse files
committed
X
1 parent f84a201 commit c15330a

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

libmoon/moogan/download_dataset.txt

Whitespace-only changes.

libmoon/moogan/modm_func.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
2+
if __name__ == '__main__':
3+
print()

libmoon/moogan/moovae.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
import numpy as np
1313
from tqdm import tqdm
1414

15+
from libmoon.util.constant import root_name
16+
17+
1518
# 创建文件夹
1619
device = 'cuda' if torch.cuda.is_available() else 'cpu'
1720

@@ -86,19 +89,15 @@ def loss_function(recon_x, x, mean, std):
8689
if __name__ == '__main__':
8790
parser = argparse.ArgumentParser()
8891
parser.add_argument('--batch-size', type=int, default=64)
89-
# parser.add_argument('--dataset-name', type=str, default='mnist')
9092
parser.add_argument('--data-name1', type=str, default='alarm')
9193
parser.add_argument('--data-name2', type=str, default='circle')
92-
9394
parser.add_argument('--n-epochs', type=int, default=100)
9495
parser.add_argument('--z-dimension', type=int, default=2)
9596
parser.add_argument('--lr', type=float, default=3e-4)
9697
parser.add_argument('--pref0', type=float, default=0.0)
9798
args = parser.parse_args()
98-
# batch_size = 64
99-
# num_epoch = 15
100-
# z_dimension = 2
101-
# 图形啊处理过程
99+
# batch_size = 64, # num_epoch = 15, # z_dimension = 2
100+
102101
img_transform = transforms.Compose([
103102
transforms.ToTensor(),
104103
])
@@ -109,12 +108,12 @@ def loss_function(recon_x, x, mean, std):
109108
# data loader 数据载入
110109
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=args.batch_size, shuffle=True)
111110
else:
112-
path1 = 'D:\\pycharm_project\\libmoon\\libmoon\\moogan\\data\\full_numpy_bitmap_{}.npy'.format(args.data_name1)
111+
path1 = os.path.join(root_name, 'libmoon', 'moogan', 'data', 'full_numpy_bitmap_{}.npy'.format(args.data_name1))
113112
img1_data = np.load(path1)
114113
img1_data = img1_data.reshape(-1, 1, 28, 28)
115114
img1_data = img1_data / 255
116115

117-
path2 = 'D:\\pycharm_project\\libmoon\\libmoon\\moogan\\data\\full_numpy_bitmap_{}.npy'.format(args.data_name2)
116+
path2 = os.path.join(root_name, 'libmoon', 'moogan', 'data', 'full_numpy_bitmap_{}.npy'.format(args.data_name2))
118117
img2_data = np.load(path2)
119118
img2_data = img2_data.reshape(-1, 1, 28, 28)
120119
img2_data = img2_data / 255
@@ -139,7 +138,6 @@ def loss_function(recon_x, x, mean, std):
139138
num_img = img.size(0)
140139
# view()函数作用把img变成[batch_size,channel_size,784]
141140
img = img.view(num_img, 1, 28, 28).to(device) # 将图片展开为28*28=784
142-
143141
x, mean1, logstd1 = vae(img) # 将真实图片放入判别器中
144142
loss = loss_function(x, img, mean1, logstd1)
145143
vae_optimizer.zero_grad() # 在反向传播之前,先将梯度归 0.

0 commit comments

Comments
 (0)