Skip to content

Commit e00ad9c

Browse files
authored
Merge pull request #311 from v923z/norm
improved accuracy of linalg.norm, and extended it to generic iterables
2 parents a726c1d + 7c4f4db commit e00ad9c

File tree

3 files changed

+44
-25
lines changed

3 files changed

+44
-25
lines changed

code/numpy/linalg/linalg.c

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -354,33 +354,46 @@ MP_DEFINE_CONST_FUN_OBJ_1(linalg_inv_obj, linalg_inv);
354354
//| ...
355355
//|
356356

357-
static mp_obj_t linalg_norm(mp_obj_t _x) {
358-
if (!MP_OBJ_IS_TYPE(_x, &ulab_ndarray_type)) {
359-
mp_raise_TypeError(translate("argument must be ndarray"));
360-
}
361-
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(_x);
362-
if((ndarray->ndim != 1) && (ndarray->ndim != 2)) {
363-
mp_raise_ValueError(translate("norm is defined for 1D and 2D arrays"));
364-
}
365-
mp_float_t dot = 0.0;
366-
uint8_t *array = (uint8_t *)ndarray->array;
357+
static mp_obj_t linalg_norm(mp_obj_t x) {
358+
mp_float_t dot = 0.0, value;
359+
size_t count = 1;
360+
361+
if(MP_OBJ_IS_TYPE(x, &mp_type_tuple) || MP_OBJ_IS_TYPE(x, &mp_type_list) || MP_OBJ_IS_TYPE(x, &mp_type_range)) {
362+
mp_obj_iter_buf_t iter_buf;
363+
mp_obj_t item, iterable = mp_getiter(x, &iter_buf);
364+
while((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) {
365+
value = mp_obj_get_float(item);
366+
// we could simply take the sum of value ** 2,
367+
// but this method is numerically stable
368+
dot = dot + (value * value - dot) / count++;
369+
}
370+
return mp_obj_new_float(MICROPY_FLOAT_C_FUN(sqrt)(dot * (count - 1)));
371+
} else if(MP_OBJ_IS_TYPE(x, &ulab_ndarray_type)) {
372+
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(x);
373+
if((ndarray->ndim != 1) && (ndarray->ndim != 2)) {
374+
mp_raise_ValueError(translate("norm is defined for 1D and 2D arrays"));
375+
}
376+
uint8_t *array = (uint8_t *)ndarray->array;
367377

368-
mp_float_t (*func)(void *) = ndarray_get_float_function(ndarray->dtype);
378+
mp_float_t (*func)(void *) = ndarray_get_float_function(ndarray->dtype);
369379

370-
size_t k = 0;
371-
do {
372-
size_t l = 0;
380+
size_t k = 0;
373381
do {
374-
mp_float_t v = func(array);
375-
array += ndarray->strides[ULAB_MAX_DIMS - 1];
376-
dot += v*v;
377-
l++;
378-
} while(l < ndarray->shape[ULAB_MAX_DIMS - 1]);
379-
array -= ndarray->strides[ULAB_MAX_DIMS - 1] * ndarray->shape[ULAB_MAX_DIMS - 1];
380-
array += ndarray->strides[ULAB_MAX_DIMS - 2];
381-
k++;
382-
} while(k < ndarray->shape[ULAB_MAX_DIMS - 2]);
383-
return mp_obj_new_float(MICROPY_FLOAT_C_FUN(sqrt)(dot));
382+
size_t l = 0;
383+
do {
384+
value = func(array);
385+
dot = dot + (value * value - dot) / count++;
386+
array += ndarray->strides[ULAB_MAX_DIMS - 1];
387+
l++;
388+
} while(l < ndarray->shape[ULAB_MAX_DIMS - 1]);
389+
array -= ndarray->strides[ULAB_MAX_DIMS - 1] * ndarray->shape[ULAB_MAX_DIMS - 1];
390+
array += ndarray->strides[ULAB_MAX_DIMS - 2];
391+
k++;
392+
} while(k < ndarray->shape[ULAB_MAX_DIMS - 2]);
393+
return mp_obj_new_float(MICROPY_FLOAT_C_FUN(sqrt)(dot * (count - 1)));
394+
} else {
395+
mp_raise_TypeError(translate("argument must be an interable or ndarray"));
396+
}
384397
}
385398

386399
MP_DEFINE_CONST_FUN_OBJ_1(linalg_norm_obj, linalg_norm);

code/ulab.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
#include "user/user.h"
3535

36-
#define ULAB_VERSION 2.3.1
36+
#define ULAB_VERSION 2.3.2
3737
#define xstr(s) str(s)
3838
#define str(s) #s
3939
#define ULAB_VERSION_STRING xstr(ULAB_VERSION) xstr(-) xstr(ULAB_MAX_DIMS) xstr(D)

docs/ulab-change-log.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
Mon, 8 Feb 2021
22

3+
version 2.3.2
4+
5+
improved the accuracy of linalg.norm, and extended it to generic iterables
6+
7+
Mon, 8 Feb 2021
8+
39
version 2.3.1
410

511
partially fix https://github.com/v923z/micropython-ulab/issues/304, and len unary operator

0 commit comments

Comments
 (0)