66
77import time
88import optax
9+ import cotengra
910import tensorcircuit as tc
1011from tensorcircuit .templates .lattice import SquareLattice , get_compatible_layers
1112from tensorcircuit .templates .hamiltonians import heisenberg_hamiltonian
1213
1314# Use JAX for high-performance, especially on GPU.
1415K = tc .set_backend ("jax" )
1516tc .set_dtype ("complex64" )
16- # On Windows, cotengra's multiprocessing can cause issues, use threads instead.
17- tc .set_contractor ("cotengra-8192-8192" , parallel = "threads" )
17+ optimizer = cotengra .ReusableHyperOptimizer (
18+ methods = ["greedy" , "kahypar" ],
19+ parallel = 8 ,
20+ optlib = "cmaes" ,
21+ minimize = "flops" ,
22+ max_time = 120 ,
23+ max_repeats = 4096 ,
24+ progbar = True ,
25+ )
26+ tc .set_contractor ("custom" , optimizer )
1827
1928
2029def run_vqe ():
2130 """Set up and run the VQE optimization for a 2D Heisenberg model."""
22- n , m , nlayers = 4 , 4 , 2
23- lattice = SquareLattice (size = (n , m ), pbc = True , precompute_neighbors = 1 )
31+ n , m , nlayers = 4 , 4 , 3
32+ lattice = SquareLattice (size = (n , m ), pbc = False , precompute_neighbors = 1 )
2433 h = heisenberg_hamiltonian (lattice , j_coupling = [1.0 , 1.0 , 0.8 ]) # Jx, Jy, Jz
2534 nn_bonds = lattice .get_neighbor_pairs (k = 1 , unique = True )
2635 gate_layers = get_compatible_layers (nn_bonds )
@@ -51,12 +60,8 @@ def vqe_forward(param):
5160 for j , k in layer :
5261 c .rzz (int (j ), int (k ), theta = param [param_idx ])
5362 param_idx += 1
54- for layer in gate_layers :
55- for j , k in layer :
5663 c .rxx (int (j ), int (k ), theta = param [param_idx ])
5764 param_idx += 1
58- for layer in gate_layers :
59- for j , k in layer :
6065 c .ryy (int (j ), int (k ), theta = param [param_idx ])
6166 param_idx += 1
6267
@@ -79,10 +84,11 @@ def train_step(param, opt_state):
7984 for i in range (1000 ):
8085 time0 = time .time ()
8186 param , opt_state , loss = train_step (param , opt_state )
87+ print (loss ) # ensure no async for profile
8288 time1 = time .time ()
8389 if i % 10 == 0 :
8490 print (
85- f"Step { i :4d} : Loss = { loss :.6f} \t (Time per step: { time1 - time0 :.4f} s)"
91+ f"Step { i :4d} : Loss = { loss :.6f} \t (Time per step: { ( time1 - time0 ) / 10 :.4f} s)"
8692 )
8793
8894 print ("Optimization finished." )
0 commit comments