@@ -354,7 +354,23 @@ MP_DEFINE_CONST_FUN_OBJ_1(linalg_inv_obj, linalg_inv);
354
354
//| ...
355
355
//|
356
356
357
- static mp_obj_t linalg_norm (mp_obj_t x ) {
357
+ static mp_obj_t linalg_norm (size_t n_args , const mp_obj_t * pos_args , mp_map_t * kw_args ) {
358
+ static const mp_arg_t allowed_args [] = {
359
+ { MP_QSTR_ , MP_ARG_REQUIRED | MP_ARG_OBJ , { .u_rom_obj = mp_const_none } } ,
360
+ { MP_QSTR_axis , MP_ARG_OBJ , { .u_rom_obj = mp_const_none } },
361
+ };
362
+
363
+ mp_arg_val_t args [MP_ARRAY_SIZE (allowed_args )];
364
+ mp_arg_parse_all (n_args , pos_args , kw_args , MP_ARRAY_SIZE (allowed_args ), allowed_args , args );
365
+
366
+ mp_obj_t x = args [0 ].u_obj ;
367
+ mp_obj_t axis = args [1 ].u_obj ;
368
+ if ((axis != mp_const_none ) && (!MP_OBJ_IS_INT (axis ))) {
369
+ mp_raise_TypeError (translate ("axis must be None, or an integer" ));
370
+ }
371
+
372
+
373
+ // static mp_obj_t linalg_norm(mp_obj_t x) {
358
374
mp_float_t dot = 0.0 , value ;
359
375
size_t count = 1 ;
360
376
@@ -370,33 +386,71 @@ static mp_obj_t linalg_norm(mp_obj_t x) {
370
386
return mp_obj_new_float (MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 )));
371
387
} else if (MP_OBJ_IS_TYPE (x , & ulab_ndarray_type )) {
372
388
ndarray_obj_t * ndarray = MP_OBJ_TO_PTR (x );
373
- if ((ndarray -> ndim != 1 ) && (ndarray -> ndim != 2 )) {
374
- mp_raise_ValueError (translate ("norm is defined for 1D and 2D arrays" ));
375
- }
376
389
uint8_t * array = (uint8_t * )ndarray -> array ;
377
-
390
+ // always get a float, so that we don't have to resolve the dtype later
378
391
mp_float_t (* func )(void * ) = ndarray_get_float_function (ndarray -> dtype );
392
+ shape_strides _shape_strides = tools_reduce_axes (ndarray , axis );
393
+ mp_float_t * rarray = NULL ;
394
+ ndarray_obj_t * results ;
395
+ if (axis != mp_const_none ) {
396
+ results = ndarray_new_dense_ndarray (MAX (1 , ndarray -> ndim - 1 ), _shape_strides .shape , NDARRAY_FLOAT );
397
+ rarray = results -> array ;
398
+ }
379
399
380
- size_t k = 0 ;
400
+ #if ULAB_MAX_DIMS > 3
401
+ size_t i = 0 ;
381
402
do {
382
- size_t l = 0 ;
403
+ #endif
404
+ #if ULAB_MAX_DIMS > 2
405
+ size_t j = 0 ;
383
406
do {
384
- value = func (array );
385
- dot = dot + (value * value - dot ) / count ++ ;
386
- array += ndarray -> strides [ULAB_MAX_DIMS - 1 ];
387
- l ++ ;
388
- } while (l < ndarray -> shape [ULAB_MAX_DIMS - 1 ]);
389
- array -= ndarray -> strides [ULAB_MAX_DIMS - 1 ] * ndarray -> shape [ULAB_MAX_DIMS - 1 ];
390
- array += ndarray -> strides [ULAB_MAX_DIMS - 2 ];
391
- k ++ ;
392
- } while (k < ndarray -> shape [ULAB_MAX_DIMS - 2 ]);
393
- return mp_obj_new_float (MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 )));
394
- } else {
395
- mp_raise_TypeError (translate ("argument must be an interable or ndarray" ));
407
+ #endif
408
+ #if ULAB_MAX_DIMS > 1
409
+ size_t k = 0 ;
410
+ do {
411
+ #endif
412
+ size_t l = 0 ;
413
+ if (axis != mp_const_none ) {
414
+ count = 1 ;
415
+ dot = 0.0 ;
416
+ }
417
+ do {
418
+ value = func (array );
419
+ dot = dot + (value * value - dot ) / count ++ ;
420
+ array += _shape_strides .strides [ULAB_MAX_DIMS - 1 ];
421
+ l ++ ;
422
+ } while (l < _shape_strides .shape [ULAB_MAX_DIMS - 1 ]);
423
+ if (axis != mp_const_none ) {
424
+ * rarray ++ = MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 ));
425
+ }
426
+ #if ULAB_MAX_DIMS > 1
427
+ array -= _shape_strides .strides [ULAB_MAX_DIMS - 1 ] * _shape_strides .shape [ULAB_MAX_DIMS - 1 ];
428
+ array += _shape_strides .strides [ULAB_MAX_DIMS - 2 ];
429
+ k ++ ;
430
+ } while (k < _shape_strides .shape [ULAB_MAX_DIMS - 2 ]);
431
+ #endif
432
+ #if ULAB_MAX_DIMS > 2
433
+ array -= _shape_strides .strides [ULAB_MAX_DIMS - 2 ] * _shape_strides .shape [ULAB_MAX_DIMS - 2 ];
434
+ array += _shape_strides .strides [ULAB_MAX_DIMS - 3 ];
435
+ j ++ ;
436
+ } while (j < _shape_strides .shape [ULAB_MAX_DIMS - 3 ]);
437
+ #endif
438
+ #if ULAB_MAX_DIMS > 3
439
+ array -= _shape_strides .strides [ULAB_MAX_DIMS - 3 ] * _shape_strides .shape [ULAB_MAX_DIMS - 3 ];
440
+ array += _shape_strides .strides [ULAB_MAX_DIMS - 4 ];
441
+ i ++ ;
442
+ } while (i < _shape_strides .shape [ULAB_MAX_DIMS - 4 ]);
443
+ #endif
444
+ if (axis == mp_const_none ) {
445
+ return mp_obj_new_float (MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 )));
446
+ }
447
+ return results ;
396
448
}
449
+ return mp_const_none ; // we should never reach this point
397
450
}
398
451
399
- MP_DEFINE_CONST_FUN_OBJ_1 (linalg_norm_obj , linalg_norm );
452
+ MP_DEFINE_CONST_FUN_OBJ_KW (linalg_norm_obj , 1 , linalg_norm );
453
+ // MP_DEFINE_CONST_FUN_OBJ_1(linalg_norm_obj, linalg_norm);
400
454
401
455
#if ULAB_MAX_DIMS > 1
402
456
#if ULAB_LINALG_HAS_TRACE
0 commit comments