1111Circuit = Any
1212
1313
14+ def lanczos_iteration_scan (
15+ hamiltonian : Any , initial_vector : Any , subspace_dimension : int
16+ ) -> Tuple [Any , Any ]:
17+ """
18+ Use Lanczos algorithm to construct orthogonal basis and projected Hamiltonian
19+ of Krylov subspace, using `tc.backend.scan` for JIT compatibility.
20+
21+ :param hamiltonian: Sparse or dense Hamiltonian matrix
22+ :type hamiltonian: Tensor
23+ :param initial_vector: Initial quantum state vector
24+ :type initial_vector: Tensor
25+ :param subspace_dimension: Dimension of Krylov subspace
26+ :type subspace_dimension: int
27+ :return: Tuple containing (basis matrix, projected Hamiltonian)
28+ :rtype: Tuple[Tensor, Tensor]
29+ """
30+ state_size = backend .shape_tuple (initial_vector )[0 ]
31+
32+ # Main scan body for the outer loop (iterating j)
33+ def lanczos_step (carry : Tuple [Any , ...], j : int ) -> Tuple [Any , ...]:
34+ v , basis , alphas , betas = carry
35+
36+ if backend .is_sparse (hamiltonian ):
37+ w = backend .sparse_dense_matmul (hamiltonian , v )
38+ else :
39+ w = backend .matvec (hamiltonian , v )
40+
41+ alpha = backend .real (backend .sum (backend .conj (v ) * w ))
42+ w = w - backend .cast (alpha , dtypestr ) * v
43+
44+ # Inner scan for re-orthogonalization (iterating k)
45+ # def ortho_step(inner_carry: Tuple[Any, Any], k: int) -> Tuple[Any, Any]:
46+ # w_carry, j_val = inner_carry
47+
48+ # def do_projection() -> Any:
49+ # # `basis` is available here through closure
50+ # v_k = basis[:, k]
51+ # projection = backend.sum(backend.conj(v_k) * w_carry)
52+ # return w_carry - projection * v_k
53+
54+ # def do_nothing() -> Any:
55+ # return w_carry
56+
57+ # # Orthogonalize against v_0, ..., v_j
58+ # w_new = backend.cond(k <= j_val, do_projection, do_nothing)
59+ # return (w_new, j_val) # Return the new carry for the inner loop
60+
61+ # # Pass `j` into the inner scan's carry
62+ # inner_init_carry = (w, j)
63+ # final_inner_carry = backend.scan(
64+ # ortho_step, backend.arange(subspace_dimension), inner_init_carry
65+ # )
66+ # w_ortho = final_inner_carry[0]
67+ # print(")}]", final_inner_carry, w_ortho)
68+
69+ def ortho_step (w_carry : Any , elems_tuple : Tuple [Any , Any ]) -> Any :
70+ k , j_from_elems = elems_tuple
71+
72+ def do_projection () -> Any :
73+ v_k = basis [:, k ]
74+ projection = backend .sum (backend .conj (v_k ) * w_carry )
75+ return w_carry - projection * v_k
76+
77+ def do_nothing () -> Any :
78+ return backend .cast (w_carry , dtype = dtypestr )
79+
80+ w_new = backend .cond (k <= j_from_elems , do_projection , do_nothing )
81+ return w_new
82+
83+ k_elems = backend .arange (subspace_dimension )
84+ j_elems = backend .tile (backend .reshape (j , [1 ]), [subspace_dimension ])
85+ inner_elems = (k_elems , j_elems )
86+ w_ortho = backend .scan (ortho_step , inner_elems , w )
87+
88+ beta = backend .norm (w_ortho )
89+ beta = backend .real (beta )
90+
91+ # Update alphas and betas arrays
92+ new_alphas = backend .scatter (
93+ alphas , backend .reshape (j , [1 , 1 ]), backend .reshape (alpha , [1 ])
94+ )
95+ new_betas = backend .scatter (
96+ betas , backend .reshape (j , [1 , 1 ]), backend .reshape (beta , [1 ])
97+ )
98+
99+ def update_state_fn () -> Tuple [Any , Any ]:
100+ epsilon = 1e-15
101+ next_v = w_ortho / backend .cast (beta + epsilon , dtypestr )
102+
103+ one_hot_update = backend .onehot (j + 1 , subspace_dimension )
104+ one_hot_update = backend .cast (one_hot_update , dtype = dtypestr )
105+
106+ # Create a mask to update only the (j+1)-th column
107+ mask = 1.0 - backend .reshape (one_hot_update , [1 , subspace_dimension ])
108+ new_basis = basis * mask + backend .reshape (
109+ next_v , [- 1 , 1 ]
110+ ) * backend .reshape (one_hot_update , [1 , subspace_dimension ])
111+
112+ return next_v , new_basis
113+
114+ def keep_state_fn () -> Tuple [Any , Any ]:
115+ return v , basis
116+
117+ next_v_carry , new_basis = backend .cond (
118+ j < subspace_dimension - 1 , update_state_fn , keep_state_fn
119+ )
120+
121+ return (next_v_carry , new_basis , new_alphas , new_betas )
122+
123+ # Prepare initial state for the main scan
124+ v0 = initial_vector / backend .norm (initial_vector )
125+
126+ init_basis = backend .zeros ((state_size , subspace_dimension ), dtype = dtypestr )
127+ init_alphas = backend .zeros ((subspace_dimension ,), dtype = rdtypestr )
128+ init_betas = backend .zeros ((subspace_dimension ,), dtype = rdtypestr )
129+
130+ one_hot_0 = backend .onehot (0 , subspace_dimension )
131+ one_hot_0 = backend .cast (one_hot_0 , dtype = dtypestr )
132+ init_basis = init_basis + backend .reshape (v0 , [- 1 , 1 ]) * backend .reshape (
133+ one_hot_0 , [1 , subspace_dimension ]
134+ )
135+
136+ init_carry = (v0 , init_basis , init_alphas , init_betas )
137+
138+ # Run the main scan
139+ final_carry = backend .scan (
140+ lanczos_step , backend .arange (subspace_dimension ), init_carry
141+ )
142+ basis_matrix , alphas_tensor , betas_tensor = (
143+ final_carry [1 ],
144+ final_carry [2 ],
145+ final_carry [3 ],
146+ )
147+
148+ betas_off_diag = betas_tensor [:- 1 ]
149+
150+ diag_part = backend .diagflat (alphas_tensor )
151+ if backend .shape_tuple (betas_off_diag )[0 ] > 0 :
152+ off_diag_part = backend .diagflat (betas_off_diag , k = 1 )
153+ projected_hamiltonian = (
154+ diag_part + off_diag_part + backend .conj (backend .transpose (off_diag_part ))
155+ )
156+ else :
157+ projected_hamiltonian = diag_part
158+
159+ return basis_matrix , projected_hamiltonian
160+
161+
14162def lanczos_iteration (
15163 hamiltonian : Tensor , initial_vector : Tensor , subspace_dimension : int
16164) -> Tuple [Tensor , Tensor ]:
@@ -116,6 +264,7 @@ def krylov_evol(
116264 time_points : Tensor ,
117265 subspace_dimension : int ,
118266 callback : Optional [Callable [[Any ], Any ]] = None ,
267+ scan_impl : bool = False ,
119268) -> Any :
120269 """
121270 Perform quantum state time evolution using Krylov subspace method.
@@ -131,14 +280,23 @@ def krylov_evol(
131280 :param callback: Optional callback function applied to quantum state at
132281 each evolution time point, return some observables
133282 :type callback: Optional[Callable[[Any], Any]], optional
283+ :param scan_impl: whether use scan implementation, suitable for jit but may be slow on numpy
284+ defaults False, True not work for tensorflow backend + jit, due to stupid issue of tensorflow
285+ context separation and the notorious inaccesibletensor error
286+ :type scan_impl: bool, optional
134287 :return: List of evolved quantum states, or list of callback function results
135288 (if callback provided)
136289 :rtype: Any
137290 """
138291 # TODO(@refraction-ray): stable and efficient AD is to be investigated
139- basis_matrix , projected_hamiltonian = lanczos_iteration (
140- hamiltonian , initial_state , subspace_dimension
141- )
292+ if not scan_impl :
293+ basis_matrix , projected_hamiltonian = lanczos_iteration (
294+ hamiltonian , initial_state , subspace_dimension
295+ )
296+ else :
297+ basis_matrix , projected_hamiltonian = lanczos_iteration_scan (
298+ hamiltonian , initial_state , subspace_dimension
299+ )
142300 initial_state = backend .cast (initial_state , dtypestr )
143301 # Project initial state to Krylov subspace: |psi_proj> = V_m^† |psi(0)>
144302 projected_state = backend .matvec (
@@ -148,6 +306,7 @@ def krylov_evol(
148306 # Perform spectral decomposition of projected Hamiltonian: T_m = U D U^†
149307 eigenvalues , eigenvectors = backend .eigh (projected_hamiltonian )
150308 eigenvalues = backend .cast (eigenvalues , dtypestr )
309+ eigenvectors = backend .cast (eigenvectors , dtypestr )
151310 time_points = backend .convert_to_tensor (time_points )
152311 time_points = backend .cast (time_points , dtypestr )
153312
0 commit comments