Skip to content

Commit 52f9804

Browse files
committed
optimized examples/vqe_qudit_example.py according to comments.
1 parent 53669a2 commit 52f9804

File tree

1 file changed

+36
-58
lines changed

1 file changed

+36
-58
lines changed

examples/vqe_qudit_example.py

Lines changed: 36 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -19,66 +19,27 @@
1919
import argparse
2020
import math
2121
import sys
22-
from dataclasses import dataclass
23-
from typing import List, Sequence, Tuple
22+
from typing import Sequence, Tuple
2423

2524
import numpy as np
2625
import tensorcircuit as tc
2726
from tensorcircuit.quditcircuit import QuditCircuit
2827

2928

3029
# ---------- Hamiltonian helpers ----------
31-
32-
33-
def number_op(d: int) -> np.ndarray:
34-
return np.diag(np.arange(d, dtype=np.float32)).astype(np.complex64)
35-
36-
37-
def x_unitary(d: int) -> np.ndarray:
38-
X = np.zeros((d, d), dtype=np.complex64)
39-
for j in range(d):
40-
X[(j + 1) % d, j] = 1.0
41-
return X
42-
43-
44-
def z_unitary(d: int) -> np.ndarray:
45-
omega = np.exp(2j * np.pi / d)
46-
diag = np.array([omega**j for j in range(d)], dtype=np.complex64)
47-
return np.diag(diag)
48-
49-
5030
def symmetrize_hermitian(U: np.ndarray) -> np.ndarray:
5131
return 0.5 * (U + U.conj().T)
5232

5333

54-
def kron2(a: np.ndarray, b: np.ndarray) -> np.ndarray:
55-
return np.kron(a, b).astype(np.complex64)
56-
57-
58-
@dataclass
59-
class Hamiltonian2Qudit:
60-
H_local_0: np.ndarray
61-
H_local_1: np.ndarray
62-
H_couple: np.ndarray
63-
64-
def as_terms(self) -> List[Tuple[np.ndarray, Sequence[int]]]:
65-
return [
66-
(self.H_local_0, [0]),
67-
(self.H_local_1, [1]),
68-
(self.H_couple, [0, 1]),
69-
]
70-
71-
72-
def build_2site_hamiltonian(d: int, J: float) -> Hamiltonian2Qudit:
73-
N = number_op(d)
74-
Xsym = symmetrize_hermitian(x_unitary(d))
75-
Zsym = symmetrize_hermitian(z_unitary(d))
34+
def build_2site_hamiltonian(
35+
d: int, J: float
36+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, float]:
37+
N = np.diag(np.arange(d))
38+
Xsym = symmetrize_hermitian(tc.backend.numpy(tc.quditgates._x_matrix_func(d)))
39+
Zsym = symmetrize_hermitian(tc.backend.numpy(tc.quditgates._z_matrix_func(d)))
7640
H0 = N.copy()
7741
H1 = N.copy()
78-
HXX = kron2(Xsym, Xsym)
79-
HZZ = kron2(Zsym, Zsym)
80-
H01 = J * (HXX + HZZ)
81-
return Hamiltonian2Qudit(H0, H1, H01)
42+
return H0, H1, Xsym, Zsym, J
8243

8344

8445
# ---------- Ansatz ----------
@@ -127,7 +88,12 @@ def build_ansatz(nlayers: int, d: int, params: Sequence) -> QuditCircuit:
12788
# ---------- Energy ----------
12889

12990

