Skip to content

Commit 93f70d1

Browse files
committed
rationalised code in sum/mean/std
1 parent fd8a225 commit 93f70d1

File tree

2 files changed

+109
-115
lines changed

2 files changed

+109
-115
lines changed

code/numpy/numerical/numerical.c

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -262,72 +262,66 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
262262
return mp_obj_new_float(MICROPY_FLOAT_C_FUN(sqrt)(S / (ndarray->len - ddof)));
263263
}
264264
} else {
265-
int8_t ax = mp_obj_get_int(axis);
266-
if(ax < 0) ax += ndarray->ndim;
267-
if((ax < 0) || (ax > ndarray->ndim - 1)) {
268-
mp_raise_ValueError(translate("index out of range"));
269-
}
270-
numerical_reduce_axes(ndarray, ax, shape, strides);
271-
uint8_t index = ULAB_MAX_DIMS - ndarray->ndim + ax;
265+
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
272266
ndarray_obj_t *results = NULL;
273267
uint8_t *rarray = NULL;
274268

275269
if(optype == NUMERICAL_SUM) {
276-
results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), shape, ndarray->dtype);
270+
results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, ndarray->dtype);
277271
rarray = (uint8_t *)results->array;
278272
// TODO: numpy promotes the output to the highest integer type
279273
if(ndarray->dtype == NDARRAY_UINT8) {
280-
RUN_SUM(ndarray, uint8_t, array, results, rarray, shape, strides, index);
274+
RUN_SUM(uint8_t, array, results, rarray, _shape_strides);
281275
} else if(ndarray->dtype == NDARRAY_INT8) {
282-
RUN_SUM(ndarray, int8_t, array, results, rarray, shape, strides, index);
276+
RUN_SUM(int8_t, array, results, rarray, _shape_strides);
283277
} else if(ndarray->dtype == NDARRAY_UINT16) {
284-
RUN_SUM(ndarray, uint16_t, array, results, rarray, shape, strides, index);
278+
RUN_SUM(uint16_t, array, results, rarray, _shape_strides);
285279
} else if(ndarray->dtype == NDARRAY_INT16) {
286-
RUN_SUM(ndarray, int16_t, array, results, rarray, shape, strides, index);
280+
RUN_SUM(int16_t, array, results, rarray, _shape_strides);
287281
} else {
288282
// for floats, the sum might be inaccurate with the naive summation
289283
// call mean, and multiply with the number of samples
290284
mp_float_t *r = (mp_float_t *)results->array;
291-
RUN_MEAN(ndarray, mp_float_t, array, results, r, shape, strides, index);
292-
mp_float_t norm = (mp_float_t)ndarray->shape[index];
285+
RUN_MEAN(mp_float_t, array, results, r, _shape_strides);
286+
mp_float_t norm = (mp_float_t)_shape_strides.shape[0];
293287
// re-wind the array here
294288
r = (mp_float_t *)results->array;
295289
for(size_t i=0; i < results->len; i++) {
296290
*r++ *= norm;
297291
}
298292
}
299293
} else if(optype == NUMERICAL_MEAN) {
300-
results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), shape, NDARRAY_FLOAT);
294+
results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_FLOAT);
301295
mp_float_t *r = (mp_float_t *)results->array;
302296
if(ndarray->dtype == NDARRAY_UINT8) {
303-
RUN_MEAN(ndarray, uint8_t, array, results, r, shape, strides, index);
297+
RUN_MEAN(uint8_t, array, results, r, _shape_strides);
304298
} else if(ndarray->dtype == NDARRAY_INT8) {
305-
RUN_MEAN(ndarray, int8_t, array, results, r, shape, strides, index);
299+
RUN_MEAN(int8_t, array, results, r, _shape_strides);
306300
} else if(ndarray->dtype == NDARRAY_UINT16) {
307-
RUN_MEAN(ndarray, uint16_t, array, results, r, shape, strides, index);
301+
RUN_MEAN(uint16_t, array, results, r, _shape_strides);
308302
} else if(ndarray->dtype == NDARRAY_INT16) {
309-
RUN_MEAN(ndarray, int16_t, array, results, r, shape, strides, index);
303+
RUN_MEAN(int16_t, array, results, r, _shape_strides);
310304
} else {
311-
RUN_MEAN(ndarray, mp_float_t, array, results, r, shape, strides, index);
305+
RUN_MEAN(mp_float_t, array, results, r, _shape_strides);
312306
}
313307
} else { // this case is certainly the standard deviation
314308
results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), shape, NDARRAY_FLOAT);
315309
// we can return the 0 array here, if the degrees of freedom is larger than the length of the axis
316-
if(ndarray->shape[index] <= ddof) {
310+
if(_shape_strides.shape[0] <= ddof) {
317311
return MP_OBJ_FROM_PTR(results);
318312
}
319-
mp_float_t div = (mp_float_t)(ndarray->shape[index] - ddof);
313+
mp_float_t div = (mp_float_t)(_shape_strides.shape[0] - ddof);
320314
mp_float_t *r = (mp_float_t *)results->array;
321315
if(ndarray->dtype == NDARRAY_UINT8) {
322-
RUN_STD(ndarray, uint8_t, array, results, r, shape, strides, index, div);
316+
RUN_STD(uint8_t, array, results, r, _shape_strides, div);
323317
} else if(ndarray->dtype == NDARRAY_INT8) {
324-
RUN_STD(ndarray, int8_t, array, results, r, shape, strides, index, div);
318+
RUN_STD(int8_t, array, results, r, _shape_strides, div);
325319
} else if(ndarray->dtype == NDARRAY_UINT16) {
326-
RUN_STD(ndarray, uint16_t, array, results, r, shape, strides, index, div);
320+
RUN_STD(uint16_t, array, results, r, _shape_strides, div);
327321
} else if(ndarray->dtype == NDARRAY_INT16) {
328-
RUN_STD(ndarray, int16_t, array, results, r, shape, strides, index, div);
322+
RUN_STD(int16_t, array, results, r, _shape_strides, div);
329323
} else {
330-
RUN_STD(ndarray, mp_float_t, array, results, r, shape, strides, index, div);
324+
RUN_STD(mp_float_t, array, results, r, _shape_strides, div);
331325
}
332326
}
333327
if(ndarray->ndim == 1) { // return a scalar here

0 commit comments

Comments
 (0)