Skip to content

Commit 0b20b30

Browse files
committed
combined macros for std and mean
1 parent 6499453 commit 0b20b30

File tree

2 files changed

+110
-54
lines changed

2 files changed

+110
-54
lines changed

code/numpy/numerical/numerical.c

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,11 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
259259
}
260260
} else {
261261
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
262+
// if(ndarray->ndim == 1) {
263+
// // if we have the single dimension, axis = 0 is equivalent to axis = None
264+
// // the call to tools_reduce_axes() has made sure that axis = 0
265+
// return numerical_sum_mean_std_ndarray(ndarray, mp_const_none, optype, ddof);
266+
// }
262267
ndarray_obj_t *results = NULL;
263268
uint8_t *rarray = NULL;
264269

@@ -278,46 +283,33 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
278283
// for floats, the sum might be inaccurate with the naive summation
279284
// call mean, and multiply with the number of samples
280285
mp_float_t *r = (mp_float_t *)results->array;
281-
RUN_MEAN(mp_float_t, array, results, r, _shape_strides);
286+
RUN_MEAN_STD(mp_float_t, array, r, _shape_strides, 0.0, 0);
282287
mp_float_t norm = (mp_float_t)_shape_strides.shape[0];
283288
// re-wind the array here
284289
r = (mp_float_t *)results->array;
285290
for(size_t i=0; i < results->len; i++) {
286291
*r++ *= norm;
287292
}
288293
}
289-
} else if(optype == NUMERICAL_MEAN) {
290-
results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_FLOAT);
291-
mp_float_t *r = (mp_float_t *)results->array;
292-
if(ndarray->dtype == NDARRAY_UINT8) {
293-
RUN_MEAN(uint8_t, array, results, r, _shape_strides);
294-
} else if(ndarray->dtype == NDARRAY_INT8) {
295-
RUN_MEAN(int8_t, array, results, r, _shape_strides);
296-
} else if(ndarray->dtype == NDARRAY_UINT16) {
297-
RUN_MEAN(uint16_t, array, results, r, _shape_strides);
298-
} else if(ndarray->dtype == NDARRAY_INT16) {
299-
RUN_MEAN(int16_t, array, results, r, _shape_strides);
300-
} else {
301-
RUN_MEAN(mp_float_t, array, results, r, _shape_strides);
302-
}
303-
} else { // this case is certainly the standard deviation
294+
} else {
295+
bool isStd = optype == NUMERICAL_STD ? 1 : 0;
304296
results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_FLOAT);
305297
// we can return the 0 array here, if the degrees of freedom is larger than the length of the axis
306-
if(_shape_strides.shape[0] <= ddof) {
298+
if((optype == NUMERICAL_STD) && (_shape_strides.shape[0] <= ddof)) {
307299
return MP_OBJ_FROM_PTR(results);
308300
}
309-
mp_float_t div = (mp_float_t)(_shape_strides.shape[0] - ddof);
310-
mp_float_t *r = (mp_float_t *)results->array;
301+
mp_float_t div = optype == NUMERICAL_STD ? (mp_float_t)(_shape_strides.shape[0] - ddof) : 0.0;
302+
mp_float_t *rarray = (mp_float_t *)results->array;
311303
if(ndarray->dtype == NDARRAY_UINT8) {
312-
RUN_STD(uint8_t, array, results, r, _shape_strides, div);
304+
RUN_MEAN_STD(uint8_t, array, rarray, _shape_strides, div, isStd);
313305
} else if(ndarray->dtype == NDARRAY_INT8) {
314-
RUN_STD(int8_t, array, results, r, _shape_strides, div);
306+
RUN_MEAN_STD(int8_t, array, rarray, _shape_strides, div, isStd);
315307
} else if(ndarray->dtype == NDARRAY_UINT16) {
316-
RUN_STD(uint16_t, array, results, r, _shape_strides, div);
308+
RUN_MEAN_STD(uint16_t, array, rarray, _shape_strides, div, isStd);
317309
} else if(ndarray->dtype == NDARRAY_INT16) {
318-
RUN_STD(int16_t, array, results, r, _shape_strides, div);
310+
RUN_MEAN_STD(int16_t, array, rarray, _shape_strides, div, isStd);
319311
} else {
320-
RUN_STD(mp_float_t, array, results, r, _shape_strides, div);
312+
RUN_MEAN_STD(mp_float_t, array, rarray, _shape_strides, div, isStd);
321313
}
322314
}
323315
if(ndarray->ndim == 1) { // return a scalar here

code/numpy/numerical/numerical.h

Lines changed: 94 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -59,35 +59,46 @@
5959

6060
// The mean could be calculated by simply dividing the sum by
6161
// the number of elements, but that method is numerically unstable
62-
#define RUN_MEAN1(type, array, results, r, ss)\
62+
#define RUN_MEAN1(type, array, rarray, ss)\
6363
({\
64-
mp_float_t M, m;\
65-
M = m = (mp_float_t)(*(type *)(array));\
66-
for(size_t i=1; i < (ss).shape[0]; i++) {\
67-
(array) += (ss).strides[0];\
64+
mp_float_t M = 0.0;\
65+
for(size_t i=0; i < (ss).shape[0]; i++) {\
6866
mp_float_t value = (mp_float_t)(*(type *)(array));\
69-
m = M + (value - M) / (mp_float_t)(i+1);\
70-
M = m;\
67+
M = M + (value - M) / (mp_float_t)(i+1);\
68+
(array) += (ss).strides[0];\
7169
}\
72-
(array) += (ss).strides[0];\
73-
*(r)++ = M;\
70+
*(rarray)++ = M;\
7471
})
7572

7673
// Instead of the straightforward implementation of the definition,
7774
// we take the numerically stable Welford algorithm here
7875
// https://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
79-
#define RUN_STD1(type, array, results, r, ss, div)\
76+
#define RUN_STD1(type, array, rarray, ss, div)\
8077
({\
81-
mp_float_t M = 0.0, m = 0.0, S = 0.0, s = 0.0;\
78+
mp_float_t M = 0.0, m = 0.0, S = 0.0;\
8279
for(size_t i=0; i < (ss).shape[0]; i++) {\
8380
mp_float_t value = (mp_float_t)(*(type *)(array));\
8481
m = M + (value - M) / (mp_float_t)(i+1);\
85-
s = S + (value - M) * (value - m);\
82+
S = S + (value - M) * (value - m);\
83+
M = m;\
84+
(array) += (ss).strides[0];\
85+
}\
86+
*(rarray)++ = MICROPY_FLOAT_C_FUN(sqrt)(S / (div));\
87+
})
88+
89+
#define RUN_MEAN_STD1(type, array, rarray, ss, div, isStd)\
90+
({\
91+
mp_float_t M = 0.0, m = 0.0, S = 0.0;\
92+
for(size_t i=0; i < (ss).shape[0]; i++) {\
93+
mp_float_t value = (mp_float_t)(*(type *)(array));\
94+
m = M + (value - M) / (mp_float_t)(i+1);\
95+
if(isStd) {\
96+
S += (value - M) * (value - m);\
97+
}\
8698
M = m;\
87-
S = s;\
8899
(array) += (ss).strides[0];\
89100
}\
90-
*(r)++ = MICROPY_FLOAT_C_FUN(sqrt)(s / (div));\
101+
*(rarray)++ = isStd ? MICROPY_FLOAT_C_FUN(sqrt)(S / (div)) : M;\
91102
})
92103

93104
#define RUN_DIFF1(ndarray, type, array, results, rarray, index, stencil, N)\
@@ -181,12 +192,16 @@
181192
RUN_SUM1(type, (array), (results), (rarray), (ss));\
182193
} while(0)
183194

184-
#define RUN_MEAN(type, array, results, r, ss) do {\
185-
RUN_MEAN1(type, (array), (results), (r), (ss));\
195+
#define RUN_MEAN(type, array, rarray, ss) do {\
196+
RUN_MEAN1(type, (array), (rarray), (ss));\
186197
} while(0)
187198

