@@ -385,14 +385,8 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
385
385
// always get a float, so that we don't have to resolve the dtype later
386
386
mp_float_t (* func )(void * ) = ndarray_get_float_function (ndarray -> dtype );
387
387
shape_strides _shape_strides = tools_reduce_axes (ndarray , axis );
388
- mp_float_t * rarray = NULL ;
389
- ndarray_obj_t * results = NULL ;
390
- if ((axis != mp_const_none ) && (ndarray -> ndim > 1 )) {
391
- results = ndarray_new_dense_ndarray (MAX (1 , ndarray -> ndim - 1 ), _shape_strides .shape , NDARRAY_FLOAT );
392
- rarray = results -> array ;
393
- } else {
394
- rarray = m_new (mp_float_t , 1 );
395
- }
388
+ ndarray_obj_t * results = ndarray_new_dense_ndarray (_shape_strides .ndim , _shape_strides .shape , NDARRAY_FLOAT );
389
+ mp_float_t * rarray = (mp_float_t * )results -> array ;
396
390
397
391
#if ULAB_MAX_DIMS > 3
398
392
size_t i = 0 ;
@@ -418,28 +412,26 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
418
412
l ++ ;
419
413
} while (l < _shape_strides .shape [0 ]);
420
414
* rarray = MICROPY_FLOAT_C_FUN (sqrt )(dot * (count - 1 ));
421
- if (results != NULL ) {
422
- rarray ++ ;
423
- }
424
415
#if ULAB_MAX_DIMS > 1
416
+ rarray += _shape_strides .increment ;
425
417
array -= _shape_strides .strides [0 ] * _shape_strides .shape [0 ];
426
418
array += _shape_strides .strides [ULAB_MAX_DIMS - 1 ];
427
419
k ++ ;
428
420
} while (k < _shape_strides .shape [ULAB_MAX_DIMS - 1 ]);
429
421
#endif
430
422
#if ULAB_MAX_DIMS > 2
431
- array -= _shape_strides .strides [ULAB_MAX_DIMS - 1 ] * _shape_strides .shape [ULAB_MAX_DIMS - 1 ];
423
+ array -= _shape_strides .strides [ULAB_MAX_DIMS - 1 ] * _shape_strides .shape [ULAB_MAX_DIMS - 1 ];
432
424
array += _shape_strides .strides [ULAB_MAX_DIMS - 2 ];
433
425
j ++ ;
434
426
} while (j < _shape_strides .shape [ULAB_MAX_DIMS - 2 ]);
435
427
#endif
436
428
#if ULAB_MAX_DIMS > 3
437
- array -= _shape_strides .strides [ULAB_MAX_DIMS - 2 ] * _shape_strides .shape [ULAB_MAX_DIMS - 2 ];
429
+ array -= _shape_strides .strides [ULAB_MAX_DIMS - 2 ] * _shape_strides .shape [ULAB_MAX_DIMS - 2 ];
438
430
array += _shape_strides .strides [ULAB_MAX_DIMS - 3 ];
439
431
i ++ ;
440
432
} while (i < _shape_strides .shape [ULAB_MAX_DIMS - 3 ]);
441
433
#endif
442
- if (results == NULL ) {
434
+ if (results -> ndim == 0 ) {
443
435
return mp_obj_new_float (* rarray );
444
436
}
445
437
return results ;
0 commit comments