Skip to content

Commit 9d95a9f

Browse files
add csr sparse support and accelerate krylov/chebyshev
1 parent 695d1d8 commit 9d95a9f

File tree

6 files changed

+64
-1
lines changed

6 files changed

+64
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
- Add `jaxy_scan` for numpy and jax backends, contrasting to the tf style original backend scan method.
1212

13+
- Add `sparse_csr_from_coo` method for numpy and jax backends to convert COO format to CSR format, the latter is more efficient for `sparse_dense_matmul`.
14+
1315
- Add `krylov_evol` method for krylov evolution.
1416

1517
- Add `chebyshev_evol` method for chebyshev polynomial evolution.

tensorcircuit/backends/abstract_backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,6 +1266,26 @@ def sparse_dense_matmul(
12661266
"Backend '{}' has not implemented `sparse_dense_matmul`.".format(self.name)
12671267
)
12681268

1269+
def sparse_csr_from_coo(self: Any, coo: Tensor, strict: bool = False) -> Tensor:
1270+
"""
1271+
transform a coo matrix to a csr matrix
1272+
1273+
:param coo: a coo matrix
1274+
:type coo: Tensor
1275+
:param strict: whether to enforce the transform, defaults to False,
1276+
corresponding to return the coo matrix if there is no implementation for specific backend.
1277+
:type strict: bool, optional
1278+
:return: a csr matrix
1279+
:rtype: Tensor
1280+
"""
1281+
if strict:
1282+
raise NotImplementedError(
1283+
"Backend '{}' has not implemented `sparse_csr_from_coo`.".format(
1284+
self.name
1285+
)
1286+
)
1287+
return coo
1288+
12691289
def to_dense(self: Any, sp_a: Tensor) -> Tensor:
12701290
"""
12711291
Convert a sparse matrix to dense tensor.

tensorcircuit/backends/jax_backend.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,11 +649,20 @@ def sparse_dense_matmul(
649649
) -> Tensor:
650650
return sp_a @ b
651651

652+
def sparse_csr_from_coo(self, coo: Tensor, strict: bool = False) -> Tensor:
653+
try:
654+
return sparse.BCSR.from_bcoo(coo) # type: ignore
655+
except AttributeError as e:
656+
if not strict:
657+
return coo
658+
else:
659+
raise e
660+
652661
def to_dense(self, sp_a: Tensor) -> Tensor:
653662
return sp_a.todense()
654663

655664
def is_sparse(self, a: Tensor) -> bool:
656-
return isinstance(a, sparse.BCOO) # type: ignore
665+
return isinstance(a, sparse.JAXSparse) # type: ignore
657666

658667
def device(self, a: Tensor) -> str:
659668
(dev,) = a.devices()

tensorcircuit/backends/numpy_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,9 @@ def sparse_dense_matmul(
333333
) -> Tensor:
334334
return sp_a @ b
335335

336+
def sparse_csr_from_coo(self, coo: Tensor, strict: bool = False) -> Tensor:
337+
return coo.tocsr()
338+
336339
def to_dense(self, sp_a: Tensor) -> Tensor:
337340
return sp_a.todense()
338341

tensorcircuit/timeevol.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def lanczos_iteration_scan(
3232
:rtype: Tuple[Tensor, Tensor]
3333
"""
3434
state_size = backend.shape_tuple(initial_vector)[0]
35+
if backend.is_sparse(hamiltonian):
36+
hamiltonian = backend.sparse_csr_from_coo(hamiltonian)
3537

3638
# Main scan body for the outer loop (iterating j)
3739
def lanczos_step(carry: Tuple[Any, ...], j: int) -> Tuple[Any, ...]:
@@ -196,6 +198,9 @@ def lanczos_iteration(
196198
# Add first basis vector
197199
basis_vectors.append(vector)
198200

201+
if backend.is_sparse(hamiltonian):
202+
hamiltonian = backend.sparse_csr_from_coo(hamiltonian)
203+
199204
# Lanczos iteration (fixed number of iterations for JIT compatibility)
200205
for j in range(subspace_dimension):
201206
# Calculate H|v_j>
@@ -633,6 +638,9 @@ def chebyshev_evol(
633638
b = (E_max + E_min) / 2.0
634639
tau = a * t # Rescaled time parameter
635640

641+
if backend.is_sparse(hamiltonian):
642+
hamiltonian = backend.sparse_csr_from_coo(hamiltonian)
643+
636644
def apply_h_norm(psi: Any) -> Any:
637645
"""Applies the normalized Hamiltonian to a state."""
638646
return ((hamiltonian @ psi) - b * psi) / a

tests/test_backends.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,27 @@ def f(x):
6161
np.testing.assert_allclose(f(a), np.ones([2]), atol=1e-5)
6262

6363

64+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
65+
def test_sparse_csr_from_coo(backend):
66+
# Create a sparse matrix in COO format
67+
values = tc.backend.convert_to_tensor(np.array([1.0, 2.0, 3.0]))
68+
values = tc.backend.cast(values, "complex64")
69+
indices = tc.backend.convert_to_tensor(np.array([[0, 0], [1, 1], [2, 3]]))
70+
indices = tc.backend.cast(indices, "int64")
71+
coo_matrix = tc.backend.coo_sparse_matrix(indices, values, shape=[4, 4])
72+
73+
# Convert COO to CSR
74+
csr_matrix = tc.backend.sparse_csr_from_coo(coo_matrix)
75+
76+
# Check that the result is still recognized as sparse
77+
assert tc.backend.is_sparse(csr_matrix) is True
78+
79+
# Check that the conversion preserves values by comparing dense representations
80+
coo_dense = tc.backend.to_dense(coo_matrix)
81+
csr_dense = tc.backend.to_dense(csr_matrix)
82+
np.testing.assert_allclose(coo_dense, csr_dense, atol=1e-5)
83+
84+
6485
def test_sparse_tensor_matmul_monkey_patch(tfb):
6586
"""
6687
Test the monkey-patched __matmul__ method for tf.SparseTensor.

0 commit comments

Comments
 (0)