Skip to content

Commit 1a440d7

Browse files
authored
Fix sort when dtype is uint16 (#563)
Prior to this fix the code was using the mp_float_t data type for uint16 and producing incorrect sort results. Signed-off-by: Damien George <[email protected]> Signed-off-by: Damien George <[email protected]>
1 parent 25a825e commit 1a440d7

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

code/numpy/numerical.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ static mp_obj_t numerical_sort_helper(mp_obj_t oin, mp_obj_t axis, uint8_t inpla
650650
if(ndarray->shape[ax]) {
651651
if((ndarray->dtype == NDARRAY_UINT8) || (ndarray->dtype == NDARRAY_INT8)) {
652652
HEAPSORT(ndarray, uint8_t, array, shape, strides, ax, increment, ndarray->shape[ax]);
653-
} else if((ndarray->dtype == NDARRAY_INT16) || (ndarray->dtype == NDARRAY_INT16)) {
653+
} else if((ndarray->dtype == NDARRAY_UINT16) || (ndarray->dtype == NDARRAY_INT16)) {
654654
HEAPSORT(ndarray, uint16_t, array, shape, strides, ax, increment, ndarray->shape[ax]);
655655
} else {
656656
HEAPSORT(ndarray, mp_float_t, array, shape, strides, ax, increment, ndarray->shape[ax]);

tests/2d/numpy/sort.py.exp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ array([1, 2, 3, 4], dtype=int8)
1111

1212
array([], dtype=uint16)
1313
[]
14-
array([0, 0, 0, 0], dtype=uint16)
14+
array([1, 2, 3, 4], dtype=uint16)
1515
[1, 3, 2, 0]
1616

1717
array([], dtype=int16)

0 commit comments

Comments
 (0)