Skip to content

Commit 26f9a52

Browse files
scan implementation of krylov
1 parent b119e4c commit 26f9a52

File tree

7 files changed

+304
-49
lines changed

7 files changed

+304
-49
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66

77
- Add new module `tc.timeevol` for different types of time evolution solvers.
88

9+
### Fixed
10+
11+
- Fixed `one_hot` in numpy backend.
12+
13+
- Fixed `scan` in tensorflow backend and numpy backend.
14+
915
## v1.3.0
1016

1117
### Added

examples/krylov_time_evolution.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def run_comprehensive_analysis(
9898
verbose: bool = True,
9999
backend_name: str = "numpy",
100100
use_jit: bool = False,
101+
scan_impl: bool = False,
101102
) -> Tuple[Dict, Dict]:
102103
"""
103104
Run comprehensive analysis showing fidelity, magnetization error and runtime
@@ -111,6 +112,7 @@ def run_comprehensive_analysis(
111112
- verbose: Whether to show detailed output
112113
- backend_name: Backend name
113114
- use_jit: Whether to use JIT compilation
115+
- scan_impl: Whether to use scan implementation in krylov_evol
114116
"""
115117
if time_points is None:
116118
time_points = [1.0, 2.0, 5.0]
@@ -126,7 +128,7 @@ def run_comprehensive_analysis(
126128
if use_jit:
127129
# Create wrapper function to avoid passing backend parameter to JIT
128130
def krylov_evol_wrapper(h, psi0, tlist, m):
129-
return krylov_evol(h, psi0, tlist, m)
131+
return krylov_evol(h, psi0, tlist, m, scan_impl=scan_impl)
130132

131133
# JIT compile function (static parameter is m)
132134
krylov_evol_jit = backend.jit(krylov_evol_wrapper, static_argnums=(3,))
@@ -135,7 +137,9 @@ def krylov_evol_wrapper(h, psi0, tlist, m):
135137
jit_info = " (using JIT compilation)"
136138
else:
137139
# Use regular version of function
138-
krylov_function = krylov_evol
140+
krylov_function = lambda h, psi0, tlist, m: krylov_evol(
141+
h, psi0, tlist, m, scan_impl=scan_impl
142+
)
139143
jit_info = ""
140144

141145
if verbose:
@@ -256,50 +260,44 @@ def krylov_evol_wrapper(h, psi0, tlist, m):
256260
reference_m = None
257261

258262
for m in subspace_dims:
259-
try:
260-
start_time = time.time()
261-
# Use correct Krylov function (possibly JIT version)
262-
# Key fix: Ensure correct parameters are passed to Krylov function
263-
krylov_result = krylov_function(
264-
hamiltonian_sparse,
265-
initial_state,
266-
[t],
267-
int(m), # Ensure m is integer
268-
)
269-
krylov_time = time.time() - start_time
270-
271-
# Extract result (krylov_result is an array containing a single time point)
272-
# Fix: Properly handle result shape
273-
if hasattr(krylov_result, "shape") and len(krylov_result.shape) > 1:
274-
evolved_state = krylov_result[
275-
0
276-
] # Take the first (and only) time point result
277-
else:
278-
evolved_state = krylov_result
279-
280-
# Calculate magnetization
281-
magnetization = compute_magnetization(evolved_state, backend)
282-
283-
results[t][m] = {
284-
"state": evolved_state,
285-
"magnetization": magnetization,
286-
"time": krylov_time,
287-
}
288263

289-
if verbose:
290-
print(
291-
f" m = {m:3d}: Time {krylov_time:.4f}s, Magnetization {magnetization.real:8.6f}"
292-
)
264+
start_time = time.time()
265+
krylov_result = krylov_function(
266+
hamiltonian_sparse,
267+
initial_state,
268+
[t],
269+
int(m), # Ensure m is integer
270+
)
271+
print(krylov_result[0, 0])
272+
krylov_time = time.time() - start_time
273+
274+
# Extract result (krylov_result is an array containing a single time point)
275+
if hasattr(krylov_result, "shape") and len(krylov_result.shape) > 1:
276+
evolved_state = krylov_result[
277+
0
278+
] # Take the first (and only) time point result
279+
else:
280+
evolved_state = krylov_result
293281

294-
# Set reference result (using largest m value, only when no exact result)
295-
# Fix: Only use Krylov result as reference when no exact result available
296-
if reference_result is None and t not in exact_results:
297-
reference_result = results[t][m]
298-
reference_m = m
282+
# Calculate magnetization
283+
magnetization = compute_magnetization(evolved_state, backend)
299284

300-
except Exception as exc:
301-
if verbose:
302-
print(f" m = {m:3d}: Failed - {str(exc)}")
285+
results[t][m] = {
286+
"state": evolved_state,
287+
"magnetization": magnetization,
288+
"time": krylov_time,
289+
}
290+
291+
if verbose:
292+
print(
293+
f" m = {m:3d}: Time {krylov_time:.4f}s, Magnetization {magnetization.real:8.6f}"
294+
)
295+
296+
# Set reference result (using largest m value, only when no exact result)
297+
# Fix: Only use Krylov result as reference when no exact result available
298+
if reference_result is None and t not in exact_results:
299+
reference_result = results[t][m]
300+
reference_m = m
303301

304302
# Compare with exact results (if available)
305303
if t in exact_results and reference_result:
@@ -427,6 +425,11 @@ def main() -> None:
427425
help="Backend selection (default: numpy)",
428426
)
429427
parser.add_argument("--jit", action="store_true", help="Enable JIT compilation")
428+
parser.add_argument(
429+
"--scan_impl",
430+
action="store_true",
431+
help="Use scan implementation in krylov_evol",
432+
)
430433