188-
#define RUN_STD(type, array, results, r, ss, div) do {\
189-
RUN_STD1(type, (array), (results), (r), (ss), (div));\
199+
#define RUN_STD(type, array, rarray, ss, div) do {\
200+
RUN_STD1(type, (array), (results), (rarray), (ss), (div));\
201+
} while(0)
202+
203+
#define RUN_MEAN_STD(type, array, rarray, ss, div, isStd) do {\
204+
RUN_MEAN_STD1(type, (array), (results), (rarray), (ss), (div), (isStd));\
190205
} while(0)
191206

192207
#define RUN_ARGMIN(ndarray, type, array, results, rarray, shape, strides, index, op) do {\
@@ -218,26 +233,37 @@
218233
} while(l < (ss).shape[ULAB_MAX_DIMS - 1]);\
219234
} while(0)
220235

221-
#define RUN_MEAN(type, array, results, r, ss) do {\
236+
#define RUN_MEAN(type, array, rarray, ss) do {\
222237
size_t l = 0;\
223238
do {\
224-
RUN_MEAN1(type, (array), (results), (r), (ss));\
239+
RUN_MEAN1(type, (array), (rarray), (ss));\
225240
(array) -= (ss).strides[0] * (ss).shape[0];\
226241
(array) += (ss).strides[ULAB_MAX_DIMS - 1];\
227242
l++;\
228243
} while(l < (ss).shape[ULAB_MAX_DIMS - 1]);\
229244
} while(0)
230245

