@@ -481,6 +481,26 @@ def test_arg_cmp(backend):
481481 )
482482
483483
484+ @pytest .mark .parametrize ("backend" , [lf ("npb" ), lf ("tfb" ), lf ("jaxb" )])
485+ def test_argsort (backend ):
486+ # Test basic argsort functionality
487+ a = tc .array_to_tensor (np .array ([3 , 1 , 2 ]), dtype = "float32" )
488+ result = tc .backend .argsort (a )
489+ expected = np .array ([1 , 2 , 0 ]) # indices that would sort the array
490+ np .testing .assert_allclose (result , expected )
491+
492+ # Test argsort with 2D array, default axis=-1
493+ b = tc .array_to_tensor (np .array ([[3 , 1 , 2 ], [4 , 0 , 1 ]]), dtype = "float32" )
494+ result = tc .backend .argsort (b )
495+ expected = np .array ([[1 , 2 , 0 ], [1 , 2 , 0 ]])
496+ np .testing .assert_allclose (result , expected )
497+
498+ # Test argsort with 2D array, axis=0
499+ result = tc .backend .argsort (b , axis = 0 )
500+ expected = np .array ([[0 , 1 , 1 ], [1 , 0 , 0 ]])
501+ np .testing .assert_allclose (result , expected )
502+
503+
484504@pytest .mark .parametrize ("backend" , [lf ("npb" ), lf ("tfb" ), lf ("jaxb" ), lf ("torchb" )])
485505def test_tree_map (backend ):
486506 def f (a , b ):
0 commit comments