Skip to content

Commit 0e0956b

Browse files
authored
Merge pull request #309 from v923z/any
any/all implementation
2 parents e4fa4cb + 4607f8e commit 0e0956b

File tree

10 files changed

+496
-47
lines changed

10 files changed

+496
-47
lines changed

code/numpy/numerical/numerical.c

Lines changed: 183 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525
#include "numerical.h"
2626

2727
enum NUMERICAL_FUNCTION_TYPE {
28-
NUMERICAL_MIN,
29-
NUMERICAL_MAX,
30-
NUMERICAL_ARGMIN,
28+
NUMERICAL_ALL,
29+
NUMERICAL_ANY,
3130
NUMERICAL_ARGMAX,
32-
NUMERICAL_SUM,
31+
NUMERICAL_ARGMIN,
32+
NUMERICAL_MAX,
3333
NUMERICAL_MEAN,
34+
NUMERICAL_MIN,
3435
NUMERICAL_STD,
36+
NUMERICAL_SUM,
3537
};
3638

3739
//| """Numerical and Statistical functions
@@ -61,6 +63,166 @@ static void numerical_reduce_axes(ndarray_obj_t *ndarray, int8_t axis, size_t *s
6163
}
6264
}
6365

66+
static shape_strides numerical_reduce_axes_(ndarray_obj_t *ndarray, mp_obj_t axis) {
67+
// TODO: replace numerical_reduce_axes with this function, wherever applicable
68+
int8_t ax = mp_obj_get_int(axis);
69+
if(ax < 0) ax += ndarray->ndim;
70+
if((ax < 0) || (ax > ndarray->ndim - 1)) {
71+
mp_raise_ValueError(translate("index out of range"));
72+
}
73+
shape_strides _shape_strides;
74+
_shape_strides.index = ULAB_MAX_DIMS - ndarray->ndim + ax;
75+
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
76+
memset(shape, 0, sizeof(size_t)*ULAB_MAX_DIMS);
77+
_shape_strides.shape = shape;
78+
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
79+
memset(strides, 0, sizeof(uint32_t)*ULAB_MAX_DIMS);
80+
_shape_strides.strides = strides;
81+
if((ndarray->ndim == 1) && (_shape_strides.axis == 0)) {
82+
_shape_strides.index = 0;
83+
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
84+
} else {
85+
for(uint8_t i = ULAB_MAX_DIMS - 1; i > 0; i--) {
86+
if(i > _shape_strides.index) {
87+
_shape_strides.shape[i] = ndarray->shape[i];
88+
_shape_strides.strides[i] = ndarray->strides[i];
89+
} else {
90+
_shape_strides.shape[i] = ndarray->shape[i-1];
91+
_shape_strides.strides[i] = ndarray->strides[i-1];
92+
}
93+
}
94+
}
95+
return _shape_strides;
96+
}
97+
98+
#if ULAB_NUMPY_HAS_ALL | ULAB_NUMPY_HAS_ANY
99+
static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
100+
bool anytype = optype == NUMERICAL_ALL ? 1 : 0;
101+
if(MP_OBJ_IS_TYPE(oin, &ulab_ndarray_type)) {
102+
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(oin);
103+
uint8_t *array = (uint8_t *)ndarray->array;
104+
// always get a float, so that we don't have to resolve the dtype later
105+
mp_float_t (*func)(void *) = ndarray_get_float_function(ndarray->dtype);
106+
if(axis == mp_const_none) {
107+
#if ULAB_MAX_DIMS > 3
108+
size_t i = 0;
109+
do {
110+
#endif
111+
#if ULAB_MAX_DIMS > 2
112+
size_t j = 0;
113+
do {
114+
#endif
115+
#if ULAB_MAX_DIMS > 1
116+
size_t k = 0;
117+
do {
118+
#endif
119+
size_t l = 0;
120+
do {
121+
mp_float_t value = func(array);
122+
if((value != MICROPY_FLOAT_CONST(0.0)) & !anytype) {
123+
// optype = NUMERICAL_ANY
124+
return mp_const_true;
125+
} else if((value == MICROPY_FLOAT_CONST(0.0)) & anytype) {
126+
// optype == NUMERICAL_ALL
127+
return mp_const_false;
128+
}
129+
array += ndarray->strides[ULAB_MAX_DIMS - 1];
130+
l++;
131+
} while(l < ndarray->shape[ULAB_MAX_DIMS - 1]);
132+
#if ULAB_MAX_DIMS > 1
133+
array -= ndarray->strides[ULAB_MAX_DIMS - 1] * ndarray->shape[ULAB_MAX_DIMS-1];
134+
array += ndarray->strides[ULAB_MAX_DIMS - 2];
135+
k++;
136+
} while(k < ndarray->shape[ULAB_MAX_DIMS - 2]);
137+
#endif
138+
#if ULAB_MAX_DIMS > 2
139+
array -= ndarray->strides[ULAB_MAX_DIMS - 2] * ndarray->shape[ULAB_MAX_DIMS-2];
140+
array += ndarray->strides[ULAB_MAX_DIMS - 3];
141+
j++;
142+
} while(j < ndarray->shape[ULAB_MAX_DIMS - 3]);
143+
#endif
144+
#if ULAB_MAX_DIMS > 3
145+
array -= ndarray->strides[ULAB_MAX_DIMS - 3] * ndarray->shape[ULAB_MAX_DIMS-3];
146+
array += ndarray->strides[ULAB_MAX_DIMS - 4];
147+
i++;
148+
} while(i < ndarray->shape[ULAB_MAX_DIMS - 4]);
149+
#endif
150+
} else {
151+
shape_strides _shape_strides = numerical_reduce_axes_(ndarray, axis);
152+
ndarray_obj_t *results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_BOOL);
153+
uint8_t *rarray = (uint8_t *)results->array;
154+
if(optype == NUMERICAL_ALL) {
155+
memset(rarray, 1, results->len);
156+
}
157+
#if ULAB_MAX_DIMS > 3
158+
size_t i = 0;
159+
do {
160+
#endif
161+
#if ULAB_MAX_DIMS > 2
162+
size_t j = 0;
163+
do {
164+
#endif
165+
#if ULAB_MAX_DIMS > 1
166+
size_t k = 0;
167+
do {
168+
#endif
169+
size_t l = 0;
170+
do {
171+
mp_float_t value = func(array);
172+
if((value != MICROPY_FLOAT_CONST(0.0)) & !anytype) {
173+
// optype == NUMERICAL_ANY
174+
*rarray = 1;
175+
// since we are breaking out of the loop, move the pointer forward
176+
array += ndarray->strides[_shape_strides.index] * (ndarray->shape[_shape_strides.index] - l);
177+
break;
178+
} else if((value == MICROPY_FLOAT_CONST(0.0)) & anytype) {
179+
// optype == NUMERICAL_ALL
180+
*rarray = 0;
181+
// since we are breaking out of the loop, move the pointer forward
182+
array += ndarray->strides[_shape_strides.index] * (ndarray->shape[_shape_strides.index] - l);
183+
break;
184+
}
185+
array += ndarray->strides[_shape_strides.index];
186+
l++;
187+
} while(l < ndarray->shape[_shape_strides.index]);
188+
#if ULAB_MAX_DIMS > 1
189+
rarray++;
190+
array -= ndarray->strides[_shape_strides.index] * ndarray->shape[_shape_strides.index];
191+
array += _shape_strides.strides[ULAB_MAX_DIMS - 1];
192+
k++;
193+
} while(k < _shape_strides.shape[ULAB_MAX_DIMS - 1]);
194+
#endif
195+
#if ULAB_MAX_DIMS > 2
196+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 1] * _shape_strides.shape[ULAB_MAX_DIMS-1];
197+
array += _shape_strides.strides[ULAB_MAX_DIMS - 2];
198+
j++;
199+
} while(j < _shape_strides.shape[ULAB_MAX_DIMS - 2]);
200+
#endif
201+
#if ULAB_MAX_DIMS > 3
202+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 2] * _shape_strides.shape[ULAB_MAX_DIMS-2];
203+
array += _shape_strides.strides[ULAB_MAX_DIMS - 3];
204+
i++;
205+
} while(i < _shape_strides.shape[ULAB_MAX_DIMS - 3])
206+
#endif
207+
return results;
208+
}
209+
} else if(mp_obj_is_int(oin) || mp_obj_is_float(oin)) {
210+
return mp_obj_is_true(oin) ? mp_const_true : mp_const_false;
211+
} else {
212+
mp_obj_iter_buf_t iter_buf;
213+
mp_obj_t item, iterable = mp_getiter(oin, &iter_buf);
214+
while((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) {
215+
if(!mp_obj_is_true(item) & !anytype) {
216+
return mp_const_false;
217+
} else if(mp_obj_is_true(item) & anytype) {
218+
return mp_const_true;
219+
}
220+
}
221+
}
222+
return anytype ? mp_const_true : mp_const_false;
223+
}
224+
#endif
225+
64226
#if ULAB_NUMPY_HAS_SUM | ULAB_NUMPY_HAS_MEAN | ULAB_NUMPY_HAS_STD
65227
static mp_obj_t numerical_sum_mean_std_iterable(mp_obj_t oin, uint8_t optype, size_t ddof) {
66228
mp_float_t value = 0.0, M = 0.0, m = 0.0, S = 0.0, s = 0.0, sum = 0.0;
@@ -403,6 +565,9 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
403565
mp_raise_TypeError(translate("axis must be None, or an integer"));
404566
}
405567

