Skip to content

Commit 433672a

Browse files
committed
update qvae_classifier
1 parent 3f7516b commit 433672a

File tree

2 files changed

+122
-124
lines changed

2 files changed

+122
-124
lines changed

example/qvae_mnist/train_qvae.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,18 @@ def create_model(train_loader, input_dim, hidden_dim, latent_dim,
5353
decoder = Decoder(latent_dim, hidden_dim, input_dim, weight_decay)
5454

5555
# 初始化bm和sampler
56-
rbm = RestrictedBoltzmannMachine(
56+
bm = RestrictedBoltzmannMachine(
5757
num_visible=num_var1,
5858
num_hidden=num_var2,
59-
h_range=[-1, 1],
60-
j_range=[-1, 1]
6159
)
60+
# bm = BoltzmannMachine(num_nodes=num_var1 + num_var2)
6261
sampler = SimulatedAnnealingOptimizer(alpha=0.95)
6362

6463
# 创建QVAE模型(参数与训练时完全一致)
6564
model = QVAE(
6665
encoder=encoder,
6766
decoder=decoder,
68-
bm=rbm,
67+
bm=bm,
6968
sampler=sampler,
7069
dist_beta=dist_beta,
7170
mean_x=mean_x,
@@ -295,7 +294,6 @@ def train_qvae_with_tsne(
295294
torch.save(model.state_dict(), model_save_path)
296295
return model
297296

298-
299297
def train_qvae(
300298
train_loader, # 用于训练QVAE
301299
device,

0 commit comments

Comments
 (0)