130-
def energy_expectation_backend(params_b, d: int, nlayers: int, ham: Hamiltonian2Qudit):
91+
def energy_expectation_backend(
92+
params_b,
93+
d: int,
94+
nlayers: int,
95+
ham: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, float],
96+
):
13197
"""
13298
params_b: 1D backend tensor (jax/tf) of shape [nparams].
13399
Returns backend scalar.
@@ -137,18 +103,30 @@ def energy_expectation_backend(params_b, d: int, nlayers: int, ham: Hamiltonian2
137103
plist = [params_b[i] for i in range(params_b.shape[0])]
138104
c = build_ansatz(nlayers, d, plist)
139105
E = 0.0 + 0.0j
140-
for op, sites in ham.as_terms():
141-
E = E + c.expectation((tc.gates.Gate(op), list(sites)))
106+
H0, H1, Xsym, Zsym, J = ham
107+
# Local number operators
108+
E = E + c.expectation((tc.gates.Gate(H0), [0]))
109+
E = E + c.expectation((tc.gates.Gate(H1), [1]))
110+
# Coupling terms as products on separate sites (avoids 9x9 reshaping issues)
111+
E = E + J * c.expectation((tc.gates.Gate(Xsym), [0]), (tc.gates.Gate(Xsym), [1]))
112+
E = E + J * c.expectation((tc.gates.Gate(Zsym), [0]), (tc.gates.Gate(Zsym), [1]))
142113
return bk.real(E)
143114

144115

145116
def energy_expectation_numpy(
146-
params_np: np.ndarray, d: int, nlayers: int, ham: Hamiltonian2Qudit
117+
params_np: np.ndarray,
118+
d: int,
119+
nlayers: int,
120+
ham: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, float],
147121
) -> float:
148122
c = build_ansatz(nlayers, d, params_np.tolist())
149123
E = 0.0 + 0.0j
150-
for op, sites in ham.as_terms():
151-
E += c.expectation((tc.gates.Gate(op), list(sites)))
124+
H0, H1, Xsym, Zsym, J = ham
125+
E += c.expectation((tc.gates.Gate(H0), [0]))
126+
E += c.expectation((tc.gates.Gate(H1), [1]))
127+
# Coupling terms as products on separate sites (avoids 9x9 reshaping issues)
128+
E += J * c.expectation((tc.gates.Gate(Xsym), [0]), (tc.gates.Gate(Xsym), [1]))
129+
E += J * c.expectation((tc.gates.Gate(Zsym), [0]), (tc.gates.Gate(Zsym), [1]))
152130
return float(np.real(E))
153131

154132

@@ -159,7 +137,7 @@ def random_search(fun_numpy, x0_shape, iters=300, seed=42):
159137
rng = np.random.default_rng(seed)
160138
best_x, best_y = None, float("inf")
161139
for _ in range(iters):
162-
x = rng.uniform(-math.pi, math.pi, size=x0_shape).astype(np.float32)
140+
x = rng.uniform(-math.pi, math.pi, size=x0_shape)
163141
y = fun_numpy(x)
164142
if y < best_y:
165143
best_x, best_y = x, y
@@ -172,11 +150,11 @@ def gradient_descent_ad(energy_bk, x0_np: np.ndarray, steps=200, lr=0.1, jit=Fal
172150
Simple gradient descent in numpy space with backend-gradients.
173151
"""
174152
bk = tc.backend
175-
if jit and hasattr(bk, "jit"):
153+
if jit:
176154
energy_bk = bk.jit(energy_bk)
177155
grad_f = bk.grad(energy_bk)
178156

179-
x_np = x0_np.astype(np.float32).copy()
157+
x_np = x0_np.copy()
180158
best_x, best_y = x_np.copy(), float("inf")
181159

182160
def to_np(x):
@@ -221,7 +199,7 @@ def main():
221199
ap.add_argument(
222200
"--jit",
223201
action="store_true",
224-
help="enable backend JIT for energy/grad if available",
202+
help="enable backend JIT (all backends implement .jit; numpy backend no-ops)",
225203
)
226204
args = ap.parse_args()
227205

@@ -255,7 +233,7 @@ def obj_bk(theta_b):
255233
return energy_expectation_backend(theta_b, d, L, ham)
256234

257235
rng = np.random.default_rng(args.seed)
258-
x0 = rng.uniform(-math.pi, math.pi, size=(nparams,)).astype(np.float32)
236+
x0 = rng.uniform(-math.pi, math.pi, size=(nparams,))
259237
x, y = gradient_descent_ad(
260238
obj_bk, x0_np=x0, steps=args.steps, lr=args.lr, jit=args.jit
261239
)

0 commit comments

Comments
 (0)