Skip to content

Commit 1c2ac74

Browse files
add chebyshev time evolution
1 parent e8cdd62 commit 1c2ac74

File tree

9 files changed

+831
-6
lines changed

9 files changed

+831
-6
lines changed

examples/chebyshev_evol.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
"""
2+
Chebyshev time evolution example with CLI options.
3+
4+
This script demonstrates the Chebyshev time evolution method with various options
5+
for backend, matrix type, and JIT compilation.
6+
"""
7+
8+
import argparse
9+
from typing import Any, Tuple
10+
import numpy as np
11+
from scipy.linalg import expm
12+
13+
import tensorcircuit as tc
14+
15+
tc.set_dtype("complex128")
16+
17+
18+
def create_heisenberg_hamiltonian(num_sites: int, sparse: bool = True) -> Any:
19+
"""
20+
Create Heisenberg Hamiltonian for a 1D chain.
21+
22+
Args:
23+
num_sites: Number of sites in the chain.
24+
sparse: Whether to create a sparse matrix.
25+
26+
Returns:
27+
Hamiltonian matrix.
28+
"""
29+
graph = tc.templates.graphs.Line1D(num_sites)
30+
return tc.quantum.heisenberg_hamiltonian(graph, sparse=sparse)
31+
32+
33+
def create_initial_state(dim: int) -> Any:
34+
"""
35+
Create initial state as equal superposition.
36+
37+
Args:
38+
dim: Dimension of the Hilbert space.
39+
40+
Returns:
41+
Normalized initial state.
42+
"""
43+
psi0 = tc.backend.ones([dim])
44+
return psi0 / tc.backend.norm(psi0)
45+
46+
47+
def estimate_spectral_bounds(hamiltonian: Any, n_iter: int = 40) -> Tuple[float, float]:
48+
"""
49+
Estimate spectral bounds of the Hamiltonian.
50+
51+
Args:
52+
hamiltonian: Hamiltonian matrix.
53+
n_iter: Number of iterations for Lanczos algorithm.
54+
55+
Returns:
56+
Tuple of (E_max, E_min).
57+
"""
58+
print(f"Estimating spectral bounds of Hamiltonian (iterations: {n_iter})...")
59+
# Ensure the initial vector is compatible with JAX backend
60+
e_max, e_min = tc.timeevol.estimate_spectral_bounds(hamiltonian, n_iter=n_iter)
61+
print(f"Estimated result: E_max = {e_max:.4f}, E_min = {e_min:.4f}")
62+
return float(e_max), float(e_min)
63+
64+
65+
def compare_with_exact_evolution(
66+
hamiltonian: Any,
67+
initial_state: Any,
68+
chebyshev_state: Any,
69+
time: float,
70+
) -> float:
71+
"""
72+
Compare Chebyshev evolution result with exact evolution.
73+
74+
Args:
75+
hamiltonian: Hamiltonian matrix.
76+
initial_state: Initial quantum state.
77+
chebyshev_state: State evolved with Chebyshev method.
78+
time: Evolution time.
79+
80+
Returns:
81+
Fidelity between the two states.
82+
"""
83+
# Exact evolution using matrix exponential
84+
if tc.backend.is_sparse(hamiltonian):
85+
h = tc.backend.to_dense(hamiltonian)
86+
else:
87+
h = hamiltonian
88+
psi_exact = expm(-1j * np.asarray(h) * time) @ np.asarray(initial_state)
89+
90+
fidelity = np.abs(np.vdot(psi_exact, np.asarray(chebyshev_state))) ** 2
91+
return fidelity
92+
93+
94+
def run_chebyshev_evolution(
95+
num_sites: int = 8,
96+
time: float = 500.0,
97+
backend_name: str = "numpy",
98+
sparse: bool = True,
99+
use_jit: bool = False,
100+
) -> None:
101+
"""
102+
Run Chebyshev time evolution with specified parameters.
103+
104+
Args:
105+
num_sites: Number of sites in the system.
106+
time: Evolution time.
107+
backend_name: Backend to use (numpy, jax, tensorflow, pytorch).
108+
sparse: Whether to use sparse matrices.
109+
use_jit: Whether to use JIT compilation.
110+
"""
111+
# Set backend
112+
tc.set_dtype("complex128") # Ensure dtype is set after backend if needed
113+
tc.set_backend(backend_name)
114+
backend = tc.backend
115+
print(f"Using {backend_name} backend")
116+
117+
# Create system
118+
dim = 2**num_sites
119+
graph = tc.templates.graphs.Line1D(num_sites)
120+
h_matrix = tc.quantum.heisenberg_hamiltonian(graph, sparse=sparse)
121+
print(f"Created Heisenberg Hamiltonian for {num_sites} sites")
122+
print(f"Matrix is {'sparse' if sparse else 'dense'}")
123+
124+
# Create initial state
125+
psi0 = create_initial_state(dim)
126+
print("Created initial state (equal superposition)")
127+
128+
# Estimate spectral bounds
129+
e_max, e_min = estimate_spectral_bounds(h_matrix, n_iter=40)
130+
131+
# Prepare Chebyshev evolution function
132+
if use_jit:
133+
chebyshev_evol_jit = backend.jit(
134+
tc.timeevol.chebyshev_evol, static_argnums=(3, 4, 5)
135+
)
136+
chebyshev_function = chebyshev_evol_jit
137+
print("Using JIT compilation")
138+
else:
139+
chebyshev_function = tc.timeevol.chebyshev_evol
140+
print("Not using JIT compilation")
141+
142+
# Perform Chebyshev evolution
143+
print("\nPerforming Chebyshev evolution...")
144+
print("--- Testing single time evolution ---")
145+
k_estimate = tc.timeevol.estimate_k(time, (e_max, e_min))
146+
m_estimate = tc.timeevol.estimate_M(time, (e_max, e_min), k=k_estimate)
147+
print(f"Required M (estimated): {m_estimate}")
148+
print(f"Required k (estimated): {k_estimate}")
149+
150+
psi_cheby = chebyshev_function(
151+
h_matrix,
152+
psi0,
153+
t=time,
154+
spectral_bounds=(e_max + 0.1, e_min - 0.1),
155+
k=k_estimate,
156+
M=m_estimate,
157+
)
158+
159+
norm = tc.backend.norm(psi_cheby)
160+
print(f"Norm of evolved state: {norm}")
161+
162+
# Compare with exact evolution
163+
print("\nComparing with exact evolution...")
164+
fidelity = compare_with_exact_evolution(h_matrix, psi0, psi_cheby, time)
165+
print(f"Fidelity for t={time}: {fidelity:.8f}")
166+
167+
168+
def main() -> None:
169+
"""Main function with CLI argument parsing."""
170+
parser = argparse.ArgumentParser(
171+
description="Chebyshev time evolution example",
172+
formatter_class=argparse.RawDescriptionHelpFormatter,
173+
epilog="""
174+
Example usage:
175+
python chebyshev_evol.py --num_sites 8 --time 500.0
176+
python chebyshev_evol.py --backend jax --jit
177+
python chebyshev_evol.py --dense --backend jax --jit
178+
""",
179+
)
180+
181+
parser.add_argument(
182+
"--num_sites",
183+
type=int,
184+
default=8,
185+
help="Number of sites in the system (default: 8)",
186+
)
187+
188+
parser.add_argument(
189+
"--time",
190+
type=float,
191+
default=500.0,
192+
help="Evolution time (default: 500.0)",
193+
)
194+
195+
parser.add_argument(
196+
"--backend",
197+
type=str,
198+
default="numpy",
199+
choices=["numpy", "jax", "tensorflow", "pytorch"],
200+
help="Backend selection (default: numpy)",
201+
)
202+
203+
parser.add_argument(
204+
"--dense",
205+
dest="sparse",
206+
action="store_false",
207+
help="Use dense matrices instead of sparse",
208+
)
209+
210+
parser.add_argument(
211+
"--sparse",
212+
dest="sparse",
213+
action="store_true",
214+
help="Use sparse matrices (default)",
215+
)
216+
217+
parser.add_argument(
218+
"--jit",
219+
action="store_true",
220+
help="Enable JIT compilation (only works with JAX backend)",
221+
)
222+
223+
parser.set_defaults(sparse=True)
224+
225+
args = parser.parse_args()
226+
227+
run_chebyshev_evolution(
228+
num_sites=args.num_sites,
229+
time=args.time,
230+
backend_name=args.backend,
231+
sparse=args.sparse,
232+
use_jit=args.jit,
233+
)
234+
235+
236+
if __name__ == "__main__":
237+
main()

