1919import argparse
2020import math
2121import sys
22- from dataclasses import dataclass
23- from typing import List , Sequence , Tuple
22+ from typing import Sequence , Tuple
2423
2524import numpy as np
2625import tensorcircuit as tc
2726from tensorcircuit .quditcircuit import QuditCircuit
2827
2928
3029# ---------- Hamiltonian helpers ----------
31-
32-
33- def number_op (d : int ) -> np .ndarray :
34- return np .diag (np .arange (d , dtype = np .float32 )).astype (np .complex64 )
35-
36-
37- def x_unitary (d : int ) -> np .ndarray :
38- X = np .zeros ((d , d ), dtype = np .complex64 )
39- for j in range (d ):
40- X [(j + 1 ) % d , j ] = 1.0
41- return X
42-
43-
44- def z_unitary (d : int ) -> np .ndarray :
45- omega = np .exp (2j * np .pi / d )
46- diag = np .array ([omega ** j for j in range (d )], dtype = np .complex64 )
47- return np .diag (diag )
48-
49-
5030def symmetrize_hermitian (U : np .ndarray ) -> np .ndarray :
5131 return 0.5 * (U + U .conj ().T )
5232
5333
54- def kron2 (a : np .ndarray , b : np .ndarray ) -> np .ndarray :
55- return np .kron (a , b ).astype (np .complex64 )
56-
57-
58- @dataclass
59- class Hamiltonian2Qudit :
60- H_local_0 : np .ndarray
61- H_local_1 : np .ndarray
62- H_couple : np .ndarray
63-
64- def as_terms (self ) -> List [Tuple [np .ndarray , Sequence [int ]]]:
65- return [
66- (self .H_local_0 , [0 ]),
67- (self .H_local_1 , [1 ]),
68- (self .H_couple , [0 , 1 ]),
69- ]
70-
71-
72- def build_2site_hamiltonian (d : int , J : float ) -> Hamiltonian2Qudit :
73- N = number_op (d )
74- Xsym = symmetrize_hermitian (x_unitary (d ))
75- Zsym = symmetrize_hermitian (z_unitary (d ))
34+ def build_2site_hamiltonian (
35+ d : int , J : float
36+ ) -> Tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray , float ]:
37+ N = np .diag (np .arange (d ))
38+ Xsym = symmetrize_hermitian (tc .backend .numpy (tc .quditgates ._x_matrix_func (d )))
39+ Zsym = symmetrize_hermitian (tc .backend .numpy (tc .quditgates ._z_matrix_func (d )))
7640 H0 = N .copy ()
7741 H1 = N .copy ()
78- HXX = kron2 (Xsym , Xsym )
79- HZZ = kron2 (Zsym , Zsym )
80- H01 = J * (HXX + HZZ )
81- return Hamiltonian2Qudit (H0 , H1 , H01 )
42+ return H0 , H1 , Xsym , Zsym , J
8243
8344
8445# ---------- Ansatz ----------
@@ -127,7 +88,12 @@ def build_ansatz(nlayers: int, d: int, params: Sequence) -> QuditCircuit:
12788# ---------- Energy ----------
12889
12990
130- def energy_expectation_backend (params_b , d : int , nlayers : int , ham : Hamiltonian2Qudit ):
91+ def energy_expectation_backend (
92+ params_b ,
93+ d : int ,
94+ nlayers : int ,
95+ ham : Tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray , float ],
96+ ):
13197 """
13298 params_b: 1D backend tensor (jax/tf) of shape [nparams].
13399 Returns backend scalar.
@@ -137,18 +103,30 @@ def energy_expectation_backend(params_b, d: int, nlayers: int, ham: Hamiltonian2
137103 plist = [params_b [i ] for i in range (params_b .shape [0 ])]
138104 c = build_ansatz (nlayers , d , plist )
139105 E = 0.0 + 0.0j
140- for op , sites in ham .as_terms ():
141- E = E + c .expectation ((tc .gates .Gate (op ), list (sites )))
106+ H0 , H1 , Xsym , Zsym , J = ham
107+ # Local number operators
108+ E = E + c .expectation ((tc .gates .Gate (H0 ), [0 ]))
109+ E = E + c .expectation ((tc .gates .Gate (H1 ), [1 ]))
110+ # Coupling terms as products on separate sites (avoids 9x9 reshaping issues)
111+ E = E + J * c .expectation ((tc .gates .Gate (Xsym ), [0 ]), (tc .gates .Gate (Xsym ), [1 ]))
112+ E = E + J * c .expectation ((tc .gates .Gate (Zsym ), [0 ]), (tc .gates .Gate (Zsym ), [1 ]))
142113 return bk .real (E )
143114
144115
145116def energy_expectation_numpy (
146- params_np : np .ndarray , d : int , nlayers : int , ham : Hamiltonian2Qudit
117+ params_np : np .ndarray ,
118+ d : int ,
119+ nlayers : int ,
120+ ham : Tuple [np .ndarray , np .ndarray , np .ndarray , np .ndarray , float ],
147121) -> float :
148122 c = build_ansatz (nlayers , d , params_np .tolist ())
149123 E = 0.0 + 0.0j
150- for op , sites in ham .as_terms ():
151- E += c .expectation ((tc .gates .Gate (op ), list (sites )))
124+ H0 , H1 , Xsym , Zsym , J = ham
125+ E += c .expectation ((tc .gates .Gate (H0 ), [0 ]))
126+ E += c .expectation ((tc .gates .Gate (H1 ), [1 ]))
127+ # Coupling terms as products on separate sites (avoids 9x9 reshaping issues)
128+ E += J * c .expectation ((tc .gates .Gate (Xsym ), [0 ]), (tc .gates .Gate (Xsym ), [1 ]))
129+ E += J * c .expectation ((tc .gates .Gate (Zsym ), [0 ]), (tc .gates .Gate (Zsym ), [1 ]))
152130 return float (np .real (E ))
153131
154132
@@ -159,7 +137,7 @@ def random_search(fun_numpy, x0_shape, iters=300, seed=42):
159137 rng = np .random .default_rng (seed )
160138 best_x , best_y = None , float ("inf" )
161139 for _ in range (iters ):
162- x = rng .uniform (- math .pi , math .pi , size = x0_shape ). astype ( np . float32 )
140+ x = rng .uniform (- math .pi , math .pi , size = x0_shape )
163141 y = fun_numpy (x )
164142 if y < best_y :
165143 best_x , best_y = x , y
@@ -172,11 +150,11 @@ def gradient_descent_ad(energy_bk, x0_np: np.ndarray, steps=200, lr=0.1, jit=Fal
172150 Simple gradient descent in numpy space with backend-gradients.
173151 """
174152 bk = tc .backend
175- if jit and hasattr ( bk , "jit" ) :
153+ if jit :
176154 energy_bk = bk .jit (energy_bk )
177155 grad_f = bk .grad (energy_bk )
178156
179- x_np = x0_np .astype ( np . float32 ). copy ()
157+ x_np = x0_np .copy ()
180158 best_x , best_y = x_np .copy (), float ("inf" )
181159
182160 def to_np (x ):
@@ -221,7 +199,7 @@ def main():
221199 ap .add_argument (
222200 "--jit" ,
223201 action = "store_true" ,
224- help = "enable backend JIT for energy/grad if available " ,
202+ help = "enable backend JIT (all backends implement .jit; numpy backend no-ops) " ,
225203 )
226204 args = ap .parse_args ()
227205
@@ -255,7 +233,7 @@ def obj_bk(theta_b):
255233 return energy_expectation_backend (theta_b , d , L , ham )
256234
257235 rng = np .random .default_rng (args .seed )
258- x0 = rng .uniform (- math .pi , math .pi , size = (nparams ,)). astype ( np . float32 )
236+ x0 = rng .uniform (- math .pi , math .pi , size = (nparams ,))
259237 x , y = gradient_descent_ad (
260238 obj_bk , x0_np = x0 , steps = args .steps , lr = args .lr , jit = args .jit
261239 )
0 commit comments