Skip to content

Commit 8bfef5b

Browse files
committed
Add inference mode and more often earlystop check.
1 parent 7b901af commit 8bfef5b

File tree

4 files changed

+15
-15
lines changed

4 files changed

+15
-15
lines changed

.gitignore

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,5 @@
22
**/__pycache__
33
_archive
44

5-
results/classical_DQL_sim_quantum_full.7z
65
/scripts/_external_sv_save.py
76
/scripts/3z_entropy-calculation.ipynb
8-
/results/classical_DQL_sim_quantum_sv
9-
/results/classical_DQL_sim_quantum_sv.7z

scripts/3._Classical_DQL_sim_quant.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,10 @@ def train(self, epoch):
171171
# update state
172172
s = s1
173173
if(self.compute_entropy):
174-
self.entropies.append(entanglement_entropy(self.calc_statevector(s)))
175-
self.cl_entropies.append(classical_entropy(self.calc_statevector(s)))
176-
self.entropies_episodes[i] += 1
174+
with torch.inference_mode():
175+
self.entropies.append(entanglement_entropy(self.calc_statevector(s)))
176+
self.cl_entropies.append(classical_entropy(self.calc_statevector(s)))
177+
self.entropies_episodes[i] += 1
177178

178179
if d == True: break
179180

@@ -190,7 +191,7 @@ def train(self, epoch):
190191
self.epsilon *= self.epsilon_growth_rate
191192
self.epsilon_list.append(self.epsilon)
192193

193-
if i%10==0 and i>100:
194+
if i>100:
194195
if sum(self.success[-window:])/window>target_win_ratio:
195196
print("Network trained before epoch limit on {i} epoch".format(i=i))
196197
break

scripts/3b._Classical_DQL_sim_quant_grid_search.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,10 @@ def train(self, epoch):
181181
# update state
182182
s = s1
183183
if(self.compute_entropy):
184-
self.entropies.append(entanglement_entropy(self.calc_statevector(s)))
185-
self.cl_entropies.append(classical_entropy(self.calc_statevector(s)))
186-
self.entropies_episodes[i] += 1
184+
with torch.inference_mode():
185+
self.entropies.append(entanglement_entropy(self.calc_statevector(s)))
186+
self.cl_entropies.append(classical_entropy(self.calc_statevector(s)))
187+
self.entropies_episodes[i] += 1
187188

188189
if d == True: break
189190

@@ -200,7 +201,7 @@ def train(self, epoch):
200201
self.epsilon *= self.epsilon_growth_rate
201202
self.epsilon_list.append(self.epsilon)
202203

203-
if i%10==0 and i>100:
204+
if i>100:
204205
if sum(self.success[-window:])/window>target_win_ratio:
205206
print("Network trained before epoch limit on {i} epoch".format(i=i))
206207
break

scripts/3c._Classical_DQL_sim_quant_finetuning.ipynb

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,9 +336,10 @@
336336
" # update state\n",
337337
" s = s1\n",
338338
" if(self.compute_entropy):\n",
339-
" self.entropies.append(entanglement_entropy(self.calc_statevector(s))) \n",
340-
" self.cl_entropies.append(classical_entropy(self.calc_statevector(s))) \n",
341-
" self.entropies_episodes[i] += 1\n",
339+
" with torch.inference_mode():\n",
340+
" self.entropies.append(entanglement_entropy(self.calc_statevector(s))) \n",
341+
" self.cl_entropies.append(classical_entropy(self.calc_statevector(s))) \n",
342+
" self.entropies_episodes[i] += 1\n",
342343
" \n",
343344
" if d == True: break\n",
344345
" \n",
@@ -362,7 +363,7 @@
362363
" steps = len(self.success)\n",
363364
" )\n",
364365
"\n",
365-
" if i%10==0 and i>100:\n",
366+
" if i>100:\n",
366367
" if sum(self.success[-window:])/window>target_win_ratio:\n",
367368
" #print(\"Network trained before epoch limit on {i} epoch\".format(i=i))\n",
368369
" break\n",

0 commit comments

Comments
 (0)