Skip to content

Commit 894d9a6

Browse files
BCSR support for mul
1 parent 76ea91b commit 894d9a6

File tree

4 files changed

+32
-3
lines changed

4 files changed

+32
-3
lines changed

examples/analog_rydberg.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121

2222
# Instantiate the AnalogCircuit
2323
ac = tc.AnalogCircuit(nqubits)
24-
24+
ac.set_solver_options(
25+
ode_backend="diffrax", max_steps=20000
26+
) # more efficient and stable than the default one by jax
2527
# 1. Create the sparsely excited state
2628
ac.x([i for i in range(nqubits) if i % 4 == 0])
2729

@@ -49,5 +51,7 @@ def rydberg_hamiltonian_func(t):
4951
# ac.rx(i, theta=-np.pi/4)
5052

5153
# 5. Sample from the final state in the computational basis
54+
state = ac.state()
55+
np.testing.assert_allclose(tc.backend.norm(state), 1, atol=1e-3)
5256
sample = ac.sample(batch=1024, allow_state=True, format="count_dict_bin")
5357
print("\nSampled bitstrings:\n", sample)

tensorcircuit/backends/jax_backend.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,27 @@ def _eigh_jax(self: Any, tensor: Tensor) -> Tensor:
175175
return adaware_eigh(tensor)
176176

177177

178+
def bcsr_scalar_mul(self: Tensor, other: Tensor) -> Tensor:
179+
"""
180+
Implements scalar multiplication for BCSR matrices (self * scalar).
181+
"""
182+
import jax.numpy as jnp
183+
from jax.experimental.sparse import BCSR
184+
185+
if jnp.isscalar(other):
186+
# The core logic: only the data array is affected by scalar multiplication.
187+
# The sparsity pattern (indices, indptr) remains the same.
188+
new_data = self.data * other
189+
190+
# Return a new BCSR instance with the scaled data.
191+
return BCSR((new_data, self.indices, self.indptr), shape=self.shape)
192+
193+
# For any other type of multiplication (e.g., element-wise with another matrix),
194+
# return NotImplemented. This allows Python to try other operations,
195+
# like other.__rmul__(self).
196+
return NotImplemented
197+
198+
178199
tensornetwork.backends.jax.jax_backend.JaxBackend.convert_to_tensor = (
179200
_convert_to_tensor_jax
180201
)
@@ -224,6 +245,11 @@ def __init__(self) -> None:
224245

225246
self.name = "jax"
226247

248+
# --- Monkey-patch the BCSR class ---
249+
250+
sparse.BCSR.__mul__ = bcsr_scalar_mul # type: ignore
251+
sparse.BCSR.__rmul__ = bcsr_scalar_mul # type: ignore
252+
227253
# it is already child of numpy backend, and self.np = self.jax.np
228254
def eye(
229255
self, N: int, dtype: Optional[str] = None, M: Optional[int] = None

tensorcircuit/backends/pytorch_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
# TODO(@refraction-ray): lack stateful random methods implementation for now
2626
# TODO(@refraction-ray): lack scatter impl for now
27-
# TODO(@refraction-ray): lack sparse relevant methods for now
2827
# To be added once pytorch backend is ready
2928

3029

tensorcircuit/timeevol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ def ode_evol_global(
616616

617617
def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
618618
h = -1.0j * hamiltonian(t, *args)
619-
return backend.sparse_dense_matmul(h, y)
619+
return h @ y
620620

621621
s1 = _solve_ode(f, initial_state, times, args, solver_kws)
622622

0 commit comments

Comments
 (0)