11from .maxcut_tfim_streaming import maxcut_tfim_streaming
2- from .maxcut_tfim_util import compute_cut_streaming , compute_energy_streaming , get_cut , get_cut_base , gray_code_next , gray_mutation , heuristic_threshold , int_to_bitstring , opencl_context
2+ from .maxcut_tfim_util import compute_cut_streaming , compute_cut_diff_streaming , compute_cut_diff_2_streaming , compute_energy_streaming , compute_energy_diff_streaming , compute_energy_diff_2_streaming , get_cut , get_cut_base , gray_code_next , gray_mutation , heuristic_threshold , int_to_bitstring , opencl_context
33import networkx as nx
44import numpy as np
55from numba import njit , prange
@@ -19,12 +19,12 @@ def run_single_bit_flips(best_theta, is_spin_glass, G_func, nodes):
1919 for i in prange (n ):
2020 state = best_theta .copy ()
2121 state [i ] = not state [i ]
22- energies [i ] = compute_energy_streaming ( state , G_func , nodes , n )
22+ energies [i ] = compute_energy_diff_streaming ( i , state , G_func , nodes , n )
2323 else :
2424 for i in prange (n ):
2525 state = best_theta .copy ()
2626 state [i ] = not state [i ]
27- energies [i ] = compute_cut_streaming ( state , G_func , nodes , n )
27+ energies [i ] = compute_cut_diff_streaming ( i , state , G_func , nodes , n )
2828
2929 best_index = np .argmax (energies )
3030 best_energy = energies [best_index ]
@@ -62,7 +62,7 @@ def run_double_bit_flips(best_theta, is_spin_glass, G_func, nodes, thread_count)
6262 state [i ] = not state [i ]
6363 state [j ] = not state [j ]
6464
65- states [t ], energies [t ] = state , compute_energy_streaming ( state , G_func , nodes , n )
65+ states [t ], energies [t ] = state , compute_energy_diff_2_streaming ( i , j , state , G_func , nodes , n )
6666
6767 s += thread_batch
6868 else :
@@ -84,7 +84,7 @@ def run_double_bit_flips(best_theta, is_spin_glass, G_func, nodes, thread_count)
8484 state [i ] = not state [i ]
8585 state [j ] = not state [j ]
8686
87- states [t ], energies [t ] = state , compute_cut_streaming ( state , G_func , nodes , n )
87+ states [t ], energies [t ] = state , compute_cut_diff_2_streaming ( i , j , state , G_func , nodes , n )
8888
8989 s += thread_batch
9090
@@ -247,16 +247,16 @@ def spin_glass_solver_streaming(
247247
248248 # Single bit flips with O(n^2)
249249 energy , state = run_single_bit_flips (best_theta , is_spin_glass , G_func , nodes )
250- if energy > max_energy :
251- max_energy = energy
250+ if energy > 0.0 :
251+ max_energy + = energy
252252 best_theta = state
253253 improved = True
254254 continue
255255
256256 # Double bit flips with O(n^3)
257257 energy , state = run_double_bit_flips (best_theta , is_spin_glass , G_func , nodes , thread_count )
258- if energy > max_energy :
259- max_energy = energy
258+ if energy > 0.0 :
259+ max_energy + = energy
260260 best_theta = state
261261 improved = True
262262 continue
@@ -277,21 +277,26 @@ def spin_glass_solver_streaming(
277277 improved = True
278278 continue
279279
280+ if max_energy == float ("inf" ):
281+ # We no way to compare for improvement.
282+ break
283+
280284 # Post-reheat phase
281285 reheat_theta = state
286+ reheat_energy = energy
282287
283288 # Single bit flips with O(n^2)
284289 energy , state = run_single_bit_flips (reheat_theta , is_spin_glass , G_func , nodes )
285- if energy > max_energy :
286- max_energy = energy
290+ if energy > ( max_energy - reheat_energy ) :
291+ max_energy + = energy
287292 best_theta = state
288293 improved = True
289294 continue
290295
291296 # Double bit flips with O(n^3)
292297 energy , state = run_double_bit_flips (reheat_theta , is_spin_glass , G_func , nodes , thread_count )
293- if energy > max_energy :
294- max_energy = energy
298+ if energy > ( max_energy - reheat_energy ) :
299+ max_energy + = energy
295300 best_theta = state
296301 improved = True
297302
0 commit comments