Skip to content

Commit fb664ed

Browse files
fix scatter in jax backend
1 parent 2aad7be commit fb664ed

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

tensorcircuit/backends/jax_backend.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -609,16 +609,16 @@ def f_jax(*args: Any, **kws: Any) -> Any:
609609
return carry
610610

611611
def scatter(self, operand: Tensor, indices: Tensor, updates: Tensor) -> Tensor:
612-
updates = jnp.reshape(updates, indices.shape)
613-
return operand.at[indices].set(updates)
614-
# rank = len(operand.shape)
615-
# dnums = libjax.lax.ScatterDimensionNumbers(
616-
# update_window_dims=(),
617-
# inserted_window_dims=tuple([i for i in range(rank)]),
618-
# scatter_dims_to_operand_dims=tuple([i for i in range(rank)]),
619-
# )
620-
# r = libjax.lax.scatter(operand, indices, updates, dnums)
621-
# return r
612+
# updates = jnp.reshape(updates, indices.shape)
613+
# return operand.at[indices].set(updates)
614+
rank = len(operand.shape)
615+
dnums = libjax.lax.ScatterDimensionNumbers(
616+
update_window_dims=(),
617+
inserted_window_dims=tuple([i for i in range(rank)]),
618+
scatter_dims_to_operand_dims=tuple([i for i in range(rank)]),
619+
)
620+
r = libjax.lax.scatter(operand, indices, updates, dnums)
621+
return r
622622

623623
def coo_sparse_matrix(
624624
self, indices: Tensor, values: Tensor, shape: Tensor

tensorcircuit/quantum.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,6 +1161,9 @@ def quimb2qop(qb_mpo: Any) -> QuOperator:
11611161
return qop
11621162

11631163

1164+
# TODO(@refraction-ray): Z2 analogy or more general analogies for the following u1 functions
1165+
1166+
11641167
def u1_inds(n: int, m: int) -> Tensor:
11651168
"""
11661169
Generate all the combination index of m down spins in n sites.

0 commit comments

Comments
 (0)