431434
args = parser.parse_args()
432435

@@ -443,6 +446,7 @@ def main() -> None:
443446
verbose=not args.quiet,
444447
backend_name=args.backend,
445448
use_jit=args.jit,
449+
scan_impl=args.scan_impl,
446450
)
447451
except KeyboardInterrupt:
448452
print("\nUser interrupted program execution")

tensorcircuit/backends/abstract_backend.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,9 +1389,21 @@ def scan(
13891389
:rtype: Tensor
13901390
"""
13911391
carry = init
1392-
for x in xs:
1393-
carry = f(carry, x)
1392+
# Check if `xs` is a PyTree (tuple or list) of arrays.
1393+
if isinstance(xs, (tuple, list)):
1394+
for x_slice_tuple in zip(*xs):
1395+
# x_slice_tuple will be (k_elems[i], j_elems[i]) at each step.
1396+
carry = f(carry, x_slice_tuple)
1397+
else:
1398+
# If xs is a single array, iterate normally.
1399+
for x in xs:
1400+
carry = f(carry, x)
1401+
13941402
return carry
1403+
# carry = init
1404+
# for x in xs:
1405+
# carry = f(carry, x)
1406+
# return carry
13951407

13961408
def stop_gradient(self: Any, a: Tensor) -> Tensor:
13971409
"""

tensorcircuit/backends/numpy_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def softmax(self, a: Sequence[Tensor], axis: Optional[int] = None) -> Tensor:
200200
return softmax(a, axis=axis)
201201

202202
def onehot(self, a: Tensor, num: int) -> Tensor:
203+
a = np.asarray(a)
203204
res = np.eye(num)[a.reshape([-1])]
204205
return res.reshape(list(a.shape) + [num])
205206
# https://stackoverflow.com/questions/38592324/one-hot-encoding-using-numpy

tensorcircuit/backends/tensorflow_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,10 @@ def switch(self, index: Tensor, branches: Sequence[Callable[[], Tensor]]) -> Ten
719719
def scan(
720720
self, f: Callable[[Tensor, Tensor], Tensor], xs: Tensor, init: Tensor
721721
) -> Tensor:
722-
return tf.scan(f, xs, init)[-1]
722+
stacked_results = tf.scan(f, xs, init)
723+
final_state = tf.nest.map_structure(lambda x: x[-1], stacked_results)
724+
return final_state
725+
# return tf.scan(f, xs, init)[-1]
723726

724727
def device(self, a: Tensor) -> str:
725728
dev = a.device

tensorcircuit/timeevol.py

Lines changed: 162 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,154 @@
1111
Circuit = Any
1212

1313

14+
def lanczos_iteration_scan(
15+
hamiltonian: Any, initial_vector: Any, subspace_dimension: int
16+
) -> Tuple[Any, Any]:
17+
"""
18+
Use Lanczos algorithm to construct orthogonal basis and projected Hamiltonian
19+
of Krylov subspace, using `tc.backend.scan` for JIT compatibility.
20+
21+
:param hamiltonian: Sparse or dense Hamiltonian matrix
22+
:type hamiltonian: Tensor
23+
:param initial_vector: Initial quantum state vector
24+
:type initial_vector: Tensor
25+
:param subspace_dimension: Dimension of Krylov subspace
26+
:type subspace_dimension: int
27+
:return: Tuple containing (basis matrix, projected Hamiltonian)
28+
:rtype: Tuple[Tensor, Tensor]
29+
"""
30+
state_size = backend.shape_tuple(initial_vector)[0]
31+
32+
# Main scan body for the outer loop (iterating j)
33+
def lanczos_step(carry: Tuple[Any, ...], j: int) -> Tuple[Any, ...]:
34+
v, basis, alphas, betas = carry
35+
36+
if backend.is_sparse(hamiltonian):
37+
w = backend.sparse_dense_matmul(hamiltonian, v)
38+
else:
39+
w = backend.matvec(hamiltonian, v)
40+
41+
alpha = backend.real(backend.sum(backend.conj(v) * w))
42+
w = w - backend.cast(alpha, dtypestr) * v
43+
44+
# Inner scan for re-orthogonalization (iterating k)
45+
# def ortho_step(inner_carry: Tuple[Any, Any], k: int) -> Tuple[Any, Any]:
46+
# w_carry, j_val = inner_carry
47+
48+
# def do_projection() -> Any:
49+
# # `basis` is available here through closure
50+
# v_k = basis[:, k]
51+
# projection = backend.sum(backend.conj(v_k) * w_carry)
52+
# return w_carry - projection * v_k
53+
54+
# def do_nothing() -> Any:
55+
# return w_carry
56+
57+
# # Orthogonalize against v_0, ..., v_j
58+
# w_new = backend.cond(k <= j_val, do_projection, do_nothing)
59+
# return (w_new, j_val) # Return the new carry for the inner loop
60+
61+
# # Pass `j` into the inner scan's carry
62+
# inner_init_carry = (w, j)
63+
# final_inner_carry = backend.scan(
64+
# ortho_step, backend.arange(subspace_dimension), inner_init_carry
65+
# )
66+
# w_ortho = final_inner_carry[0]
67+
# print(")}]", final_inner_carry, w_ortho)
68+
69+
def ortho_step(w_carry: Any, elems_tuple: Tuple[Any, Any]) -> Any:
70+
k, j_from_elems = elems_tuple
71+
72+
def do_projection() -> Any:
73+
v_k = basis[:, k]
74+
projection = backend.sum(backend.conj(v_k) * w_carry)
75+
return w_carry - projection * v_k
76+
77+
def do_nothing() -> Any:
78+
return backend.cast(w_carry, dtype=dtypestr)
79+
80+
w_new = backend.cond(k <= j_from_elems, do_projection, do_nothing)
81+
return w_new
82+
83+
k_elems = backend.arange(subspace_dimension)
84+
j_elems = backend.tile(backend.reshape(j, [1]), [subspace_dimension])
85+
inner_elems = (k_elems, j_elems)
86+
w_ortho = backend.scan(ortho_step, inner_elems, w)
87+
88+
beta = backend.norm(w_ortho)
89+
beta = backend.real(beta)
90+
91+
# Update alphas and betas arrays
92+
new_alphas = backend.scatter(
93+
alphas, backend.reshape(j, [1, 1]), backend.reshape(alpha, [1])
94+
)
95+
new_betas = backend.scatter(
96+
betas, backend.reshape(j, [1, 1]), backend.reshape(beta, [1])
97+
)
98+
99+
def update_state_fn() -> Tuple[Any, Any]:
100+
epsilon = 1e-15
101+
next_v = w_ortho / backend.cast(beta + epsilon, dtypestr)
102+
103+
one_hot_update = backend.onehot(j + 1, subspace_dimension)
104+
one_hot_update = backend.cast(one_hot_update, dtype=dtypestr)
105+
106+
# Create a mask to update only the (j+1)-th column
107+
mask = 1.0 - backend.reshape(one_hot_update, [1, subspace_dimension])
108+
new_basis = basis * mask + backend.reshape(
109+
next_v, [-1, 1]
110+
) * backend.reshape(one_hot_update, [1, subspace_dimension])
111+
112+
return next_v, new_basis
113+
114+
def keep_state_fn() -> Tuple[Any, Any]:
115+
return v, basis
116+
117+
next_v_carry, new_basis = backend.cond(
118+
j < subspace_dimension - 1, update_state_fn, keep_state_fn
119+
)
120+
121+
return (next_v_carry, new_basis, new_alphas, new_betas)
122+
123+
# Prepare initial state for the main scan
124+
v0 = initial_vector / backend.norm(initial_vector)
125+
126+
init_basis = backend.zeros((state_size, subspace_dimension), dtype=dtypestr)
127+
init_alphas = backend.zeros((subspace_dimension,), dtype=rdtypestr)
128+
init_betas = backend.zeros((subspace_dimension,), dtype=rdtypestr)
129+
130+
one_hot_0 = backend.onehot(0, subspace_dimension)
131+
one_hot_0 = backend.cast(one_hot_0, dtype=dtypestr)
132+
init_basis = init_basis + backend.reshape(v0, [-1, 1]) * backend.reshape(
133+
one_hot_0, [1, subspace_dimension]
134+
)
135+
136+
init_carry = (v0, init_basis, init_alphas, init_betas)
137+
138+
# Run the main scan
139+
final_carry = backend.scan(
140+
lanczos_step, backend.arange(subspace_dimension), init_carry
141+
)
142+
basis_matrix, alphas_tensor, betas_tensor = (
143+
final_carry[1],
144+
final_carry[2],
145+
final_carry[3],
146+
)
147+
148+
betas_off_diag = betas_tensor[:-1]
149+
150+
diag_part = backend.diagflat(alphas_tensor)
151+
if backend.shape_tuple(betas_off_diag)[0] > 0:
152+
off_diag_part = backend.diagflat(betas_off_diag, k=1)
153+
projected_hamiltonian = (
154+
diag_part + off_diag_part + backend.conj(backend.transpose(off_diag_part))
155+
)
156+
else:
157+
projected_hamiltonian = diag_part
158+
159+
return basis_matrix, projected_hamiltonian
160+
161+
14162
def lanczos_iteration(
15163
hamiltonian: Tensor, initial_vector: Tensor, subspace_dimension: int
16164
) -> Tuple[Tensor, Tensor]:
@@ -116,6 +264,7 @@ def krylov_evol(
116264
time_points: Tensor,
117265
subspace_dimension: int,
118266
callback: Optional[Callable[[Any], Any]] = None,
267+
scan_impl: bool = False,
119268
) -> Any:
120269
"""
121270
Perform quantum state time evolution using Krylov subspace method.
@@ -131,14 +280,23 @@ def krylov_evol(
131280
:param callback: Optional callback function applied to quantum state at
132281
each evolution time point, return some observables
133282
:type callback: Optional[Callable[[Any], Any]], optional
283+
:param scan_impl: whether use scan implementation, suitable for jit but may be slow on numpy
284+
defaults False, True not work for tensorflow backend + jit, due to stupid issue of tensorflow
285+
context separation and the notorious inaccesibletensor error
286+
:type scan_impl: bool, optional
134287
:return: List of evolved quantum states, or list of callback function results
135288
(if callback provided)
136289
:rtype: Any
137290
"""
138291
# TODO(@refraction-ray): stable and efficient AD is to be investigated
139-
basis_matrix, projected_hamiltonian = lanczos_iteration(
140-
hamiltonian, initial_state, subspace_dimension
141-
)
292+
if not scan_impl:
293+
basis_matrix, projected_hamiltonian = lanczos_iteration(
294+
hamiltonian, initial_state, subspace_dimension
295+
)
296+
else:
297+
basis_matrix, projected_hamiltonian = lanczos_iteration_scan(
298+
hamiltonian, initial_state, subspace_dimension
299+
)
142300
initial_state = backend.cast(initial_state, dtypestr)
143301
# Project initial state to Krylov subspace: |psi_proj> = V_m^† |psi(0)>
144302
projected_state = backend.matvec(
@@ -148,6 +306,7 @@ def krylov_evol(
148306
# Perform spectral decomposition of projected Hamiltonian: T_m = U D U^†
149307
eigenvalues, eigenvectors = backend.eigh(projected_hamiltonian)
150308
eigenvalues = backend.cast(eigenvalues, dtypestr)
309+
eigenvectors = backend.cast(eigenvectors, dtypestr)
151310
time_points = backend.convert_to_tensor(time_points)
152311
time_points = backend.cast(time_points, dtypestr)
153312

0 commit comments

Comments
 (0)