Skip to content

Commit 2b74236

Browse files
authored
Take (#688)
* add numpy.take
1 parent c0b3262 commit 2b74236

File tree

14 files changed

+544
-19
lines changed

14 files changed

+544
-19
lines changed

code/numpy/create.c

Lines changed: 230 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* The MIT License (MIT)
77
*
88
* Copyright (c) 2020 Jeff Epler for Adafruit Industries
9-
* 2019-2021 Zoltán Vörös
9+
* 2019-2024 Zoltán Vörös
1010
* 2020 Taku Fukada
1111
*/
1212

@@ -776,6 +776,235 @@ mp_obj_t create_ones(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args)
776776
MP_DEFINE_CONST_FUN_OBJ_KW(create_ones_obj, 0, create_ones);
777777
#endif
778778

779+
#if ULAB_NUMPY_HAS_TAKE
780+
//| def take(
781+
//| a: ulab.numpy.ndarray,
782+
//| indices: _ArrayLike,
783+
//| axis: Optional[int] = None,
784+
//| out: Optional[ulab.numpy.ndarray] = None,
785+
//| mode: Optional[str] = None) -> ulab.numpy.ndarray:
786+
//| """
787+
//| .. param: a
788+
//| The source array.
789+
//| .. param: indices
790+
//| The indices of the values to extract.
791+
//| .. param: axis
792+
//| The axis over which to select values. By default, the flattened input array is used.
793+
//| .. param: out
794+
//| If provided, the result will be placed in this array. It should be of the appropriate shape and dtype.
795+
//| .. param: mode
796+
//| Specifies how out-of-bounds indices will behave.
797+
//| - `raise`: raise an error (default)
798+
//| - `wrap`: wrap around
799+
//| - `clip`: clip to the range
800+
//| `clip` mode means that all indices that are too large are replaced by the
801+
//| index that addresses the last element along that axis. Note that this disables
802+
//| indexing with negative numbers.
803+
//|
804+
//| Return a new array."""
805+
//| ...
806+
//|
807+
808+
enum CREATE_TAKE_MODE {
809+
CREATE_TAKE_RAISE,
810+
CREATE_TAKE_WRAP,
811+
CREATE_TAKE_CLIP,
812+
};
813+
814+
mp_obj_t create_take(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
815+
static const mp_arg_t allowed_args[] = {
816+
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_obj = MP_OBJ_NULL } },
817+
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_obj = MP_OBJ_NULL } },
818+
{ MP_QSTR_axis, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
819+
{ MP_QSTR_out, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
820+
{ MP_QSTR_mode, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
821+
};
822+
823+
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
824+
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
825+
826+
if(!mp_obj_is_type(args[0].u_obj, &ulab_ndarray_type)) {
827+
mp_raise_TypeError(MP_ERROR_TEXT("input is not an array"));
828+
}
829+
830+
ndarray_obj_t *a = MP_OBJ_TO_PTR(args[0].u_obj);
831+
int8_t axis = 0;
832+
int8_t axis_index = 0;
833+
int32_t axis_len;
834+
uint8_t mode = CREATE_TAKE_RAISE;
835+
uint8_t ndim;
836+
837+
// axis keyword argument
838+
if(args[2].u_obj == mp_const_none) {
839+
// work with the flattened array
840+
axis_len = a->len;
841+
ndim = 1;
842+
} else { // i.e., axis is an integer
843+
// TODO: this pops up at quite a few places, write it as a function
844+
axis = mp_obj_get_int(args[2].u_obj);
845+
ndim = a->ndim;
846+
if(axis < 0) axis += a->ndim;
847+
if((axis < 0) || (axis > a->ndim - 1)) {
848+
mp_raise_ValueError(MP_ERROR_TEXT("index out of range"));
849+
}
850+
axis_index = ULAB_MAX_DIMS - a->ndim + axis;
851+
axis_len = (int32_t)a->shape[axis_index];
852+
}
853+
854+
size_t _len;
855+
// mode keyword argument
856+
if(mp_obj_is_str(args[4].u_obj)) {
857+
const char *_mode = mp_obj_str_get_data(args[4].u_obj, &_len);
858+
if(memcmp(_mode, "raise", 5) == 0) {
859+
mode = CREATE_TAKE_RAISE;
860+
} else if(memcmp(_mode, "wrap", 4) == 0) {
861+
mode = CREATE_TAKE_WRAP;
862+
} else if(memcmp(_mode, "clip", 4) == 0) {
863+
mode = CREATE_TAKE_CLIP;
864+
} else {
865+
mp_raise_ValueError(MP_ERROR_TEXT("mode should be raise, wrap or clip"));
866+
}
867+
}
868+
869+
size_t indices_len = (size_t)mp_obj_get_int(mp_obj_len_maybe(args[1].u_obj));
870+
871+
size_t *indices = m_new(size_t, indices_len);
872+
873+
mp_obj_iter_buf_t buf;
874+
mp_obj_t item, iterable = mp_getiter(args[1].u_obj, &buf);
875+
876+
size_t z = 0;
877+
while((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) {
878+
int32_t index = mp_obj_get_int(item);
879+
if(mode == CREATE_TAKE_RAISE) {
880+
if(index < 0) {
881+
index += axis_len;
882+
}
883+
if((index < 0) || (index > axis_len - 1)) {
884+
m_del(size_t, indices, indices_len);
885+
mp_raise_ValueError(MP_ERROR_TEXT("index out of range"));
886+
}
887+
} else if(mode == CREATE_TAKE_WRAP) {
888+
index %= axis_len;
889+
} else { // mode == CREATE_TAKE_CLIP
890+
if(index < 0) {
891+
m_del(size_t, indices, indices_len);
892+
mp_raise_ValueError(MP_ERROR_TEXT("index must not be negative"));
893+
}
894+
if(index > axis_len - 1) {
895+
index = axis_len - 1;
896+
}
897+
}
898+
indices[z++] = (size_t)index;
899+
}
900+
901+
size_t *shape = m_new0(size_t, ULAB_MAX_DIMS);
902+
if(args[2].u_obj == mp_const_none) { // flattened array
903+
shape[ULAB_MAX_DIMS - 1] = indices_len;
904+
} else {
905+
for(uint8_t i = 0; i < ULAB_MAX_DIMS; i++) {
906+
shape[i] = a->shape[i];
907+
if(i == axis_index) {
908+
shape[i] = indices_len;
909+
}
910+
}
911+
}
912+
913+
ndarray_obj_t *out = NULL;
914+
if(args[3].u_obj == mp_const_none) {
915+
// no output was supplied
916+
out = ndarray_new_dense_ndarray(ndim, shape, a->dtype);
917+
} else {
918+
// TODO: deal with last argument being false!
919+
out = ulab_tools_inspect_out(args[3].u_obj, a->dtype, ndim, shape, true);
920+
}
921+
922+
#if ULAB_MAX_DIMS > 1 // we can save the hassle, if there is only one possible dimension
923+
if((args[2].u_obj == mp_const_none) || (a->ndim == 1)) { // flattened array
924+
#endif
925+
uint8_t *out_array = (uint8_t *)out->array;
926+
for(size_t x = 0; x < indices_len; x++) {
927+
uint8_t *a_array = (uint8_t *)a->array;
928+
size_t remainder = indices[x];
929+
uint8_t q = ULAB_MAX_DIMS - 1;
930+
do {
931+
size_t div = (remainder / a->shape[q]);
932+
a_array += remainder * a->strides[q];
933+
remainder -= div * a->shape[q];
934+
q--;
935+
} while(q > ULAB_MAX_DIMS - a->ndim);
936+
// NOTE: for floats and complexes, this might be
937+
// better with memcpy(out_array, a_array, a->itemsize)
938+
for(uint8_t p = 0; p < a->itemsize; p++) {
939+
out_array[p] = a_array[p];
940+
}
941+
out_array += a->itemsize;
942+
}
943+
#if ULAB_MAX_DIMS > 1
944+
} else {
945+
// move the axis shape/stride to the leftmost position:
946+
SWAP(size_t, a->shape[0], a->shape[axis_index]);
947+
SWAP(size_t, out->shape[0], out->shape[axis_index]);
948+
SWAP(int32_t, a->strides[0], a->strides[axis_index]);
949+
SWAP(int32_t, out->strides[0], out->strides[axis_index]);
950+
951+
for(size_t x = 0; x < indices_len; x++) {
952+
uint8_t *a_array = (uint8_t *)a->array;
953+
uint8_t *out_array = (uint8_t *)out->array;
954+
a_array += indices[x] * a->strides[0];
955+
out_array += x * out->strides[0];
956+
957+
#if ULAB_MAX_DIMS > 3
958+
size_t j = 0;
959+
do {
960+
#endif
961+
#if ULAB_MAX_DIMS > 2
962+
size_t k = 0;
963+
do {
964+
#endif
965+
size_t l = 0;
966+
do {
967+
// NOTE: for floats and complexes, this might be
968+
// better with memcpy(out_array, a_array, a->itemsize)
969+
for(uint8_t p = 0; p < a->itemsize; p++) {
970+
out_array[p] = a_array[p];
971+
}
972+
out_array += out->strides[ULAB_MAX_DIMS - 1];
973+
a_array += a->strides[ULAB_MAX_DIMS - 1];
974+
l++;
975+
} while(l < a->shape[ULAB_MAX_DIMS - 1]);
976+
#if ULAB_MAX_DIMS > 2
977+
out_array -= out->strides[ULAB_MAX_DIMS - 1] * out->shape[ULAB_MAX_DIMS - 1];
978+
out_array += out->strides[ULAB_MAX_DIMS - 2];
979+
a_array -= a->strides[ULAB_MAX_DIMS - 1] * a->shape[ULAB_MAX_DIMS - 1];
980+
a_array += a->strides[ULAB_MAX_DIMS - 2];
981+
k++;
982+
} while(k < a->shape[ULAB_MAX_DIMS - 2]);
983+
#endif
984+
#if ULAB_MAX_DIMS > 3
985+
out_array -= out->strides[ULAB_MAX_DIMS - 2] * out->shape[ULAB_MAX_DIMS - 2];
986+
out_array += out->strides[ULAB_MAX_DIMS - 3];
987+
a_array -= a->strides[ULAB_MAX_DIMS - 2] * a->shape[ULAB_MAX_DIMS - 2];
988+
a_array += a->strides[ULAB_MAX_DIMS - 3];
989+
j++;
990+
} while(j < a->shape[ULAB_MAX_DIMS - 3]);
991+
#endif
992+
}
993+
994+
// revert back to the original order
995+
SWAP(size_t, a->shape[0], a->shape[axis_index]);
996+
SWAP(size_t, out->shape[0], out->shape[axis_index]);
997+
SWAP(int32_t, a->strides[0], a->strides[axis_index]);
998+
SWAP(int32_t, out->strides[0], out->strides[axis_index]);
999+
}
1000+
#endif /* ULAB_MAX_DIMS > 1 */
1001+
m_del(size_t, indices, indices_len);
1002+
return MP_OBJ_FROM_PTR(out);
1003+
}
1004+
1005+
MP_DEFINE_CONST_FUN_OBJ_KW(create_take_obj, 2, create_take);
1006+
#endif /* ULAB_NUMPY_HAS_TAKE */
1007+
7791008
#if ULAB_NUMPY_HAS_ZEROS
7801009
//| def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: _DType = ulab.numpy.float) -> ulab.numpy.ndarray:
7811010
//| """

code/numpy/create.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ mp_obj_t create_ones(size_t , const mp_obj_t *, mp_map_t *);
6262
MP_DECLARE_CONST_FUN_OBJ_KW(create_ones_obj);
6363
#endif
6464

65+
#if ULAB_NUMPY_HAS_TAKE
66+
mp_obj_t create_take(size_t , const mp_obj_t *, mp_map_t *);
67+
MP_DECLARE_CONST_FUN_OBJ_KW(create_take_obj);
68+
#endif
69+
6570
#if ULAB_NUMPY_HAS_ZEROS
6671
mp_obj_t create_zeros(size_t , const mp_obj_t *, mp_map_t *);
6772
MP_DECLARE_CONST_FUN_OBJ_KW(create_zeros_obj);

code/numpy/numpy.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,9 @@ static const mp_rom_map_elem_t ulab_numpy_globals_table[] = {
291291
#if ULAB_NUMPY_HAS_SUM
292292
{ MP_ROM_QSTR(MP_QSTR_sum), MP_ROM_PTR(&numerical_sum_obj) },
293293
#endif
294+
#if ULAB_NUMPY_HAS_TAKE
295+
{ MP_ROM_QSTR(MP_QSTR_take), MP_ROM_PTR(&create_take_obj) },
296+
#endif
294297
// functions of the poly sub-module
295298
#if ULAB_NUMPY_HAS_POLYFIT
296299
{ MP_ROM_QSTR(MP_QSTR_polyfit), MP_ROM_PTR(&poly_polyfit_obj) },

code/ulab.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "user/user.h"
3434
#include "utils/utils.h"
3535

36-
#define ULAB_VERSION 6.5.5
36+
#define ULAB_VERSION 6.6.0
3737
#define xstr(s) str(s)
3838
#define str(s) #s
3939

code/ulab.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,10 @@
559559
#define ULAB_NUMPY_HAS_SUM (1)
560560
#endif
561561

562+
#ifndef ULAB_NUMPY_HAS_TAKE
563+
#define ULAB_NUMPY_HAS_TAKE (1)
564+
#endif
565+
562566
#ifndef ULAB_NUMPY_HAS_TRACE
563567
#define ULAB_NUMPY_HAS_TRACE (1)
564568
#endif

code/ulab_tools.c

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,31 @@ bool ulab_tools_mp_obj_is_scalar(mp_obj_t obj) {
274274
}
275275
#endif
276276
}
277+
278+
ndarray_obj_t *ulab_tools_inspect_out(mp_obj_t out, uint8_t dtype, uint8_t ndim, size_t *shape, bool dense_only) {
279+
if(!mp_obj_is_type(out, &ulab_ndarray_type)) {
280+
mp_raise_TypeError(MP_ERROR_TEXT("out has wrong type"));
281+
}
282+
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(out);
283+
284+
if(ndarray->dtype != dtype) {
285+
mp_raise_ValueError(MP_ERROR_TEXT("out array has wrong dtype"));
286+
}
287+
288+
if(ndarray->ndim != ndim) {
289+
mp_raise_ValueError(MP_ERROR_TEXT("out array has wrong dimension"));
290+
}
291+
292+
for(uint8_t i = 0; i < ULAB_MAX_DIMS; i++) {
293+
if(ndarray->shape[i] != shape[i]) {
294+
mp_raise_ValueError(MP_ERROR_TEXT("out array has wrong shape"));
295+
}
296+
}
297+
298+
if(dense_only) {
299+
if(!ndarray_is_dense(ndarray)) {
300+
mp_raise_ValueError(MP_ERROR_TEXT("output array must be contiguous"));
301+
}
302+
}
303+
return ndarray;
304+
}

code/ulab_tools.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ void ulab_rescale_float_strides(int32_t *);
4444

4545
bool ulab_tools_mp_obj_is_scalar(mp_obj_t );
4646

47-
#if ULAB_NUMPY_HAS_RANDOM_MODULE
48-
ndarray_obj_t *ulab_tools_create_out(mp_obj_tuple_t , mp_obj_t , uint8_t , bool );
49-
#endif
47+
ndarray_obj_t *ulab_tools_inspect_out(mp_obj_t , uint8_t , uint8_t , size_t *, bool );
48+
5049
#endif

docs/manual/source/conf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
author = 'Zoltán Vörös'
2828

2929
# The full version, including alpha/beta/rc tags
30-
release = '6.5.5'
31-
30+
release = '6.6.0'
3231

3332
# -- General configuration ---------------------------------------------------
3433

0 commit comments

Comments
 (0)