Skip to content

Commit 0401d4b

Browse files
add backend.argsort
1 parent 4b40245 commit 0401d4b

File tree

7 files changed

+48
-1
lines changed

7 files changed

+48
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
- Introducing pyproject.toml finally
2222

23+
- Add `argsort` method for backends
24+
2325
### Fixed
2426

2527
- Fixed `one_hot` in numpy backend.

examples/hamiltonian_building.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import jax
88
import numpy as np
99
import quimb
10+
from quimb import tensor as qt
1011
import scipy
1112
import tensorflow as tf
1213

@@ -90,7 +91,7 @@
9091
print("hamiltonian building with quimb")
9192
print("quimb version: ", quimb.__version__)
9293

93-
builder = quimb.tensor.SpinHam1D()
94+
builder = qt.SpinHam1D()
9495
# spin operator instead of Pauli matrix
9596
builder += 4, "Z", "Z"
9697
builder += -2, "X"

tensorcircuit/backends/abstract_backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,21 @@ def argmin(self: Any, a: Tensor, axis: int = 0) -> Tensor:
581581
"Backend '{}' has not implemented `argmin`.".format(self.name)
582582
)
583583

584+
def argsort(self: Any, a: Tensor, axis: int = -1) -> Tensor:
585+
"""
586+
return the indices that would sort an array.
587+
588+
:param a: the tensor to be sorted
589+
:type a: Tensor
590+
:param axis: the sorted axis, defaults to -1
591+
:type axis: int
592+
:return: the sorted indices
593+
:rtype: Tensor
594+
"""
595+
raise NotImplementedError(
596+
"Backend '{}' has not implemented `argsort`.".format(self.name)
597+
)
598+
584599
def unique_with_counts(self: Any, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
585600
"""
586601
Find the unique elements and their corresponding counts of the given tensor ``a``.

tensorcircuit/backends/jax_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,9 @@ def argmax(self, a: Tensor, axis: int = 0) -> Tensor:
387387
def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
388388
return jnp.argmin(a, axis=axis)
389389

390+
def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
391+
return jnp.argsort(a, axis=axis)
392+
390393
def unique_with_counts( # type: ignore
391394
self, a: Tensor, *, size: Optional[int] = None, fill_value: Optional[int] = None
392395
) -> Tuple[Tensor, Tensor]:

tensorcircuit/backends/numpy_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,9 @@ def special_jv(self, v: int, z: Tensor, M: int) -> Tensor:
251251
def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
252252
return np.searchsorted(a, v, side=side) # type: ignore
253253

254+
def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
255+
return np.argsort(a, axis=axis)
256+
254257
def set_random_state(
255258
self, seed: Optional[int] = None, get_only: bool = False
256259
) -> Any:

tensorcircuit/backends/tensorflow_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,9 @@ def argmax(self, a: Tensor, axis: int = 0) -> Tensor:
530530
def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
531531
return tf.math.argmin(a, axis=axis)
532532

533+
def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
534+
return tf.argsort(a, axis=axis)
535+
533536
def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
534537
r = tf.unique_with_counts(a)
535538
order = tf.argsort(r.y)

tests/test_backends.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,26 @@ def test_arg_cmp(backend):
481481
)
482482

483483

484+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
485+
def test_argsort(backend):
486+
# Test basic argsort functionality
487+
a = tc.array_to_tensor(np.array([3, 1, 2]), dtype="float32")
488+
result = tc.backend.argsort(a)
489+
expected = np.array([1, 2, 0]) # indices that would sort the array
490+
np.testing.assert_allclose(result, expected)
491+
492+
# Test argsort with 2D array, default axis=-1
493+
b = tc.array_to_tensor(np.array([[3, 1, 2], [4, 0, 1]]), dtype="float32")
494+
result = tc.backend.argsort(b)
495+
expected = np.array([[1, 2, 0], [1, 2, 0]])
496+
np.testing.assert_allclose(result, expected)
497+
498+
# Test argsort with 2D array, axis=0
499+
result = tc.backend.argsort(b, axis=0)
500+
expected = np.array([[0, 1, 1], [1, 0, 0]])
501+
np.testing.assert_allclose(result, expected)
502+
503+
484504
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
485505
def test_tree_map(backend):
486506
def f(a, b):

0 commit comments

Comments
 (0)