@@ -609,16 +609,16 @@ def f_jax(*args: Any, **kws: Any) -> Any:
609609 return carry
610610
611611 def scatter (self , operand : Tensor , indices : Tensor , updates : Tensor ) -> Tensor :
612- updates = jnp .reshape (updates , indices .shape )
613- return operand .at [indices ].set (updates )
614- # rank = len(operand.shape)
615- # dnums = libjax.lax.ScatterDimensionNumbers(
616- # update_window_dims=(),
617- # inserted_window_dims=tuple([i for i in range(rank)]),
618- # scatter_dims_to_operand_dims=tuple([i for i in range(rank)]),
619- # )
620- # r = libjax.lax.scatter(operand, indices, updates, dnums)
621- # return r
612+ # updates = jnp.reshape(updates, indices.shape)
613+ # return operand.at[indices].set(updates)
614+ rank = len (operand .shape )
615+ dnums = libjax .lax .ScatterDimensionNumbers (
616+ update_window_dims = (),
617+ inserted_window_dims = tuple ([i for i in range (rank )]),
618+ scatter_dims_to_operand_dims = tuple ([i for i in range (rank )]),
619+ )
620+ r = libjax .lax .scatter (operand , indices , updates , dnums )
621+ return r
622622
623623 def coo_sparse_matrix (
624624 self , indices : Tensor , values : Tensor , shape : Tensor
0 commit comments