@@ -391,10 +391,12 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
391
391
mp_float_t (* func )(void * ) = ndarray_get_float_function (ndarray -> dtype );
392
392
shape_strides _shape_strides = tools_reduce_axes (ndarray , axis );
393
393
mp_float_t * rarray = NULL ;
394
- ndarray_obj_t * results ;
395
- if (axis != mp_const_none ) {
394
+ ndarray_obj_t * results = NULL ;
395
+ if (( axis != mp_const_none ) && ( ndarray -> ndim > 1 ) ) {
396
396
results = ndarray_new_dense_ndarray (MAX (1 , ndarray -> ndim - 1 ), _shape_strides .shape , NDARRAY_FLOAT );
397
397
rarray = results -> array ;
398
+ } else {
399
+ rarray = m_new (mp_float_t , 1 );
398
400
}
399
401
400
402
#if ULAB_MAX_DIMS > 3
@@ -420,8 +422,9 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
420
422
array += _shape_strides .strides [ULAB_MAX_DIMS - 1 ];
421
423
l ++ ;
422
424
} while (l < _shape_strides .shape [ULAB_MAX_DIMS - 1 ]);
423
- if (axis != mp_const_none ) {
424
- * rarray ++ = MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 ));
425
+ * rarray = MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 ));
426
+ if (results != NULL ) {
427
+ rarray ++ ;
425
428
}
426
429
#if ULAB_MAX_DIMS > 1
427
430
array -= _shape_strides .strides [ULAB_MAX_DIMS - 1 ] * _shape_strides .shape [ULAB_MAX_DIMS - 1 ];
@@ -441,8 +444,8 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
441
444
i ++ ;
442
445
} while (i < _shape_strides .shape [ULAB_MAX_DIMS - 4 ]);
443
446
#endif
444
- if (axis == mp_const_none ) {
445
- return mp_obj_new_float (MICROPY_FLOAT_C_FUN ( sqrt )( dot * ( count - 1 )) );
447
+ if (results == NULL ) {
448
+ return mp_obj_new_float (* rarray );
446
449
}
447
450
return results ;
448
451
}
0 commit comments