Skip to content

Commit 0ea7871

Browse files
authored
Merge pull request #316 from v923z/stats
simplify array contraction
2 parents 96a944c + 7de1d09 commit 0ea7871

File tree

12 files changed

+339
-264
lines changed

12 files changed

+339
-264
lines changed

code/micropython.mk

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ SRC_USERMOD += $(USERMODULES_DIR)/numpy/linalg/linalg.c
1818
SRC_USERMOD += $(USERMODULES_DIR)/numpy/linalg/linalg_tools.c
1919
SRC_USERMOD += $(USERMODULES_DIR)/numpy/numerical/numerical.c
2020
SRC_USERMOD += $(USERMODULES_DIR)/numpy/poly/poly.c
21+
SRC_USERMOD += $(USERMODULES_DIR)/numpy/stats/stats.c
2122
SRC_USERMOD += $(USERMODULES_DIR)/numpy/vector/vector.c
2223
SRC_USERMOD += $(USERMODULES_DIR)/user/user.c
2324

code/ndarray.c

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,6 @@ bool ndarray_is_dense(ndarray_obj_t *ndarray) {
598598

599599
ndarray_obj_t *ndarray_new_ndarray(uint8_t ndim, size_t *shape, int32_t *strides, uint8_t dtype) {
600600
// Creates the base ndarray with shape, and initialises the values to straight 0s
601-
// the function should work in the general n-dimensional case
602601
ndarray_obj_t *ndarray = m_new_obj(ndarray_obj_t);
603602
ndarray->base.type = &ulab_ndarray_type;
604603
ndarray->dtype = dtype == NDARRAY_BOOL ? NDARRAY_UINT8 : dtype;
@@ -618,10 +617,12 @@ ndarray_obj_t *ndarray_new_ndarray(uint8_t ndim, size_t *shape, int32_t *strides
618617
ndarray->len *= shape[i-1];
619618
}
620619

621-
uint8_t *array = m_new(byte, ndarray->itemsize * ndarray->len);
620+
// if the length is 0, still allocate a single item, so that contractions can be handled
621+
size_t len = ndarray->itemsize * MAX(1, ndarray->len);
622+
uint8_t *array = m_new(byte, len);
622623
// this should set all elements to 0, irrespective of the of the dtype (all bits are zero)
623624
// we could, perhaps, leave this step out, and initialise the array only, when needed
624-
memset(array, 0, ndarray->len * ndarray->itemsize);
625+
memset(array, 0, len);
625626
ndarray->array = array;
626627
return ndarray;
627628
}

code/numpy/linalg/linalg.c

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -385,14 +385,8 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
385385
// always get a float, so that we don't have to resolve the dtype later
386386
mp_float_t (*func)(void *) = ndarray_get_float_function(ndarray->dtype);
387387
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
388-
mp_float_t *rarray = NULL;
389-
ndarray_obj_t *results = NULL;
390-
if((axis != mp_const_none) && (ndarray->ndim > 1)) {
391-
results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_FLOAT);
392-
rarray = results->array;
393-
} else {
394-
rarray = m_new(mp_float_t, 1);
395-
}
388+
ndarray_obj_t *results = ndarray_new_dense_ndarray(_shape_strides.ndim, _shape_strides.shape, NDARRAY_FLOAT);
389+
mp_float_t *rarray = (mp_float_t *)results->array;
396390

397391
#if ULAB_MAX_DIMS > 3
398392
size_t i = 0;
@@ -418,28 +412,26 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
418412
l++;
419413
} while(l < _shape_strides.shape[0]);
420414
*rarray = MICROPY_FLOAT_C_FUN(sqrt)(dot * (count - 1));
421-
if(results != NULL) {
422-
rarray++;
423-
}
424415
#if ULAB_MAX_DIMS > 1
416+
rarray += _shape_strides.increment;
425417
array -= _shape_strides.strides[0] * _shape_strides.shape[0];
426418
array += _shape_strides.strides[ULAB_MAX_DIMS - 1];
427419
k++;
428420
} while(k < _shape_strides.shape[ULAB_MAX_DIMS - 1]);
429421
#endif
430422
#if ULAB_MAX_DIMS > 2
431-
array -= _shape_strides.strides[ULAB_MAX_DIMS - 1] * _shape_strides.shape[ULAB_MAX_DIMS-1];
423+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 1] * _shape_strides.shape[ULAB_MAX_DIMS - 1];
432424
array += _shape_strides.strides[ULAB_MAX_DIMS - 2];
433425
j++;
434426
} while(j < _shape_strides.shape[ULAB_MAX_DIMS - 2]);
435427
#endif
436428
#if ULAB_MAX_DIMS > 3
437-
array -= _shape_strides.strides[ULAB_MAX_DIMS - 2] * _shape_strides.shape[ULAB_MAX_DIMS-2];
429+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 2] * _shape_strides.shape[ULAB_MAX_DIMS - 2];
438430
array += _shape_strides.strides[ULAB_MAX_DIMS - 3];
439431
i++;
440432
} while(i < _shape_strides.shape[ULAB_MAX_DIMS - 3]);
441433
#endif
442-
if(results == NULL) {
434+
if(results->ndim == 0) {
443435
return mp_obj_new_float(*rarray);
444436
}
445437
return results;

code/numpy/numerical/numerical.c

Lines changed: 81 additions & 131 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)