231-
#define RUN_STD(type, array, results, r, ss, div) do {\
246+
#define RUN_STD(type, array, rarray, ss, div) do {\
232247
size_t l = 0;\
233248
do {\
234-
RUN_STD1(type, (array), (results), (r), (ss), (div));\
249+
RUN_STD1(type, (array), (rarray), (ss), (div));\
235250
(array) -= (ss).strides[0] * (ss).shape[0];\
236251
(array) += (ss).strides[ULAB_MAX_DIMS - 1];\
237252
l++;\
238253
} while(l < (ss).shape[ULAB_MAX_DIMS - 1]);\
239254
} while(0)
240255

256+
#define RUN_MEAN_STD(type, array, rarray, ss, div, isStd) do {\
257+
size_t l = 0;\
258+
do {\
259+
RUN_MEAN_STD1(type, (array), (rarray), (ss), (div), (isStd));\
260+
(array) -= (ss).strides[0] * (ss).shape[0];\
261+
(array) += (ss).strides[ULAB_MAX_DIMS - 1];\
262+
l++;\
263+
} while(l < (ss).shape[ULAB_MAX_DIMS - 1]);\
264+
} while(0)
265+
266+
241267
#define RUN_ARGMIN(ndarray, type, array, results, rarray, shape, strides, index, op) do {\
242268
size_t l = 0;\
243269
do {\
@@ -298,12 +324,28 @@
298324
} while(k < (ss).shape[ULAB_MAX_DIMS - 2]);\
299325
} while(0)
300326

301-
#define RUN_MEAN(type, array, results, r, ss) do {\
327+
#define RUN_MEAN(type, array, rarray, ss) do {\
328+
size_t k = 0;\
329+
do {\
330+
size_t l = 0;\
331+
do {\
332+
RUN_MEAN1(type, (array), (rarray), (ss));\
333+
(array) -= (ss).strides[0] * (ss).shape[0];\
334+
(array) += (ss).strides[ULAB_MAX_DIMS - 1];\
335+
l++;\
336+
} while(l < (ss).shape[ULAB_MAX_DIMS - 1]);\
337+
(array) -= (ss).strides[ULAB_MAX_DIMS - 1] * (ss).shape[ULAB_MAX_DIMS-1];\
338+
(array) += (ss).strides[ULAB_MAX_DIMS - 2];\
339+
k++;\
340+
} while(k < (ss).shape[ULAB_MAX_DIMS - 2]);\
341+
} while(0)
342+
343+
#define RUN_STD(type, array, rarray, ss, div) do {\
302344
size_t k = 0;\
303345
do {\
304346
size_t l = 0;\
305347
do {\
306-
RUN_MEAN1(type, (array), (results), (r), (ss));\
348+
RUN_STD1(type, (array), (rarray), (ss), (div));\
307349
(array) -= (ss).strides[0] * (ss).shape[0];\
308350
(array) += (ss).strides[ULAB_MAX_DIMS - 1];\
309351
l++;\
@@ -314,12 +356,12 @@
314356
} while(k < (ss).shape[ULAB_MAX_DIMS - 2]);\
315357
} while(0)
316358

