Skip to content

Commit c86b344

Browse files
add tebd example
1 parent c254d41 commit c86b344

File tree

3 files changed

+179
-1
lines changed

3 files changed

+179
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ TensorCircuit-NG is the actively maintained official version and a [fully compat
3535

3636
Please begin with [Quick Start](/docs/source/quickstart.rst) in the [full documentation](https://tensorcircuit-ng.readthedocs.io/).
3737

38-
For more information on software usage, sota algorithm implementation and engineer paradigm demonstration, please refer to 80+ [example scripts](/examples) and 30+ [tutorial notebooks](https://tensorcircuit-ng.readthedocs.io/en/latest/#tutorials). API docstrings and test cases in [tests](/tests) are also informative.
38+
For more information on software usage, sota algorithm implementation and engineer paradigm demonstration, please refer to 80+ [example scripts](/examples) and 30+ [tutorial notebooks](https://tensorcircuit-ng.readthedocs.io/en/latest/#tutorials). API docstrings and test cases in [tests](/tests) are also informative. One can also refer to tensorcircuit-ng [deepwiki](https://deepwiki.com/tensorcircuit/tensorcircuit-ng) generated by LLM.
3939

4040
For beginners, please refer to [quantum computing lectures with TC-NG](https://github.com/sxzgroup/qc_lecture) to learn both quantum computing basics and representative usage of TensorCircuit-NG.
4141

examples/vqe2d_gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,4 @@ def train_step(param, opt_state):
7878
time0 = time.time()
7979
param, opt_state, losses = train_step(param, opt_state)
8080
print(K.mean(losses), time.time() - time0)
81+
# ~0.017s per iteration on A800

examples/xyzmodel_tebd.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
"""
2+
1D TEBD using MPSCircuit
3+
"""
4+
5+
import time
6+
import numpy as np
7+
import scipy
8+
import tensorcircuit as tc
9+
10+
K = tc.set_backend("jax")
11+
tc.set_dtype("complex128")
12+
13+
14+
def heisenberg_time_evolution_mps(
15+
nqubits: int,
16+
total_time: float,
17+
dt: float,
18+
hxx: float = 1.0,
19+
hyy: float = 1.0,
20+
hzz: float = 1.0,
21+
hz: float = 0.0,
22+
hx: float = 0.0,
23+
hy: float = 0.0,
24+
initial_state=None,
25+
split_rules=None,
26+
):
27+
# Initialize MPS circuit
28+
if initial_state is not None:
29+
mps = tc.MPSCircuit(nqubits, wavefunction=initial_state, split=split_rules)
30+
else:
31+
mps = tc.MPSCircuit(nqubits, split=split_rules)
32+
tensors = mps._mps.tensors
33+
34+
# Number of Trotter steps
35+
nsteps = int(total_time / dt)
36+
dt_step = dt
37+
38+
@K.jit
39+
def apply_trotter_step(mps_tensors):
40+
mps_circuit = tc.MPSCircuit(nqubits, tensors=mps_tensors, split=split_rules)
41+
# Apply odd bonds (1-2, 3-4, ...)
42+
43+
for i in range(0, nqubits, 2):
44+
mps_circuit.rxx(i, (i + 1) % nqubits, theta=hxx * dt_step)
45+
mps_circuit.ryy(i, (i + 1) % nqubits, theta=hyy * dt_step)
46+
mps_circuit.rzz(i, (i + 1) % nqubits, theta=hzz * dt_step)
47+
48+
# Apply even bonds (2-3, 4-5, ...)
49+
for i in range(1, nqubits, 2):
50+
mps_circuit.rxx(i, (i + 1) % nqubits, theta=2 * hxx * dt_step)
51+
mps_circuit.ryy(i, (i + 1) % nqubits, theta=2 * hyy * dt_step)
52+
mps_circuit.rzz(i, (i + 1) % nqubits, theta=2 * hzz * dt_step)
53+
54+
# mps_circuit.unitary(i, unitary=unitary)
55+
56+
for i in range(0, nqubits, 2):
57+
mps_circuit.rxx(i, (i + 1) % nqubits, theta=hxx * dt_step)
58+
mps_circuit.ryy(i, (i + 1) % nqubits, theta=hyy * dt_step)
59+
mps_circuit.rzz(i, (i + 1) % nqubits, theta=hzz * dt_step)
60+
61+
for i in range(nqubits):
62+
mps_circuit.rx(i, theta=2 * hx * dt_step)
63+
mps_circuit.ry(i, theta=2 * hy * dt_step)
64+
mps_circuit.rz(i, theta=2 * hz * dt_step)
65+
66+
return mps_circuit._mps.tensors
67+
68+
# Perform time evolution
69+
for step in range(nsteps):
70+
tensors = apply_trotter_step(tensors)
71+
72+
return tc.MPSCircuit(nqubits, tensors=tensors, split=split_rules)
73+
74+
75+
def compare_baseline():
76+
# Parameters
77+
nqubits = 10
78+
total_time = 1
79+
dt = 0.01
80+
81+
# Heisenberg parameters
82+
hxx = 0.9
83+
hyy = 1.0
84+
hzz = 0.3
85+
hz = -0.1
86+
hy = 0.16
87+
hx = 0.43
88+
89+
split_rules = {"max_singular_values": 32}
90+
91+
c = tc.Circuit(nqubits)
92+
c.x(nqubits // 2)
93+
initial_state = c.state()
94+
95+
# TEBD evolution
96+
final_mps = heisenberg_time_evolution_mps(
97+
nqubits=nqubits,
98+
total_time=total_time,
99+
dt=dt,
100+
hxx=hxx,
101+
hyy=hyy,
102+
hzz=hzz,
103+
hz=hz,
104+
hx=hx,
105+
hy=hy,
106+
initial_state=initial_state,
107+
split_rules=split_rules,
108+
)
109+
110+
# Exact evolution
111+
g = tc.templates.graphs.Line1D(nqubits, pbc=True)
112+
H = tc.quantum.heisenberg_hamiltonian(
113+
g, hxx=hxx, hyy=hyy, hzz=hzz, hz=hz, hy=hy, hx=hx, sparse=False
114+
)
115+
U = scipy.linalg.expm(-1j * total_time * H)
116+
exact_final = K.reshape(U @ K.reshape(initial_state, [-1, 1]), [-1])
117+
118+
# Compare results
119+
mps_state = final_mps.wavefunction()
120+
fidelity = np.abs(np.vdot(exact_final, mps_state)) ** 2
121+
print(f"Fidelity between TEBD and exact evolution: {fidelity}")
122+
c_exact = tc.Circuit(nqubits, inputs=exact_final)
123+
# Measure observables
124+
z_magnetization_mps = []
125+
z_magnetization_exact = []
126+
for i in range(nqubits):
127+
mag_mps = final_mps.expectation((tc.gates.z(), [i]))
128+
z_magnetization_mps.append(mag_mps)
129+
130+
mag_exact = c_exact.expectation((tc.gates.z(), [i]))
131+
z_magnetization_exact.append(mag_exact)
132+
133+
print("MPS Z magnetization:", K.stack(z_magnetization_mps))
134+
print("Exact Z magnetization:", K.stack(z_magnetization_exact))
135+
print("Final bond dimensions:", final_mps.get_bond_dimensions())
136+
137+
return final_mps, exact_final
138+
139+
140+
def benchmark_efficiency(nqubits, bond_d):
141+
total_time = 0.1
142+
dt = 0.01
143+
hxx = 0.9
144+
hyy = 1.0
145+
hzz = 0.3
146+
split_rules = {"max_singular_values": bond_d}
147+
148+
# TEBD evolution
149+
time0 = time.time()
150+
final_mps = heisenberg_time_evolution_mps(
151+
nqubits=nqubits,
152+
total_time=total_time,
153+
dt=dt,
154+
hxx=hxx,
155+
hyy=hyy,
156+
hzz=hzz,
157+
split_rules=split_rules,
158+
)
159+
print(final_mps._mps.tensors[0])
160+
print("cold start run:", time.time() - time0)
161+
time0 = time.time()
162+
final_mps = heisenberg_time_evolution_mps(
163+
nqubits=nqubits,
164+
total_time=total_time,
165+
dt=dt,
166+
hxx=hxx,
167+
hyy=hyy,
168+
hzz=hzz,
169+
split_rules=split_rules,
170+
)
171+
print(final_mps._mps.tensors[0])
172+
print("jitted run:", time.time() - time0)
173+
174+
175+
if __name__ == "__main__":
176+
compare_baseline()
177+
benchmark_efficiency(24, 64)

0 commit comments

Comments
 (0)