@@ -354,33 +354,46 @@ MP_DEFINE_CONST_FUN_OBJ_1(linalg_inv_obj, linalg_inv);
354
354
//| ...
355
355
//|
356
356
357
- static mp_obj_t linalg_norm (mp_obj_t _x ) {
358
- if (!MP_OBJ_IS_TYPE (_x , & ulab_ndarray_type )) {
359
- mp_raise_TypeError (translate ("argument must be ndarray" ));
360
- }
361
- ndarray_obj_t * ndarray = MP_OBJ_TO_PTR (_x );
362
- if ((ndarray -> ndim != 1 ) && (ndarray -> ndim != 2 )) {
363
- mp_raise_ValueError (translate ("norm is defined for 1D and 2D arrays" ));
364
- }
365
- mp_float_t dot = 0.0 ;
366
- uint8_t * array = (uint8_t * )ndarray -> array ;
357
+ static mp_obj_t linalg_norm (mp_obj_t x ) {
358
+ mp_float_t dot = 0.0 , value ;
359
+ size_t count = 1 ;
360
+
361
+ if (MP_OBJ_IS_TYPE (x , & mp_type_tuple ) || MP_OBJ_IS_TYPE (x , & mp_type_list ) || MP_OBJ_IS_TYPE (x , & mp_type_range )) {
362
+ mp_obj_iter_buf_t iter_buf ;
363
+ mp_obj_t item , iterable = mp_getiter (x , & iter_buf );
364
+ while ((item = mp_iternext (iterable )) != MP_OBJ_STOP_ITERATION ) {
365
+ value = mp_obj_get_float (item );
366
+ // we could simply take the sum of value ** 2,
367
+ // but this method is numerically stable
368
+ dot = dot + (value * value - dot ) / count ++ ;
369
+ }
370
+ return mp_obj_new_float (MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 )));
371
+ } else if (MP_OBJ_IS_TYPE (x , & ulab_ndarray_type )) {
372
+ ndarray_obj_t * ndarray = MP_OBJ_TO_PTR (x );
373
+ if ((ndarray -> ndim != 1 ) && (ndarray -> ndim != 2 )) {
374
+ mp_raise_ValueError (translate ("norm is defined for 1D and 2D arrays" ));
375
+ }
376
+ uint8_t * array = (uint8_t * )ndarray -> array ;
367
377
368
- mp_float_t (* func )(void * ) = ndarray_get_float_function (ndarray -> dtype );
378
+ mp_float_t (* func )(void * ) = ndarray_get_float_function (ndarray -> dtype );
369
379
370
- size_t k = 0 ;
371
- do {
372
- size_t l = 0 ;
380
+ size_t k = 0 ;
373
381
do {
374
- mp_float_t v = func (array );
375
- array += ndarray -> strides [ULAB_MAX_DIMS - 1 ];
376
- dot += v * v ;
377
- l ++ ;
378
- } while (l < ndarray -> shape [ULAB_MAX_DIMS - 1 ]);
379
- array -= ndarray -> strides [ULAB_MAX_DIMS - 1 ] * ndarray -> shape [ULAB_MAX_DIMS - 1 ];
380
- array += ndarray -> strides [ULAB_MAX_DIMS - 2 ];
381
- k ++ ;
382
- } while (k < ndarray -> shape [ULAB_MAX_DIMS - 2 ]);
383
- return mp_obj_new_float (MICROPY_FLOAT_C_FUN (sqrt )(dot ));
382
+ size_t l = 0 ;
383
+ do {
384
+ value = func (array );
385
+ dot = dot + (value * value - dot ) / count ++ ;
386
+ array += ndarray -> strides [ULAB_MAX_DIMS - 1 ];
387
+ l ++ ;
388
+ } while (l < ndarray -> shape [ULAB_MAX_DIMS - 1 ]);
389
+ array -= ndarray -> strides [ULAB_MAX_DIMS - 1 ] * ndarray -> shape [ULAB_MAX_DIMS - 1 ];
390
+ array += ndarray -> strides [ULAB_MAX_DIMS - 2 ];
391
+ k ++ ;
392
+ } while (k < ndarray -> shape [ULAB_MAX_DIMS - 2 ]);
393
+ return mp_obj_new_float (MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 )));
394
+ } else {
395
+ mp_raise_TypeError (translate ("argument must be an interable or ndarray" ));
396
+ }
384
397
}
385
398
386
399
MP_DEFINE_CONST_FUN_OBJ_1 (linalg_norm_obj , linalg_norm );
0 commit comments