Skip to content

Commit cacb1b6

Browse files
committed
fixed linalg.norm for a special case
1 parent 674220c commit cacb1b6

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

code/numpy/linalg/linalg.c

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -391,10 +391,12 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
391391
mp_float_t (*func)(void *) = ndarray_get_float_function(ndarray->dtype);
392392
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
393393
mp_float_t *rarray = NULL;
394-
ndarray_obj_t *results;
395-
if(axis != mp_const_none) {
394+
ndarray_obj_t *results = NULL;
395+
if((axis != mp_const_none) && (ndarray->ndim > 1)) {
396396
results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_FLOAT);
397397
rarray = results->array;
398+
} else {
399+
rarray = m_new(mp_float_t, 1);
398400
}
399401

400402
#if ULAB_MAX_DIMS > 3
@@ -420,8 +422,9 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
420422
array += _shape_strides.strides[ULAB_MAX_DIMS - 1];
421423
l++;
422424
} while(l < _shape_strides.shape[ULAB_MAX_DIMS - 1]);
423-
if(axis != mp_const_none) {
424-
*rarray++ = MICROPY_FLOAT_C_FUN(sqrt)(dot * (count - 1));
425+
*rarray = MICROPY_FLOAT_C_FUN(sqrt)(dot * (count - 1));
426+
if(results != NULL) {
427+
rarray++;
425428
}
426429
#if ULAB_MAX_DIMS > 1
427430
array -= _shape_strides.strides[ULAB_MAX_DIMS - 1] * _shape_strides.shape[ULAB_MAX_DIMS - 1];
@@ -441,8 +444,8 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
441444
i++;
442445
} while(i < _shape_strides.shape[ULAB_MAX_DIMS - 4]);
443446
#endif
444-
if(axis == mp_const_none) {
445-
return mp_obj_new_float(MICROPY_FLOAT_C_FUN(sqrt)(dot * (count - 1)));
447+
if(results == NULL) {
448+
return mp_obj_new_float(*rarray);
446449
}
447450
return results;
448451
}

0 commit comments

Comments
 (0)