@@ -624,52 +624,76 @@ def chebyshev_evol(
624624 :return: Evolved state
625625 :rtype: Tensor
626626 """
627-
627+ # TODO(@refraction-ray): no support for tf backend as bessel function has no implementation
628628 E_max , E_min = spectral_bounds
629629 if E_max <= E_min :
630630 raise ValueError ("E_max must be > E_min." )
631631
632632 a = (E_max - E_min ) / 2.0
633633 b = (E_max + E_min ) / 2.0
634- tau = a * t # tau is now a scalar
634+ tau = a * t # Rescaled time parameter
635635
636636 def apply_h_norm (psi : Any ) -> Any :
637+ """Applies the normalized Hamiltonian to a state."""
637638 return ((hamiltonian @ psi ) - b * psi ) / a
638639
639- T0 = initial_state
640- if k == 1 :
641- T_k_vectors = T0 [ None , :]
642- else :
643- T1 = apply_h_norm ( T0 )
640+ # Handle edge case where no evolution is needed.
641+ if k == 0 :
642+ # The phase factor still applies even for zero evolution of the series part.
643+ phase = backend . exp ( - 1j * b * t )
644+ return phase * backend . zeros_like ( initial_state )
644645
645- def scan_body (carry , _ ): # type: ignore
646- Tk , Tkm1 = carry
647- Tkp1 = 2 * apply_h_norm (Tk ) - Tkm1
648- return (Tkp1 , Tk ), Tk
646+ # --- 2. Calculate Chebyshev Expansion Coefficients ---
647+ k_indices = backend .arange (k )
648+ bessel_vals = backend .special_jv (k , tau , M )
649649
650- # 假设 backend.jaxy_scan 已正确实现
651- _ , T_k_stack_1_onwards = backend .jaxy_scan (
652- scan_body , (T1 , T0 ), xs = backend .arange (k - 1 )
650+ # Prefactor is 1 for k=0 and 2 for k>0.
651+ prefactor = backend .ones ([k ])
652+ if k > 1 :
653+ # Using concat for backend compatibility (vs. jax's .at[1:].set(2.0))
654+ prefactor = backend .concat (
655+ [backend .ones ([1 ]), backend .ones ([k - 1 ]) * 2.0 ], axis = 0
653656 )
654- T_k_vectors = backend .concat ([T0 [None , :], T_k_stack_1_onwards ], axis = 0 )
655657
656- bessel_vals = backend .special_jv (k , tau , M )
658+ ik_powers = backend .power (0 - 1j , k_indices )
659+ coeffs = prefactor * ik_powers * bessel_vals
657660
658- k_indices = backend .arange (k )
659- first_element = backend .ones ([1 ])
661+ # --- 3. Iteratively build the result using a scan ---
660662
661- remaining_elements = backend .ones ([k - 1 ]) * 2.0
663+ # Handle the simple case of k=1 separately.
664+ if k == 1 :
665+ psi_unphased = coeffs [0 ] * initial_state
666+ else : # k >= 2, use the scan operation.
667+ # Initialize the first two Chebyshev vectors and the initial sum.
668+ T0 = initial_state
669+ T1 = apply_h_norm (T0 )
670+ initial_sum = coeffs [0 ] * T0 + coeffs [1 ] * T1
662671
663- prefactor = backend .concat ([first_element , remaining_elements ], axis = 0 )
664- ik_powers = backend .power (0 - 1j , k_indices )
672+ # The carry for the scan holds the state needed for the next iteration:
673+ # (current vector T_k, previous vector T_{k-1}, and the running sum).
674+ initial_carry = (T1 , T0 , initial_sum )
665675
666- # coeffs 现在是一个清晰的 1D 向量,形状为 (n_terms,)
667- coeffs = prefactor * ik_powers * bessel_vals
676+ def scan_body (carry , i ): # type: ignore
677+ """The body of the scan operation."""
678+ Tk , Tkm1 , current_sum = carry
668679
669- # 求和也变得更简单
670- psi_unphased = backend . einsum ( "k,kD->D" , coeffs , T_k_vectors )
680+ # Calculate the next Chebyshev vector using the recurrence relation.
681+ Tkp1 = 2 * apply_h_norm ( Tk ) - Tkm1
671682
672- # 加上全局相位
683+ # Add its contribution to the running sum.
684+ new_sum = current_sum + coeffs [i ] * Tkp1
685+
686+ # Return the updated carry for the next step. No intermediate output is needed.
687+ return (Tkp1 , Tk , new_sum )
688+
689+ # Run the scan over the remaining coefficients (from index 2 to k-1).
690+ final_carry = backend .scan (scan_body , backend .arange (2 , k ), initial_carry )
691+
692+ # The final result is the sum accumulated in the last carry state.
693+ psi_unphased = final_carry [2 ]
694+
695+ # --- 4. Final Step: Apply Phase Correction ---
696+ # This undoes the energy shift from the Hamiltonian normalization.
673697 phase = backend .exp (- 1j * b * t )
674698 psi_final = phase * psi_unphased
675699
@@ -750,19 +774,19 @@ def estimate_spectral_bounds(
750774 r = backend .convert_to_tensor (r ) # in case np.matrix
751775 r = backend .reshape (r , [- 1 ])
752776 if beta != 0 :
753- r -= beta * q_prev
777+ r -= backend . cast ( beta , dtypestr ) * q_prev
754778
755779 alpha = backend .real (backend .sum (backend .conj (q ) * r ))
756780
757781 alphas .append (alpha )
758782
759- r -= alpha * q
783+ r -= backend . cast ( alpha , dtypestr ) * q
760784
761785 q_prev = q
762786 beta = backend .norm (r )
763787 q = r / beta
788+ beta = backend .abs (beta )
764789 betas .append (beta )
765-
766790 if beta < 1e-8 :
767791 break
768792
0 commit comments