@@ -141,41 +141,42 @@ def pick_gray_seeds(best_theta, thread_count, gray_seed_multiple, G_m, n, is_spi
141141 return best_seeds , best_energies
142142
143143@njit (parallel = True )
144- def run_gray_optimization (best_theta , iterators , energies , gray_iterations , thread_count , is_spin_glass , G_m ):
144+ def run_gray_optimization (best_theta , iterators , gray_iterations , thread_count , is_spin_glass , G_m ):
145145 n = len (best_theta )
146146 thread_iterations = (gray_iterations + thread_count - 1 ) // thread_count
147147 blocks = (n + 63 ) >> 6
148+ energies = np .empty (thread_count , dtype = dtype )
148149
149150 if is_spin_glass :
150151 for i in prange (thread_count ):
151152 iterator = iterators [i ]
152- best_energy = energies [ i ]
153+ best_energy = 0.0
153154 for curr_idx in range (thread_iterations ):
154155 for block in range (blocks ):
155156 flip_bit = gray_code_next (iterator , curr_idx , block << 6 )
156- energy = compute_energy ( iterator , G_m , n )
157+ energy = compute_energy_diff ( flip_bit , iterator , G_m , n )
157158 if energy > best_energy :
158159 best_energy = energy
159160 else :
160161 # Revert iterator
161162 iterator [flip_bit ] = not iterator [flip_bit ]
162- if best_energy > energies [ i ] :
163- energies [i ] = best_energy
163+ if best_energy > 0.0 :
164+ energies [i ] + = best_energy
164165 else :
165166 for i in prange (thread_count ):
166167 iterator = iterators [i ]
167- best_energy = energies [ i ]
168+ best_energy = 0.0
168169 for curr_idx in range (thread_iterations ):
169170 for block in range (blocks ):
170171 flip_bit = gray_code_next (iterator , curr_idx , block << 6 )
171- energy = compute_cut ( iterator , G_m , n )
172+ energy = compute_cut_diff ( flip_bit , iterator , G_m , n )
172173 if energy > best_energy :
173174 best_energy = energy
174175 else :
175176 # Revert iterator
176177 iterator [flip_bit ] = not iterator [flip_bit ]
177- if best_energy > energies [ i ] :
178- energies [i ] = best_energy
178+ if best_energy > 0.0 :
179+ energies [i ] + = best_energy
179180
180181 best_index = np .argmax (energies )
181182 best_energy = energies [best_index ]
@@ -295,9 +296,7 @@ def run_gray_search_opencl(n, kernel, best_energy, theta, theta_buf, G_m_buf, is
295296
296297 if energy <= 0.0 :
297298 # No improvement: we can exit early
298- return best_energy , theta
299-
300- energy += best_energy
299+ return 0.0 , theta
301300
302301 # We need the best index
303302 queue .finish ()
@@ -457,9 +456,10 @@ def spin_glass_solver(
457456 improved = True
458457 continue
459458
460- energy , state = run_gray_optimization (best_theta , iterators , energies , gray_iterations , thread_count , is_spin_glass , G_m )
461- if energy > max_energy :
462- max_energy = energy
459+ energies = None
460+ energy , state = run_gray_optimization (best_theta , iterators , gray_iterations , thread_count , is_spin_glass , G_m )
461+ if energy > 0.0 :
462+ max_energy += energy
463463 best_theta = state
464464 improved = True
465465 continue
0 commit comments