Skip to content

Commit f7d86b5

Browse files
committed
completed any/all implementation
1 parent e485b0c commit f7d86b5

File tree

2 files changed

+184
-31
lines changed

2 files changed

+184
-31
lines changed

code/numpy/numerical/numerical.c

Lines changed: 175 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525
#include "numerical.h"
2626

2727
enum NUMERICAL_FUNCTION_TYPE {
28-
NUMERICAL_MIN,
29-
NUMERICAL_MAX,
30-
NUMERICAL_ARGMIN,
28+
NUMERICAL_ALL,
29+
NUMERICAL_ANY,
3130
NUMERICAL_ARGMAX,
32-
NUMERICAL_SUM,
31+
NUMERICAL_ARGMIN,
32+
NUMERICAL_MAX,
3333
NUMERICAL_MEAN,
34+
NUMERICAL_MIN,
3435
NUMERICAL_STD,
36+
NUMERICAL_SUM,
3537
};
3638

3739
//| """Numerical and Statistical functions
@@ -61,36 +63,164 @@ static void numerical_reduce_axes(ndarray_obj_t *ndarray, int8_t axis, size_t *s
6163
}
6264
}
6365

64-
#if ULAB_NUMPY_HAS_ALL | ULAB_NUMPY_HAS_ANY
65-
static mp_obj_t numerical_all_any_iterable(mp_obj_t oin, bool anytype) {
66-
if(mp_obj_is_int(oin) || mp_obj_is_float(oin)) {
67-
return mp_obj_is_true(oin) ? mp_const_true : mp_const_false;
66+
static shape_strides numerical_reduce_axes_(ndarray_obj_t *ndarray, mp_obj_t axis) {
67+
// TODO: replace numerical_reduce_axes with this function, wherever applicable
68+
int8_t ax = mp_obj_get_int(axis);
69+
if(ax < 0) ax += ndarray->ndim;
70+
if((ax < 0) || (ax > ndarray->ndim - 1)) {
71+
mp_raise_ValueError(translate("index out of range"));
6872
}
69-
mp_obj_iter_buf_t iter_buf;
70-
mp_obj_t item, iterable = mp_getiter(oin, &iter_buf);
71-
while((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) {
72-
if(!mp_obj_is_true(item) & !anytype) {
73-
return mp_const_false;
74-
} else if(mp_obj_is_true(item) & anytype) {
75-
return mp_const_true;
73+
shape_strides _shape_strides;
74+
_shape_strides.index = ULAB_MAX_DIMS - ndarray->ndim + ax;
75+
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
76+
memset(shape, 0, sizeof(size_t)*ULAB_MAX_DIMS);
77+
_shape_strides.shape = shape;
78+
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
79+
memset(strides, 0, sizeof(uint32_t)*ULAB_MAX_DIMS);
80+
_shape_strides.strides = strides;
81+
if((ndarray->ndim == 1) && (_shape_strides.axis == 0)) {
82+
_shape_strides.index = 0;
83+
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
84+
} else {
85+
for(uint8_t i = ULAB_MAX_DIMS - 1; i > 0; i--) {
86+
if(i > _shape_strides.index) {
87+
_shape_strides.shape[i] = ndarray->shape[i];
88+
_shape_strides.strides[i] = ndarray->strides[i];
89+
} else {
90+
_shape_strides.shape[i] = ndarray->shape[i-1];
91+
_shape_strides.strides[i] = ndarray->strides[i-1];
92+
}
7693
}
7794
}
78-
return anytype ? mp_const_false : mp_const_true;
79-
}
80-
81-
#if ULAB_NUMPY_HAS_ALL
82-
mp_obj_t numerical_all(mp_obj_t oin) {
83-
return numerical_all_any_iterable(oin, false);
95+
return _shape_strides;
8496
}
85-
MP_DEFINE_CONST_FUN_OBJ_1(numerical_all_obj, numerical_all);
86-
#endif
8797

88-
#if ULAB_NUMPY_HAS_ANY
89-
mp_obj_t numerical_any(mp_obj_t oin) {
90-
return numerical_all_any_iterable(oin, true);
98+
#if ULAB_NUMPY_HAS_ALL | ULAB_NUMPY_HAS_ANY
99+
static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
100+
bool anytype = optype == NUMERICAL_ALL ? 1 : 0;
101+
if(MP_OBJ_IS_TYPE(oin, &ulab_ndarray_type)) {
102+
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(oin);
103+
uint8_t *array = (uint8_t *)ndarray->array;
104+
// always get a float, so that we don't have to resolve the dtype later
105+
mp_float_t (*func)(void *) = ndarray_get_float_function(ndarray->dtype);
106+
if(axis == mp_const_none) {
107+
#if ULAB_MAX_DIMS > 3
108+
size_t i = 0;
109+
do {
110+
#endif
111+
#if ULAB_MAX_DIMS > 2
112+
size_t j = 0;
113+
do {
114+
#endif
115+
#if ULAB_MAX_DIMS > 1
116+
size_t k = 0;
117+
do {
118+
#endif
119+
size_t l = 0;
120+
do {
121+
mp_float_t value = func(array);
122+
if((value != MICROPY_FLOAT_CONST(0.0)) & !anytype) {
123+
// optype = NUMERICAL_ANY
124+
return mp_const_true;
125+
} else if((value == MICROPY_FLOAT_CONST(0.0)) & anytype) {
126+
// optype == NUMERICAL_ALL
127+
return mp_const_false;
128+
}
129+
array += ndarray->strides[ULAB_MAX_DIMS - 1];
130+
l++;
131+
} while(l < ndarray->shape[ULAB_MAX_DIMS - 1]);
132+
#if ULAB_MAX_DIMS > 1
133+
array -= ndarray->strides[ULAB_MAX_DIMS - 1] * ndarray->shape[ULAB_MAX_DIMS-1];
134+
array += ndarray->strides[ULAB_MAX_DIMS - 2];
135+
k++;
136+
} while(k < ndarray->shape[ULAB_MAX_DIMS - 2]);
137+
#endif
138+
#if ULAB_MAX_DIMS > 2
139+
array -= ndarray->strides[ULAB_MAX_DIMS - 2] * ndarray->shape[ULAB_MAX_DIMS-2];
140+
array += ndarray->strides[ULAB_MAX_DIMS - 3];
141+
j++;
142+
} while(j < ndarray->shape[ULAB_MAX_DIMS - 3]);
143+
#endif
144+
#if ULAB_MAX_DIMS > 3
145+
array -= ndarray->strides[ULAB_MAX_DIMS - 3] * ndarray->shape[ULAB_MAX_DIMS-3];
146+
array += ndarray->strides[ULAB_MAX_DIMS - 4];
147+
i++;
148+
} while(i < ndarray->shape[ULAB_MAX_DIMS - 4]);
149+
#endif
150+
} else {
151+
shape_strides _shape_strides = numerical_reduce_axes_(ndarray, axis);
152+
ndarray_obj_t *results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_BOOL);
153+
uint8_t *rarray = (uint8_t *)results->array;
154+
if(optype == NUMERICAL_ALL) {
155+
memset(rarray, 1, results->len);
156+
}
157+
#if ULAB_MAX_DIMS > 3
158+
size_t i = 0;
159+
do {
160+
#endif
161+
#if ULAB_MAX_DIMS > 2
162+
size_t j = 0;
163+
do {
164+
#endif
165+
#if ULAB_MAX_DIMS > 1
166+
size_t k = 0;
167+
do {
168+
#endif
169+
size_t l = 0;
170+
do {
171+
mp_float_t value = func(array);
172+
if((value != MICROPY_FLOAT_CONST(0.0)) & !anytype) {
173+
// optype == NUMERICAL_ANY
174+
*rarray = 1;
175+
// since we are breaking out of the loop, move the pointer forward
176+
array += ndarray->strides[_shape_strides.index] * (ndarray->shape[_shape_strides.index] - l);
177+
break;
178+
} else if((value == MICROPY_FLOAT_CONST(0.0)) & anytype) {
179+
// optype == NUMERICAL_ALL
180+
*rarray = 0;
181+
// since we are breaking out of the loop, move the pointer forward
182+
array += ndarray->strides[_shape_strides.index] * (ndarray->shape[_shape_strides.index] - l);
183+
break;
184+
}
185+
array += ndarray->strides[_shape_strides.index];
186+
l++;
187+
} while(l < ndarray->shape[_shape_strides.index]);
188+
#if ULAB_MAX_DIMS > 1
189+
rarray++;
190+
array -= ndarray->strides[_shape_strides.index] * ndarray->shape[_shape_strides.index];
191+
array += _shape_strides.strides[ULAB_MAX_DIMS - 1];
192+
k++;
193+
} while(k < _shape_strides.shape[ULAB_MAX_DIMS - 1]);
194+
#endif
195+
#if ULAB_MAX_DIMS > 2
196+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 1] * _shape_strides.shape[ULAB_MAX_DIMS-1];
197+
array += _shape_strides.strides[ULAB_MAX_DIMS - 2];
198+
j++;
199+
} while(j < _shape_strides.shape[ULAB_MAX_DIMS - 2]);
200+
#endif
201+
#if ULAB_MAX_DIMS > 3
202+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 2] * _shape_strides.shape[ULAB_MAX_DIMS-2];
203+
array += _shape_strides.strides[ULAB_MAX_DIMS - 3];
204+
i++;
205+
} while(i < _shape_strides.shape[ULAB_MAX_DIMS - 3])
206+
#endif
207+
return results;
208+
}
209+
} else if(mp_obj_is_int(oin) || mp_obj_is_float(oin)) {
210+
return mp_obj_is_true(oin) ? mp_const_true : mp_const_false;
211+
} else {
212+
mp_obj_iter_buf_t iter_buf;
213+
mp_obj_t item, iterable = mp_getiter(oin, &iter_buf);
214+
while((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) {
215+
if(!mp_obj_is_true(item) & !anytype) {
216+
return mp_const_false;
217+
} else if(mp_obj_is_true(item) & anytype) {
218+
return mp_const_true;
219+
}
220+
}
221+
}
222+
return anytype ? mp_const_true : mp_const_false;
91223
}
92-
MP_DEFINE_CONST_FUN_OBJ_1(numerical_any_obj, numerical_any);
93-
#endif
94224
#endif
95225

96226
#if ULAB_NUMPY_HAS_SUM | ULAB_NUMPY_HAS_MEAN | ULAB_NUMPY_HAS_STD
@@ -435,6 +565,9 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
435565
mp_raise_TypeError(translate("axis must be None, or an integer"));
436566
}
437567

568+
if((optype == NUMERICAL_ALL) || (optype == NUMERICAL_ANY)) {
569+
return numerical_all_any(oin, axis, optype);
570+
}
438571
if(MP_OBJ_IS_TYPE(oin, &mp_type_tuple) || MP_OBJ_IS_TYPE(oin, &mp_type_list) ||
439572
MP_OBJ_IS_TYPE(oin, &mp_type_range)) {
440573
switch(optype) {
@@ -525,6 +658,20 @@ static mp_obj_t numerical_sort_helper(mp_obj_t oin, mp_obj_t axis, uint8_t inpla
525658
}
526659
#endif /* ULAB_NUMERICAL_HAS_SORT | NDARRAY_HAS_SORT */
527660

661+
#if ULAB_NUMPY_HAS_ALL
662+
mp_obj_t numerical_all(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
663+
return numerical_function(n_args, pos_args, kw_args, NUMERICAL_ALL);
664+
}
665+
MP_DEFINE_CONST_FUN_OBJ_KW(numerical_all_obj, 1, numerical_all);
666+
#endif
667+
668+
#if ULAB_NUMPY_HAS_ANY
669+
mp_obj_t numerical_any(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
670+
return numerical_function(n_args, pos_args, kw_args, NUMERICAL_ANY);
671+
}
672+
MP_DEFINE_CONST_FUN_OBJ_KW(numerical_any_obj, 1, numerical_any);
673+
#endif
674+
528675
#if ULAB_NUMPY_HAS_ARGMINMAX
529676
//| def argmax(array: _ArrayLike, *, axis: Optional[int] = None) -> int:
530677
//| """Return the index of the maximum element of the 1D array"""

code/numpy/numerical/numerical.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
#include "../../ndarray.h"
1717

1818
// TODO: implement cumsum
19-
//mp_obj_t numerical_cumsum(size_t , const mp_obj_t *, mp_map_t *);
19+
20+
typedef struct {
21+
uint8_t index;
22+
int8_t axis;
23+
size_t *shape;
24+
int32_t *strides;
25+
} shape_strides;
2026

2127
#define RUN_ARGMIN1(ndarray, type, array, results, rarray, index, op)\
2228
({\
@@ -568,8 +574,8 @@
568574

569575
#endif
570576

571-
MP_DECLARE_CONST_FUN_OBJ_1(numerical_all_obj);
572-
MP_DECLARE_CONST_FUN_OBJ_1(numerical_any_obj);
577+
MP_DECLARE_CONST_FUN_OBJ_KW(numerical_all_obj);
578+
MP_DECLARE_CONST_FUN_OBJ_KW(numerical_any_obj);
573579
MP_DECLARE_CONST_FUN_OBJ_KW(numerical_argmax_obj);
574580
MP_DECLARE_CONST_FUN_OBJ_KW(numerical_argmin_obj);
575581
MP_DECLARE_CONST_FUN_OBJ_KW(numerical_argsort_obj);

0 commit comments

Comments
 (0)