2222tc .set_dtype ("complex128" )
2323
2424
25- def get_circuit (n , d , params ):
25+ def circuit2nodes (n , d , params , tc_mpo ):
2626 c = tc .Circuit (n )
2727 c .h (range (n ))
2828 for i in range (d ):
@@ -32,14 +32,15 @@ def get_circuit(n, d, params):
3232 c .rx (j , theta = params [j , i , 1 ])
3333 for j in range (n ):
3434 c .ry (j , theta = params [j , i , 2 ])
35- return c
3635
37-
38- def core (params , i , tree , n , d , tc_mpo ):
39- c = get_circuit (n , d , params )
4036 mps = c .get_quvector ()
4137 e = mps .adjoint () @ tc_mpo @ mps
42- _ , nodes = tc .cons .get_tn_info (e .nodes )
38+ return e .nodes
39+
40+
41+ def core (params , i , tree , n , d , tc_mpo ):
42+ nodes = circuit2nodes (n , d , params , tc_mpo )
43+ _ , nodes = tc .cons .get_tn_info (nodes )
4344 input_arrays = [node .tensor for node in nodes ]
4445 sliced_arrays = tree .slice_arrays (input_arrays , i )
4546 return K .real (tree .contract_core (sliced_arrays , backend = backend ))[0 , 0 ]
@@ -52,24 +53,18 @@ def core(params, i, tree, n, d, tc_mpo):
5253 nqubit = 12
5354 d = 6
5455
55- Jx = jax .numpy .array ([1.0 ] * (nqubit - 1 )) # XX coupling strength
56- Bz = jax .numpy .array ([- 1.0 ] * nqubit ) # Transverse field strength
57-
58- # Create TensorNetwork MPO
59- tn_mpo = tn .matrixproductstates .mpo .FiniteTFI (Jx , Bz , dtype = np .complex64 )
60- tc_mpo = tc .quantum .tn2qop (tn_mpo )
61-
6256 # baseline results
6357 lattice = tc .templates .graphs .Line1D (nqubit , pbc = False )
6458 h = tc .quantum .heisenberg_hamiltonian (lattice , hzz = 0 , hyy = 0 , hxx = 1.0 , hz = - 1.0 )
6559 es0 = scipy .sparse .linalg .eigsh (K .numpy (h ), k = 1 , which = "SA" )[0 ]
6660 print ("exact ground state energy: " , es0 )
6761
68- params = K .implicit_randn (stddev = 0.1 , shape = [1 , nqubit , d , 3 ], dtype = tc .rdtypestr )
69- params = K .tile (params , [num_device , 1 , 1 , 1 ])
62+ params = K .implicit_randn (stddev = 0.1 , shape = [nqubit , d , 3 ], dtype = tc .rdtypestr )
63+ replicated_params = K .reshape (params , [1 ] + list (params .shape ))
64+ replicated_params = K .tile (replicated_params , [num_device , 1 , 1 , 1 ])
7065
7166 optimizer = optax .adam (5e-2 )
72- base_opt_state = optimizer .init (params [ 0 ] )
67+ base_opt_state = optimizer .init (params )
7368 replicated_opt_state = jax .tree .map (
7469 lambda x : (
7570 jax .numpy .broadcast_to (x , (num_device ,) + x .shape )
@@ -93,28 +88,32 @@ def para_vag(params, i, tree, n, d, tc_mpo, opt_state):
9388 params = optax .apply_updates (params , updates )
9489 return params , opt_state , loss
9590
96- c = get_circuit (nqubit , d , params [0 ])
97- mps = c .get_quvector ()
98- e = mps .adjoint () @ tc_mpo @ mps
99- tn_info , nodes = tc .cons .get_tn_info (e .nodes )
91+ Jx = jax .numpy .array ([1.0 ] * (nqubit - 1 )) # XX coupling strength
92+ Bz = jax .numpy .array ([- 1.0 ] * nqubit ) # Transverse field strength
93+ # Create TensorNetwork MPO
94+ tn_mpo = tn .matrixproductstates .mpo .FiniteTFI (Jx , Bz , dtype = np .complex64 )
95+ tc_mpo = tc .quantum .tn2qop (tn_mpo )
10096
97+ nodes = circuit2nodes (nqubit , d , params , tc_mpo )
98+ tn_info , _ = tc .cons .get_tn_info (nodes )
99+
100+ # Create ReusableHyperOptimizer for finding optimal contraction paths
101101 opt = ctg .ReusableHyperOptimizer (
102- parallel = True ,
102+ parallel = True , # Enable parallel path finding
103103 slicing_opts = {
104- "target_slices" : num_device ,
105- # "target_size": 2**20, # Add memory target
104+ "target_slices" : num_device , # Split computation across available devices
105+ # "target_size": 2**20, # Optional: Set memory limit per slice
106106 },
107- max_repeats = 256 ,
108- progbar = True ,
109- minimize = "combo" ,
107+ max_repeats = 256 , # Maximum number of path finding attempts
108+ progbar = True , # Show progress bar during optimization
109+ minimize = "combo" , # Optimize for both time and memory
110110 )
111-
112111 tree = opt .search (* tn_info )
113112
114113 inds = K .arange (num_device )
115114 for j in range (100 ):
116115 print (f"training loop: { j } -step" )
117- params , replicated_opt_state , loss = para_vag (
118- params , inds , tree , nqubit , d , tc_mpo , replicated_opt_state
116+ replicated_params , replicated_opt_state , loss = para_vag (
117+ replicated_params , inds , tree , nqubit , d , tc_mpo , replicated_opt_state
119118 )
120119 print (loss [0 ])
0 commit comments