317-
#define RUN_STD(type, array, results, r, ss, div) do {\
359+
#define RUN_MEAN_STD(type, array, rarray, ss, div, isStd) do {\
318360
size_t k = 0;\
319361
do {\
320362
size_t l = 0;\
321363
do {\
322-
RUN_STD1(type, (array), (results), (r), (ss), (div));\
364+
RUN_MEAN_STD1(type, (array), (rarray), (ss), (div), (isStd));\
323365
(array) -= (ss).strides[0] * (ss).shape[0];\
324366
(array) += (ss).strides[ULAB_MAX_DIMS - 1];\
325367
l++;\
@@ -424,14 +466,36 @@
424466
} while(j < (ss).shape[ULAB_MAX_DIMS - 3]);\
425467
} while(0)
426468

427-
#define RUN_MEAN(type, array, results, r, ss) do {\
469+
#define RUN_MEAN(type, array, rarray, ss) do {\
470+
size_t j = 0;\
471+
do {\
472+
size_t k = 0;\
473+
do {\
474+
size_t l = 0;\
475+
do {\
476+
RUN_MEAN1(type, (array), (rarray), (ss));\
477+
(array) -= (ss).strides[0] * (ss).shape[0];\
478+
(array) += (ss).strides[ULAB_MAX_DIMS - 1];\
479+
l++;\
480+
} while(l < (ss).shape[ULAB_MAX_DIMS - 1]);\
481+
(array) -= (ss).strides[ULAB_MAX_DIMS - 1] * (ss).shape[ULAB_MAX_DIMS-1];\
482+
(array) += (ss).strides[ULAB_MAX_DIMS - 2];\
483+
k++;\
484+
} while(k < (ss).shape[ULAB_MAX_DIMS - 2]);\
485+
(array) -= (ss).strides[ULAB_MAX_DIMS - 2] * (ss).shape[ULAB_MAX_DIMS-2];\
486+
(array) += (ss).strides[ULAB_MAX_DIMS - 3];\
487+
j++;\
488+
} while(j < (ss).shape[ULAB_MAX_DIMS - 3]);\
489+
} while(0)
490+
491+
#define RUN_STD(type, array, rarray, ss, div) do {\
428492
size_t j = 0;\
429493
do {\
430494
size_t k = 0;\
431495
do {\
432496
size_t l = 0;\
433497
do {\
434-
RUN_MEAN1(type, (array), (results), (r), (ss));\
498+
RUN_STD1(type, (array), (rarray), (ss), (div));\
435499
(array) -= (ss).strides[0] * (ss).shape[0];\
436500
(array) += (ss).strides[ULAB_MAX_DIMS - 1];\
437501
l++;\
@@ -446,14 +510,14 @@
446510
} while(j < (ss).shape[ULAB_MAX_DIMS - 3]);\
447511
} while(0)
448512

449-
#define RUN_STD(type, array, results, r, ss, div) do {\
513+
#define RUN_MEAN_STD(type, array, rarray, ss, div, isStd) do {\
450514
size_t j = 0;\
451515
do {\
452516
size_t k = 0;\
453517
do {\
454518
size_t l = 0;\
455519
do {\
456-
RUN_STD1(type, (array), (results), (r), (ss), (div));\
520+
RUN_MEAN_STD1(type, (array), (rarray), (ss), (div), (isStd));\
457521
(array) -= (ss).strides[0] * (ss).shape[0];\
458522
(array) += (ss).strides[ULAB_MAX_DIMS - 1];\
459523
l++;\

0 commit comments

Comments
 (0)