Skip to content

Commit 695d1d8

Browse files
memory efficient chebyshev evol
1 parent d35da1d commit 695d1d8

File tree

5 files changed

+143
-38
lines changed

5 files changed

+143
-38
lines changed

tensorcircuit/backends/tensorflow_backend.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,38 @@ def _svd_tf(
360360
tensornetwork.backends.tensorflow.tensorflow_backend.TensorFlowBackend.svd = _svd_tf
361361

362362

363+
def sparse_tensor_matmul(self: Tensor, other: Tensor) -> Tensor:
364+
"""
365+
An implementation of matrix multiplication (@) for tf.SparseTensor.
366+
367+
This function is designed to be monkey-patched onto the tf.SparseTensor class.
368+
It handles multiplication with a dense vector (rank-1 Tensor) by temporarily
369+
promoting it to a matrix (rank-2 Tensor) for the underlying TensorFlow call.
370+
"""
371+
# Ensure the 'other' tensor is of a compatible dtype
372+
if not other.dtype.is_compatible_with(self.dtype):
373+
other = tf.cast(other, self.dtype)
374+
375+
# tf.sparse.sparse_dense_matmul requires the dense tensor to be a 2D matrix.
376+
# If we get a 1D vector, we need to reshape it.
377+
is_vector = len(other.shape) == 1
378+
379+
if is_vector:
380+
# Promote the vector to a column matrix [N] -> [N, 1]
381+
other_matrix = tf.expand_dims(other, axis=1)
382+
else:
383+
other_matrix = other
384+
385+
# Perform the actual multiplication
386+
result_matrix = tf.sparse.sparse_dense_matmul(self, other_matrix)
387+
388+
if is_vector:
389+
# Demote the result matrix back to a vector [M, 1] -> [M]
390+
return tf.squeeze(result_matrix, axis=1)
391+
else:
392+
return result_matrix
393+
394+
363395
class TensorFlowBackend(tensorflow_backend.TensorFlowBackend, ExtendedBackend): # type: ignore
364396
"""
365397
See the original backend API at `tensorflow backend
@@ -378,6 +410,8 @@ def __init__(self) -> None:
378410
)
379411
tf = tensorflow
380412
tf.sparse.SparseTensor.__add__ = tf.sparse.add
413+
tf.SparseTensor.__matmul__ = sparse_tensor_matmul
414+
381415
self.minor = int(tf.__version__.split(".")[1])
382416
self.name = "tensorflow"
383417
logger = tf.get_logger() # .setLevel('ERROR')

tensorcircuit/fgs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def get_alpha(self) -> Tensor:
227227
return self.alpha
228228

229229
def get_cmatrix(self, now_i: bool = True, now_j: bool = True) -> Tensor:
230-
"""
230+
r"""
231231
Calculates the correlation matrix.
232232
233233
The correlation matrix is defined as :math:`C_{ij} = \langle c_i^\dagger c_j \rangle`.
@@ -509,7 +509,7 @@ def orthogonal(self) -> None:
509509

510510
@staticmethod
511511
def hopping(chi: Tensor, i: int, j: int, L: int) -> Tensor:
512-
"""
512+
r"""
513513
Constructs the hopping Hamiltonian between two sites.
514514
515515
The hopping Hamiltonian is given by :math:`\chi c_i^\dagger c_j + h.c.`.
@@ -550,7 +550,7 @@ def evol_hp(self, i: int, j: int, chi: Tensor = 0) -> None:
550550

551551
@staticmethod
552552
def chemical_potential(chi: Tensor, i: int, L: int) -> Tensor:
553-
"""
553+
r"""
554554
Constructs the chemical potential Hamiltonian for a single site.
555555
556556
The chemical potential Hamiltonian is given by :math:`\chi c_i^\dagger c_i`.
@@ -572,7 +572,7 @@ def chemical_potential(chi: Tensor, i: int, L: int) -> Tensor:
572572

573573
@staticmethod
574574
def sc_pairing(chi: Tensor, i: int, j: int, L: int) -> Tensor:
575-
"""
575+
r"""
576576
Constructs the superconducting pairing Hamiltonian between two sites.
577577
578578
The superconducting pairing Hamiltonian is given by :math:`\chi c_i^\dagger c_j^\dagger + h.c.`.
@@ -637,7 +637,7 @@ def evol_icp(self, i: int, chi: Tensor = 0) -> None:
637637
self.evol_ihamiltonian(self.chemical_potential(chi, i, self.L))
638638

639639
def get_bogoliubov_uv(self) -> Tuple[Tensor, Tensor]:
640-
"""
640+
r"""
641641
Returns the u and v matrices of the Bogoliubov transformation.
642642
643643
The Bogoliubov transformation is defined as:

tensorcircuit/timeevol.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -624,52 +624,76 @@ def chebyshev_evol(
624624
:return: Evolved state
625625
:rtype: Tensor
626626
"""
627-
627+
# TODO(@refraction-ray): no support for tf backend as bessel function has no implementation
628628
E_max, E_min = spectral_bounds
629629
if E_max <= E_min:
630630
raise ValueError("E_max must be > E_min.")
631631

632632
a = (E_max - E_min) / 2.0
633633
b = (E_max + E_min) / 2.0
634-
tau = a * t # tau is now a scalar
634+
tau = a * t # Rescaled time parameter
635635

636636
def apply_h_norm(psi: Any) -> Any:
637+
"""Applies the normalized Hamiltonian to a state."""
637638
return ((hamiltonian @ psi) - b * psi) / a
638639

639-
T0 = initial_state
640-
if k == 1:
641-
T_k_vectors = T0[None, :]
642-
else:
643-
T1 = apply_h_norm(T0)
640+
# Handle edge case where no evolution is needed.
641+
if k == 0:
642+
# The phase factor still applies even for zero evolution of the series part.
643+
phase = backend.exp(-1j * b * t)
644+
return phase * backend.zeros_like(initial_state)
644645

645-
def scan_body(carry, _): # type: ignore
646-
Tk, Tkm1 = carry
647-
Tkp1 = 2 * apply_h_norm(Tk) - Tkm1
648-
return (Tkp1, Tk), Tk
646+
# --- 2. Calculate Chebyshev Expansion Coefficients ---
647+
k_indices = backend.arange(k)
648+
bessel_vals = backend.special_jv(k, tau, M)
649649

650-
# 假设 backend.jaxy_scan 已正确实现
651-
_, T_k_stack_1_onwards = backend.jaxy_scan(
652-
scan_body, (T1, T0), xs=backend.arange(k - 1)
650+
# Prefactor is 1 for k=0 and 2 for k>0.
651+
prefactor = backend.ones([k])
652+
if k > 1:
653+
# Using concat for backend compatibility (vs. jax's .at[1:].set(2.0))
654+
prefactor = backend.concat(
655+
[backend.ones([1]), backend.ones([k - 1]) * 2.0], axis=0
653656
)
654-
T_k_vectors = backend.concat([T0[None, :], T_k_stack_1_onwards], axis=0)
655657

656-
bessel_vals = backend.special_jv(k, tau, M)
658+
ik_powers = backend.power(0 - 1j, k_indices)
659+
coeffs = prefactor * ik_powers * bessel_vals
657660

658-
k_indices = backend.arange(k)
659-
first_element = backend.ones([1])
661+
# --- 3. Iteratively build the result using a scan ---
660662

661-
remaining_elements = backend.ones([k - 1]) * 2.0
663+
# Handle the simple case of k=1 separately.
664+
if k == 1:
665+
psi_unphased = coeffs[0] * initial_state
666+
else: # k >= 2, use the scan operation.
667+
# Initialize the first two Chebyshev vectors and the initial sum.
668+
T0 = initial_state
669+
T1 = apply_h_norm(T0)
670+
initial_sum = coeffs[0] * T0 + coeffs[1] * T1
662671

663-
prefactor = backend.concat([first_element, remaining_elements], axis=0)
664-
ik_powers = backend.power(0 - 1j, k_indices)
672+
# The carry for the scan holds the state needed for the next iteration:
673+
# (current vector T_k, previous vector T_{k-1}, and the running sum).
674+
initial_carry = (T1, T0, initial_sum)
665675

666-
# coeffs 现在是一个清晰的 1D 向量,形状为 (n_terms,)
667-
coeffs = prefactor * ik_powers * bessel_vals
676+
def scan_body(carry, i): # type: ignore
677+
"""The body of the scan operation."""
678+
Tk, Tkm1, current_sum = carry
668679

669-
# 求和也变得更简单
670-
psi_unphased = backend.einsum("k,kD->D", coeffs, T_k_vectors)
680+
# Calculate the next Chebyshev vector using the recurrence relation.
681+
Tkp1 = 2 * apply_h_norm(Tk) - Tkm1
671682

672-
# 加上全局相位
683+
# Add its contribution to the running sum.
684+
new_sum = current_sum + coeffs[i] * Tkp1
685+
686+
# Return the updated carry for the next step. No intermediate output is needed.
687+
return (Tkp1, Tk, new_sum)
688+
689+
# Run the scan over the remaining coefficients (from index 2 to k-1).
690+
final_carry = backend.scan(scan_body, backend.arange(2, k), initial_carry)
691+
692+
# The final result is the sum accumulated in the last carry state.
693+
psi_unphased = final_carry[2]
694+
695+
# --- 4. Final Step: Apply Phase Correction ---
696+
# This undoes the energy shift from the Hamiltonian normalization.
673697
phase = backend.exp(-1j * b * t)
674698
psi_final = phase * psi_unphased
675699

@@ -750,19 +774,19 @@ def estimate_spectral_bounds(
750774
r = backend.convert_to_tensor(r) # in case np.matrix
751775
r = backend.reshape(r, [-1])
752776
if beta != 0:
753-
r -= beta * q_prev
777+
r -= backend.cast(beta, dtypestr) * q_prev
754778

755779
alpha = backend.real(backend.sum(backend.conj(q) * r))
756780

757781
alphas.append(alpha)
758782

759-
r -= alpha * q
783+
r -= backend.cast(alpha, dtypestr) * q
760784

761785
q_prev = q
762786
beta = backend.norm(r)
763787
q = r / beta
788+
beta = backend.abs(beta)
764789
betas.append(beta)
765-
766790
if beta < 1e-8:
767791
break
768792

tests/test_backends.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,49 @@ def f(x):
6161
np.testing.assert_allclose(f(a), np.ones([2]), atol=1e-5)
6262

6363

64+
def test_sparse_tensor_matmul_monkey_patch(tfb):
65+
"""
66+
Test the monkey-patched __matmul__ method for tf.SparseTensor.
67+
This test specifically targets the line:
68+
tf.SparseTensor.__matmul__ = sparse_tensor_matmul
69+
"""
70+
# Create a sparse matrix in COO format
71+
indices = tf.constant([[0, 0], [1, 1], [2, 3]], dtype=tf.int64)
72+
values = tf.constant([1.0, 2.0, 3.0], dtype=tf.complex64)
73+
shape = [4, 4]
74+
sparse_matrix = tf.SparseTensor(indices=indices, values=values, dense_shape=shape)
75+
76+
# Test 1: Matrix-vector multiplication with 1D vector
77+
vector_1d = tf.constant([1.0, 2.0, 3.0, 4.0], dtype=tf.complex64)
78+
result_1d = sparse_matrix @ vector_1d # Using the monkey-patched @ operator
79+
80+
expected_1d = tf.constant([1.0, 4.0, 12.0, 0.0], dtype=tf.complex64)
81+
82+
np.testing.assert_allclose(result_1d, expected_1d, atol=1e-6)
83+
vector_1d = tc.backend.reshape(vector_1d, [4, 1])
84+
result_1dn = sparse_matrix @ vector_1d # Using the monkey-patched @ operator
85+
expected_1d = tc.backend.reshape(expected_1d, [4, 1])
86+
87+
np.testing.assert_allclose(result_1dn, expected_1d, atol=1e-6)
88+
89+
# Test 2: Matrix-matrix multiplication with 2D matrix
90+
matrix_2d = tf.constant(
91+
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], dtype=tf.complex64
92+
)
93+
result_2d = sparse_matrix @ matrix_2d # Using the monkey-patched @ operator
94+
95+
expected_2d = tf.sparse.sparse_dense_matmul(sparse_matrix, matrix_2d)
96+
97+
np.testing.assert_allclose(result_2d.numpy(), expected_2d.numpy(), atol=1e-6)
98+
99+
# Test 3: Verify that the operation is consistent with sparse_dense_matmul
100+
101+
reference_result = tc.backend.sparse_dense_matmul(sparse_matrix, vector_1d)
102+
reference_result_squeezed = tc.backend.reshape(reference_result, [-1])
103+
104+
np.testing.assert_allclose(result_1d, reference_result_squeezed, atol=1e-6)
105+
106+
64107
@pytest.mark.parametrize("backend", [lf("npb"), lf("jaxb")])
65108
def test_backend_jv(backend, highp):
66109
def calculate_M(k, x_val):

tests/test_timeevol.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,6 @@ def test_krylov_evol_heisenberg_6_sites(backend):
279279

280280
# Generate Heisenberg Hamiltonian
281281
h = tc.quantum.heisenberg_hamiltonian(g, hzz=1.0, hxx=1.0, hyy=1.0, sparse=False)
282-
print(h.dtype)
283282
# Initial state - all spins up except last one down
284283
psi0 = np.zeros((2**n,))
285284
psi0[62] = 1.0
@@ -454,15 +453,18 @@ def loss_function(t):
454453
print(gradient)
455454

456455

457-
@pytest.mark.parametrize("backend", [lf("npb"), lf("jaxb")])
458-
def test_chebyshev_evol_basic(backend, highp):
456+
@pytest.mark.parametrize(
457+
"backend, sparse",
458+
[[lf("npb"), True], [lf("npb"), False], [lf("jaxb"), True], [lf("jaxb"), False]],
459+
)
460+
def test_chebyshev_evol_basic(backend, highp, sparse):
459461
n = 6
460462
# Create a 1D chain graph
461463
g = tc.templates.graphs.Line1D(n, pbc=False)
462464

463465
# Generate Heisenberg Hamiltonian (dense for better compatibility)
464466
h = tc.quantum.heisenberg_hamiltonian(
465-
g, hzz=1.0, hxx=1.0, hyy=1.0, hx=0.2, sparse=False
467+
g, hzz=1.0, hxx=1.0, hyy=1.0, hx=0.2, sparse=sparse
466468
)
467469

468470
# Initial Neel state: |↑↓↑↓⟩
@@ -490,6 +492,8 @@ def test_chebyshev_evol_basic(backend, highp):
490492
np.testing.assert_allclose(norm, 1.0, atol=1e-3)
491493

492494
# Compare with exact evolution for small system
495+
if sparse is True:
496+
h = tc.backend.to_dense(h)
493497
psi_exact = tc.timeevol.ed_evol(h, psi0, 1.0j * tc.backend.convert_to_tensor([t]))[
494498
0
495499
]

0 commit comments

Comments
 (0)