568+
if((optype == NUMERICAL_ALL) || (optype == NUMERICAL_ANY)) {
569+
return numerical_all_any(oin, axis, optype);
570+
}
406571
if(MP_OBJ_IS_TYPE(oin, &mp_type_tuple) || MP_OBJ_IS_TYPE(oin, &mp_type_list) ||
407572
MP_OBJ_IS_TYPE(oin, &mp_type_range)) {
408573
switch(optype) {
@@ -493,6 +658,20 @@ static mp_obj_t numerical_sort_helper(mp_obj_t oin, mp_obj_t axis, uint8_t inpla
493658
}
494659
#endif /* ULAB_NUMERICAL_HAS_SORT | NDARRAY_HAS_SORT */
495660

661+
#if ULAB_NUMPY_HAS_ALL
662+
mp_obj_t numerical_all(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
663+
return numerical_function(n_args, pos_args, kw_args, NUMERICAL_ALL);
664+
}
665+
MP_DEFINE_CONST_FUN_OBJ_KW(numerical_all_obj, 1, numerical_all);
666+
#endif
667+
668+
#if ULAB_NUMPY_HAS_ANY
669+
mp_obj_t numerical_any(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
670+
return numerical_function(n_args, pos_args, kw_args, NUMERICAL_ANY);
671+
}
672+
MP_DEFINE_CONST_FUN_OBJ_KW(numerical_any_obj, 1, numerical_any);
673+
#endif
674+
496675
#if ULAB_NUMPY_HAS_ARGMINMAX
497676
//| def argmax(array: _ArrayLike, *, axis: Optional[int] = None) -> int:
498677
//| """Return the index of the maximum element of the 1D array"""

code/numpy/numerical/numerical.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
#include "../../ndarray.h"
1717

1818
// TODO: implement cumsum
19-
//mp_obj_t numerical_cumsum(size_t , const mp_obj_t *, mp_map_t *);
19+
20+
typedef struct {
21+
uint8_t index;
22+
int8_t axis;
23+
size_t *shape;
24+
int32_t *strides;
25+
} shape_strides;
2026

2127
#define RUN_ARGMIN1(ndarray, type, array, results, rarray, index, op)\
2228
({\
@@ -568,6 +574,8 @@
568574

569575
#endif
570576

577+
MP_DECLARE_CONST_FUN_OBJ_KW(numerical_all_obj);
578+
MP_DECLARE_CONST_FUN_OBJ_KW(numerical_any_obj);
571579
MP_DECLARE_CONST_FUN_OBJ_KW(numerical_argmax_obj);
572580
MP_DECLARE_CONST_FUN_OBJ_KW(numerical_argmin_obj);
573581
MP_DECLARE_CONST_FUN_OBJ_KW(numerical_argsort_obj);

code/numpy/numpy.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,12 @@ static const mp_rom_map_elem_t ulab_numpy_globals_table[] = {
149149
{ MP_OBJ_NEW_QSTR(MP_QSTR_convolve), (mp_obj_t)&filter_convolve_obj },
150150
#endif
151151
// functions of the numerical sub-module
152+
#if ULAB_NUMPY_HAS_ALL
153+
{ MP_OBJ_NEW_QSTR(MP_QSTR_all), (mp_obj_t)&numerical_all_obj },
154+
#endif
155+
#if ULAB_NUMPY_HAS_ANY
156+
{ MP_OBJ_NEW_QSTR(MP_QSTR_any), (mp_obj_t)&numerical_any_obj },
157+
#endif
152158
#if ULAB_NUMPY_HAS_ARGMINMAX
153159
{ MP_OBJ_NEW_QSTR(MP_QSTR_argmax), (mp_obj_t)&numerical_argmax_obj },
154160
{ MP_OBJ_NEW_QSTR(MP_QSTR_argmin), (mp_obj_t)&numerical_argmin_obj },

code/ulab.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
#include "user/user.h"
3535

36-
#define ULAB_VERSION 2.2.0
36+
#define ULAB_VERSION 2.3.0
3737
#define xstr(s) str(s)
3838
#define str(s) #s
3939
#define ULAB_VERSION_STRING xstr(ULAB_VERSION) xstr(-) xstr(ULAB_MAX_DIMS) xstr(D)

code/ulab.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,14 @@
375375
#define ULAB_FFT_HAS_IFFT (1)
376376
#endif
377377

378+
#ifndef ULAB_NUMPY_HAS_ALL
379+
#define ULAB_NUMPY_HAS_ALL (1)
380+
#endif
381+
382+
#ifndef ULAB_NUMPY_HAS_ANY
383+
#define ULAB_NUMPY_HAS_ANY (1)
384+
#endif
385+
378386
#ifndef ULAB_NUMPY_HAS_ARGMINMAX
379387
#define ULAB_NUMPY_HAS_ARGMINMAX (1)
380388
#endif

docs/manual/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
author = 'Zoltán Vörös'
2828

2929
# The full version, including alpha/beta/rc tags
30-
release = '2.2.0'
30+
release = '2.3.0'
3131

3232

3333
# -- General configuration ---------------------------------------------------

0 commit comments

Comments
 (0)