@@ -262,72 +262,66 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
262
262
return mp_obj_new_float (MICROPY_FLOAT_C_FUN (sqrt )(S / (ndarray -> len - ddof )));
263
263
}
264
264
} else {
265
- int8_t ax = mp_obj_get_int (axis );
266
- if (ax < 0 ) ax += ndarray -> ndim ;
267
- if ((ax < 0 ) || (ax > ndarray -> ndim - 1 )) {
268
- mp_raise_ValueError (translate ("index out of range" ));
269
- }
270
- numerical_reduce_axes (ndarray , ax , shape , strides );
271
- uint8_t index = ULAB_MAX_DIMS - ndarray -> ndim + ax ;
265
+ shape_strides _shape_strides = tools_reduce_axes (ndarray , axis );
272
266
ndarray_obj_t * results = NULL ;
273
267
uint8_t * rarray = NULL ;
274
268
275
269
if (optype == NUMERICAL_SUM ) {
276
- results = ndarray_new_dense_ndarray (MAX (1 , ndarray -> ndim - 1 ), shape , ndarray -> dtype );
270
+ results = ndarray_new_dense_ndarray (MAX (1 , ndarray -> ndim - 1 ), _shape_strides . shape , ndarray -> dtype );
277
271
rarray = (uint8_t * )results -> array ;
278
272
// TODO: numpy promotes the output to the highest integer type
279
273
if (ndarray -> dtype == NDARRAY_UINT8 ) {
280
- RUN_SUM (ndarray , uint8_t , array , results , rarray , shape , strides , index );
274
+ RUN_SUM (uint8_t , array , results , rarray , _shape_strides );
281
275
} else if (ndarray -> dtype == NDARRAY_INT8 ) {
282
- RUN_SUM (ndarray , int8_t , array , results , rarray , shape , strides , index );
276
+ RUN_SUM (int8_t , array , results , rarray , _shape_strides );
283
277
} else if (ndarray -> dtype == NDARRAY_UINT16 ) {
284
- RUN_SUM (ndarray , uint16_t , array , results , rarray , shape , strides , index );
278
+ RUN_SUM (uint16_t , array , results , rarray , _shape_strides );
285
279
} else if (ndarray -> dtype == NDARRAY_INT16 ) {
286
- RUN_SUM (ndarray , int16_t , array , results , rarray , shape , strides , index );
280
+ RUN_SUM (int16_t , array , results , rarray , _shape_strides );
287
281
} else {
288
282
// for floats, the sum might be inaccurate with the naive summation
289
283
// call mean, and multiply with the number of samples
290
284
mp_float_t * r = (mp_float_t * )results -> array ;
291
- RUN_MEAN (ndarray , mp_float_t , array , results , r , shape , strides , index );
292
- mp_float_t norm = (mp_float_t )ndarray -> shape [index ];
285
+ RUN_MEAN (mp_float_t , array , results , r , _shape_strides );
286
+ mp_float_t norm = (mp_float_t )_shape_strides . shape [0 ];
293
287
// re-wind the array here
294
288
r = (mp_float_t * )results -> array ;
295
289
for (size_t i = 0 ; i < results -> len ; i ++ ) {
296
290
* r ++ *= norm ;
297
291
}
298
292
}
299
293
} else if (optype == NUMERICAL_MEAN ) {
300
- results = ndarray_new_dense_ndarray (MAX (1 , ndarray -> ndim - 1 ), shape , NDARRAY_FLOAT );
294
+ results = ndarray_new_dense_ndarray (MAX (1 , ndarray -> ndim - 1 ), _shape_strides . shape , NDARRAY_FLOAT );
301
295
mp_float_t * r = (mp_float_t * )results -> array ;
302
296
if (ndarray -> dtype == NDARRAY_UINT8 ) {
303
- RUN_MEAN (ndarray , uint8_t , array , results , r , shape , strides , index );
297
+ RUN_MEAN (uint8_t , array , results , r , _shape_strides );
304
298
} else if (ndarray -> dtype == NDARRAY_INT8 ) {
305
- RUN_MEAN (ndarray , int8_t , array , results , r , shape , strides , index );
299
+ RUN_MEAN (int8_t , array , results , r , _shape_strides );
306
300
} else if (ndarray -> dtype == NDARRAY_UINT16 ) {
307
- RUN_MEAN (ndarray , uint16_t , array , results , r , shape , strides , index );
301
+ RUN_MEAN (uint16_t , array , results , r , _shape_strides );
308
302
} else if (ndarray -> dtype == NDARRAY_INT16 ) {
309
- RUN_MEAN (ndarray , int16_t , array , results , r , shape , strides , index );
303
+ RUN_MEAN (int16_t , array , results , r , _shape_strides );
310
304
} else {
311
- RUN_MEAN (ndarray , mp_float_t , array , results , r , shape , strides , index );
305
+ RUN_MEAN (mp_float_t , array , results , r , _shape_strides );
312
306
}
313
307
} else { // this case is certainly the standard deviation
314
308
results = ndarray_new_dense_ndarray (MAX (1 , ndarray -> ndim - 1 ), shape , NDARRAY_FLOAT );
315
309
// we can return the 0 array here, if the degrees of freedom is larger than the length of the axis
316
- if (ndarray -> shape [index ] <= ddof ) {
310
+ if (_shape_strides . shape [0 ] <= ddof ) {
317
311
return MP_OBJ_FROM_PTR (results );
318
312
}
319
- mp_float_t div = (mp_float_t )(ndarray -> shape [index ] - ddof );
313
+ mp_float_t div = (mp_float_t )(_shape_strides . shape [0 ] - ddof );
320
314
mp_float_t * r = (mp_float_t * )results -> array ;
321
315
if (ndarray -> dtype == NDARRAY_UINT8 ) {
322
- RUN_STD (ndarray , uint8_t , array , results , r , shape , strides , index , div );
316
+ RUN_STD (uint8_t , array , results , r , _shape_strides , div );
323
317
} else if (ndarray -> dtype == NDARRAY_INT8 ) {
324
- RUN_STD (ndarray , int8_t , array , results , r , shape , strides , index , div );
318
+ RUN_STD (int8_t , array , results , r , _shape_strides , div );
325
319
} else if (ndarray -> dtype == NDARRAY_UINT16 ) {
326
- RUN_STD (ndarray , uint16_t , array , results , r , shape , strides , index , div );
320
+ RUN_STD (uint16_t , array , results , r , _shape_strides , div );
327
321
} else if (ndarray -> dtype == NDARRAY_INT16 ) {
328
- RUN_STD (ndarray , int16_t , array , results , r , shape , strides , index , div );
322
+ RUN_STD (int16_t , array , results , r , _shape_strides , div );
329
323
} else {
330
- RUN_STD (ndarray , mp_float_t , array , results , r , shape , strides , index , div );
324
+ RUN_STD (mp_float_t , array , results , r , _shape_strides , div );
331
325
}
332
326
}
333
327
if (ndarray -> ndim == 1 ) { // return a scalar here
0 commit comments