tensorcircuit/backends/abstract_backend.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,21 @@ def solve(self: Any, A: Tensor, b: Tensor, **kws: Any) -> Tensor:
808808
"Backend '{}' has not implemented `solve`.".format(self.name)
809809
)
810810

811+
def special_jv(self: Any, v: int, z: Tensor, M: int) -> Tensor:
812+
"""
813+
Special function: Bessel function of the first kind.
814+
815+
:param v: The order of the Bessel function.
816+
:type v: int
817+
:param z: The argument of the Bessel function.
818+
:type z: Tensor
819+
:return: The value of the Bessel function [J_0, ...J_{v-1}(z)].
820+
:rtype: Tensor
821+
"""
822+
raise NotImplementedError(
823+
"Backend '{}' has not implemented `special_jv`.".format(self.name)
824+
)
825+
811826
def searchsorted(self: Any, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
812827
"""
813828
Find indices where elements should be inserted to maintain order.
@@ -1400,10 +1415,37 @@ def scan(
14001415
carry = f(carry, x)
14011416

14021417
return carry
1403-
# carry = init
1404-
# for x in xs:
1405-
# carry = f(carry, x)
1406-
# return carry
1418+
1419+
def jaxy_scan(
1420+
self: Any, f: Callable[[Tensor, Tensor], Tensor], init: Tensor, xs: Tensor
1421+
) -> Tensor:
1422+
"""
1423+
This API follows jax scan style. TF use plain for loop
1424+
1425+
:param f: _description_
1426+
:type f: Callable[[Tensor, Tensor], Tensor]
1427+
:param init: _description_
1428+
:type init: Tensor
1429+
:param xs: _description_
1430+
:type xs: Tensor
1431+
:raises ValueError: _description_
1432+
:return: _description_
1433+
:rtype: Tensor
1434+
"""
1435+
if xs is None:
1436+
raise ValueError("Either xs or length must be provided.")
1437+
if xs is not None:
1438+
length = len(xs)
1439+
carry, outputs_to_stack = init, []
1440+
for i in range(length):
1441+
if isinstance(xs, (tuple, list)):
1442+
x = [ele[i] for ele in xs]
1443+
else:
1444+
x = xs[i]
1445+
new_carry, y = f(carry, x)
1446+
carry = new_carry
1447+
outputs_to_stack.append(y)
1448+
return carry, self.stack(outputs_to_stack)
14071449

14081450
def stop_gradient(self: Any, a: Tensor) -> Tensor:
14091451
"""

tensorcircuit/backends/jax_backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,11 @@ def is_tensor(self, a: Any) -> bool:
418418
def solve(self, A: Tensor, b: Tensor, assume_a: str = "gen") -> Tensor: # type: ignore
419419
return jsp.linalg.solve(A, b, assume_a=assume_a)
420420

421+
def special_jv(self, v: int, z: Tensor, M: int) -> Tensor:
422+
from .jax_ops import bessel_jv_jax_rescaled
423+
424+
return bessel_jv_jax_rescaled(v, z, M)
425+
421426
def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
422427
if not self.is_tensor(a):
423428
a = self.convert_to_tensor(a)
@@ -615,6 +620,11 @@ def f_jax(*args: Any, **kws: Any) -> Any:
615620
carry, _ = libjax.lax.scan(f_jax, init, xs)
616621
return carry
617622

623+
def jaxy_scan(
624+
self, f: Callable[[Tensor, Tensor], Tensor], init: Tensor, xs: Tensor
625+
) -> Tensor:
626+
return libjax.lax.scan(f, init, xs)
627+
618628
def scatter(self, operand: Tensor, indices: Tensor, updates: Tensor) -> Tensor:
619629
# updates = jnp.reshape(updates, indices.shape)
620630
# return operand.at[indices].set(updates)

0 commit comments

Comments
 (0)