Skip to content

Commit 83d6df5

Browse files
Optimize
1 parent dec4bd8 commit 83d6df5

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

pyqrackising/generate_tfim_samples.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy as np
66
from numba import njit
77

8+
from collections import Counter
9+
810

911
epsilon = opencl_context.epsilon
1012

@@ -151,20 +153,22 @@ def generate_tfim_samples(
151153
# First dimension: Hamming weight
152154
bias = get_tfim_hamming_distribution(J=J, h=h, z=z, theta=theta, t=t, n_qubits=n_qubits)
153155
thresholds = fix_cdf(bias)
154-
hamming_samples = sample_hamming_weight(thresholds, shots)
156+
hamming_samples = dict(Counter(sample_hamming_weight(thresholds, shots)))
155157

156-
for h_weight in hamming_samples:
158+
for h_weight, count in hamming_samples.items():
157159
if h_weight == 0:
158-
samples.append(0)
160+
samples += count * [0]
159161
continue
160162

161163
if h_weight == n_qubits:
162-
samples.append((1 << n_qubits) - 1)
164+
samples += count * [(1 << n_qubits) - 1]
163165
continue
164166

165-
p = np.random.random()
167+
rands = [np.random.random() for _ in range(count)]
168+
rands.sort()
166169
state_int = 0
167170
tot_prob = 0
171+
s = 0
168172
# How closely grouped are "like" bits to "like"?
169173
expected_closeness = expected_closeness_weight(n_rows, n_cols, h_weight)
170174
h_weight_combos = math.comb(n_qubits, h_weight)
@@ -175,9 +179,15 @@ def generate_tfim_samples(
175179
# Use a normalized weighted average that favors the (n+1)-dimensional model at later times.
176180
# The (n+1)-dimensional marginal probability is the product of a function of Hamming weight and "closeness," split among all basis states with that specific Hamming weight.
177181
tot_prob += normed_closeness / h_weight_combos
178-
if (p <= tot_prob):
182+
while (rands[s] <= tot_prob):
179183
samples.append(state_int)
184+
s += 1
185+
if s == count:
186+
break
187+
if s == count:
180188
break
189+
if s < count:
190+
samples += (count - s) * [state_int]
181191

182192

183193
np.random.shuffle(samples)

0 commit comments

Comments
 (0)