|
6 | 6 | * The MIT License (MIT)
|
7 | 7 | *
|
8 | 8 | * Copyright (c) 2020 Jeff Epler for Adafruit Industries
|
9 |
| - * 2019-2021 Zoltán Vörös |
| 9 | + * 2019-2024 Zoltán Vörös |
10 | 10 | * 2020 Taku Fukada
|
11 | 11 | */
|
12 | 12 |
|
@@ -776,6 +776,235 @@ mp_obj_t create_ones(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args)
|
776 | 776 | MP_DEFINE_CONST_FUN_OBJ_KW(create_ones_obj, 0, create_ones);
|
777 | 777 | #endif
|
778 | 778 |
|
| 779 | +#if ULAB_NUMPY_HAS_TAKE |
| 780 | +//| def take( |
| 781 | +//| a: ulab.numpy.ndarray, |
| 782 | +//| indices: _ArrayLike, |
| 783 | +//| axis: Optional[int] = None, |
| 784 | +//| out: Optional[ulab.numpy.ndarray] = None, |
| 785 | +//| mode: Optional[str] = None) -> ulab.numpy.ndarray: |
| 786 | +//| """ |
| 787 | +//| .. param: a |
| 788 | +//| The source array. |
| 789 | +//| .. param: indices |
| 790 | +//| The indices of the values to extract. |
| 791 | +//| .. param: axis |
| 792 | +//| The axis over which to select values. By default, the flattened input array is used. |
| 793 | +//| .. param: out |
| 794 | +//| If provided, the result will be placed in this array. It should be of the appropriate shape and dtype. |
| 795 | +//| .. param: mode |
| 796 | +//| Specifies how out-of-bounds indices will behave. |
| 797 | +//| - `raise`: raise an error (default) |
| 798 | +//| - `wrap`: wrap around |
| 799 | +//| - `clip`: clip to the range |
| 800 | +//| `clip` mode means that all indices that are too large are replaced by the |
| 801 | +//| index that addresses the last element along that axis. Note that this disables |
| 802 | +//| indexing with negative numbers. |
| 803 | +//| |
| 804 | +//| Return a new array.""" |
| 805 | +//| ... |
| 806 | +//| |
| 807 | + |
| 808 | +enum CREATE_TAKE_MODE { |
| 809 | + CREATE_TAKE_RAISE, |
| 810 | + CREATE_TAKE_WRAP, |
| 811 | + CREATE_TAKE_CLIP, |
| 812 | +}; |
| 813 | + |
| 814 | +mp_obj_t create_take(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { |
| 815 | + static const mp_arg_t allowed_args[] = { |
| 816 | + { MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_obj = MP_OBJ_NULL } }, |
| 817 | + { MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_obj = MP_OBJ_NULL } }, |
| 818 | + { MP_QSTR_axis, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }, |
| 819 | + { MP_QSTR_out, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }, |
| 820 | + { MP_QSTR_mode, MP_ARG_KW_ONLY | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } }, |
| 821 | + }; |
| 822 | + |
| 823 | + mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)]; |
| 824 | + mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args); |
| 825 | + |
| 826 | + if(!mp_obj_is_type(args[0].u_obj, &ulab_ndarray_type)) { |
| 827 | + mp_raise_TypeError(MP_ERROR_TEXT("input is not an array")); |
| 828 | + } |
| 829 | + |
| 830 | + ndarray_obj_t *a = MP_OBJ_TO_PTR(args[0].u_obj); |
| 831 | + int8_t axis = 0; |
| 832 | + int8_t axis_index = 0; |
| 833 | + int32_t axis_len; |
| 834 | + uint8_t mode = CREATE_TAKE_RAISE; |
| 835 | + uint8_t ndim; |
| 836 | + |
| 837 | + // axis keyword argument |
| 838 | + if(args[2].u_obj == mp_const_none) { |
| 839 | + // work with the flattened array |
| 840 | + axis_len = a->len; |
| 841 | + ndim = 1; |
| 842 | + } else { // i.e., axis is an integer |
| 843 | + // TODO: this pops up at quite a few places, write it as a function |
| 844 | + axis = mp_obj_get_int(args[2].u_obj); |
| 845 | + ndim = a->ndim; |
| 846 | + if(axis < 0) axis += a->ndim; |
| 847 | + if((axis < 0) || (axis > a->ndim - 1)) { |
| 848 | + mp_raise_ValueError(MP_ERROR_TEXT("index out of range")); |
| 849 | + } |
| 850 | + axis_index = ULAB_MAX_DIMS - a->ndim + axis; |
| 851 | + axis_len = (int32_t)a->shape[axis_index]; |
| 852 | + } |
| 853 | + |
| 854 | + size_t _len; |
| 855 | + // mode keyword argument |
| 856 | + if(mp_obj_is_str(args[4].u_obj)) { |
| 857 | + const char *_mode = mp_obj_str_get_data(args[4].u_obj, &_len); |
| 858 | + if(memcmp(_mode, "raise", 5) == 0) { |
| 859 | + mode = CREATE_TAKE_RAISE; |
| 860 | + } else if(memcmp(_mode, "wrap", 4) == 0) { |
| 861 | + mode = CREATE_TAKE_WRAP; |
| 862 | + } else if(memcmp(_mode, "clip", 4) == 0) { |
| 863 | + mode = CREATE_TAKE_CLIP; |
| 864 | + } else { |
| 865 | + mp_raise_ValueError(MP_ERROR_TEXT("mode should be raise, wrap or clip")); |
| 866 | + } |
| 867 | + } |
| 868 | + |
| 869 | + size_t indices_len = (size_t)mp_obj_get_int(mp_obj_len_maybe(args[1].u_obj)); |
| 870 | + |
| 871 | + size_t *indices = m_new(size_t, indices_len); |
| 872 | + |
| 873 | + mp_obj_iter_buf_t buf; |
| 874 | + mp_obj_t item, iterable = mp_getiter(args[1].u_obj, &buf); |
| 875 | + |
| 876 | + size_t z = 0; |
| 877 | + while((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) { |
| 878 | + int32_t index = mp_obj_get_int(item); |
| 879 | + if(mode == CREATE_TAKE_RAISE) { |
| 880 | + if(index < 0) { |
| 881 | + index += axis_len; |
| 882 | + } |
| 883 | + if((index < 0) || (index > axis_len - 1)) { |
| 884 | + m_del(size_t, indices, indices_len); |
| 885 | + mp_raise_ValueError(MP_ERROR_TEXT("index out of range")); |
| 886 | + } |
| 887 | + } else if(mode == CREATE_TAKE_WRAP) { |
| 888 | + index %= axis_len; |
| 889 | + } else { // mode == CREATE_TAKE_CLIP |
| 890 | + if(index < 0) { |
| 891 | + m_del(size_t, indices, indices_len); |
| 892 | + mp_raise_ValueError(MP_ERROR_TEXT("index must not be negative")); |
| 893 | + } |
| 894 | + if(index > axis_len - 1) { |
| 895 | + index = axis_len - 1; |
| 896 | + } |
| 897 | + } |
| 898 | + indices[z++] = (size_t)index; |
| 899 | + } |
| 900 | + |
| 901 | + size_t *shape = m_new0(size_t, ULAB_MAX_DIMS); |
| 902 | + if(args[2].u_obj == mp_const_none) { // flattened array |
| 903 | + shape[ULAB_MAX_DIMS - 1] = indices_len; |
| 904 | + } else { |
| 905 | + for(uint8_t i = 0; i < ULAB_MAX_DIMS; i++) { |
| 906 | + shape[i] = a->shape[i]; |
| 907 | + if(i == axis_index) { |
| 908 | + shape[i] = indices_len; |
| 909 | + } |
| 910 | + } |
| 911 | + } |
| 912 | + |
| 913 | + ndarray_obj_t *out = NULL; |
| 914 | + if(args[3].u_obj == mp_const_none) { |
| 915 | + // no output was supplied |
| 916 | + out = ndarray_new_dense_ndarray(ndim, shape, a->dtype); |
| 917 | + } else { |
| 918 | + // TODO: deal with last argument being false! |
| 919 | + out = ulab_tools_inspect_out(args[3].u_obj, a->dtype, ndim, shape, true); |
| 920 | + } |
| 921 | + |
| 922 | + #if ULAB_MAX_DIMS > 1 // we can save the hassle, if there is only one possible dimension |
| 923 | + if((args[2].u_obj == mp_const_none) || (a->ndim == 1)) { // flattened array |
| 924 | + #endif |
| 925 | + uint8_t *out_array = (uint8_t *)out->array; |
| 926 | + for(size_t x = 0; x < indices_len; x++) { |
| 927 | + uint8_t *a_array = (uint8_t *)a->array; |
| 928 | + size_t remainder = indices[x]; |
| 929 | + uint8_t q = ULAB_MAX_DIMS - 1; |
| 930 | + do { |
| 931 | + size_t div = (remainder / a->shape[q]); |
| 932 | + a_array += remainder * a->strides[q]; |
| 933 | + remainder -= div * a->shape[q]; |
| 934 | + q--; |
| 935 | + } while(q > ULAB_MAX_DIMS - a->ndim); |
| 936 | + // NOTE: for floats and complexes, this might be |
| 937 | + // better with memcpy(out_array, a_array, a->itemsize) |
| 938 | + for(uint8_t p = 0; p < a->itemsize; p++) { |
| 939 | + out_array[p] = a_array[p]; |
| 940 | + } |
| 941 | + out_array += a->itemsize; |
| 942 | + } |
| 943 | + #if ULAB_MAX_DIMS > 1 |
| 944 | + } else { |
| 945 | + // move the axis shape/stride to the leftmost position: |
| 946 | + SWAP(size_t, a->shape[0], a->shape[axis_index]); |
| 947 | + SWAP(size_t, out->shape[0], out->shape[axis_index]); |
| 948 | + SWAP(int32_t, a->strides[0], a->strides[axis_index]); |
| 949 | + SWAP(int32_t, out->strides[0], out->strides[axis_index]); |
| 950 | + |
| 951 | + for(size_t x = 0; x < indices_len; x++) { |
| 952 | + uint8_t *a_array = (uint8_t *)a->array; |
| 953 | + uint8_t *out_array = (uint8_t *)out->array; |
| 954 | + a_array += indices[x] * a->strides[0]; |
| 955 | + out_array += x * out->strides[0]; |
| 956 | + |
| 957 | + #if ULAB_MAX_DIMS > 3 |
| 958 | + size_t j = 0; |
| 959 | + do { |
| 960 | + #endif |
| 961 | + #if ULAB_MAX_DIMS > 2 |
| 962 | + size_t k = 0; |
| 963 | + do { |
| 964 | + #endif |
| 965 | + size_t l = 0; |
| 966 | + do { |
| 967 | + // NOTE: for floats and complexes, this might be |
| 968 | + // better with memcpy(out_array, a_array, a->itemsize) |
| 969 | + for(uint8_t p = 0; p < a->itemsize; p++) { |
| 970 | + out_array[p] = a_array[p]; |
| 971 | + } |
| 972 | + out_array += out->strides[ULAB_MAX_DIMS - 1]; |
| 973 | + a_array += a->strides[ULAB_MAX_DIMS - 1]; |
| 974 | + l++; |
| 975 | + } while(l < a->shape[ULAB_MAX_DIMS - 1]); |
| 976 | + #if ULAB_MAX_DIMS > 2 |
| 977 | + out_array -= out->strides[ULAB_MAX_DIMS - 1] * out->shape[ULAB_MAX_DIMS - 1]; |
| 978 | + out_array += out->strides[ULAB_MAX_DIMS - 2]; |
| 979 | + a_array -= a->strides[ULAB_MAX_DIMS - 1] * a->shape[ULAB_MAX_DIMS - 1]; |
| 980 | + a_array += a->strides[ULAB_MAX_DIMS - 2]; |
| 981 | + k++; |
| 982 | + } while(k < a->shape[ULAB_MAX_DIMS - 2]); |
| 983 | + #endif |
| 984 | + #if ULAB_MAX_DIMS > 3 |
| 985 | + out_array -= out->strides[ULAB_MAX_DIMS - 2] * out->shape[ULAB_MAX_DIMS - 2]; |
| 986 | + out_array += out->strides[ULAB_MAX_DIMS - 3]; |
| 987 | + a_array -= a->strides[ULAB_MAX_DIMS - 2] * a->shape[ULAB_MAX_DIMS - 2]; |
| 988 | + a_array += a->strides[ULAB_MAX_DIMS - 3]; |
| 989 | + j++; |
| 990 | + } while(j < a->shape[ULAB_MAX_DIMS - 3]); |
| 991 | + #endif |
| 992 | + } |
| 993 | + |
| 994 | + // revert back to the original order |
| 995 | + SWAP(size_t, a->shape[0], a->shape[axis_index]); |
| 996 | + SWAP(size_t, out->shape[0], out->shape[axis_index]); |
| 997 | + SWAP(int32_t, a->strides[0], a->strides[axis_index]); |
| 998 | + SWAP(int32_t, out->strides[0], out->strides[axis_index]); |
| 999 | + } |
| 1000 | + #endif /* ULAB_MAX_DIMS > 1 */ |
| 1001 | + m_del(size_t, indices, indices_len); |
| 1002 | + return MP_OBJ_FROM_PTR(out); |
| 1003 | +} |
| 1004 | + |
| 1005 | +MP_DEFINE_CONST_FUN_OBJ_KW(create_take_obj, 2, create_take); |
| 1006 | +#endif /* ULAB_NUMPY_HAS_TAKE */ |
| 1007 | + |
779 | 1008 | #if ULAB_NUMPY_HAS_ZEROS
|
780 | 1009 | //| def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: _DType = ulab.numpy.float) -> ulab.numpy.ndarray:
|
781 | 1010 | //| """
|
|
0 commit comments