Skip to content

Commit 0d1379d

Browse files
committed
linalg.norm should not work with the axis keyword argument
1 parent 2c71434 commit 0d1379d

File tree

3 files changed

+95
-40
lines changed

3 files changed

+95
-40
lines changed

code/numpy/linalg/linalg.c

Lines changed: 74 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,23 @@ MP_DEFINE_CONST_FUN_OBJ_1(linalg_inv_obj, linalg_inv);
354354
//| ...
355355
//|
356356

357-
static mp_obj_t linalg_norm(mp_obj_t x) {
357+
static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
358+
static const mp_arg_t allowed_args[] = {
359+
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = mp_const_none} } ,
360+
{ MP_QSTR_axis, MP_ARG_OBJ, { .u_rom_obj = mp_const_none } },
361+
};
362+
363+
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
364+
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
365+
366+
mp_obj_t x = args[0].u_obj;
367+
mp_obj_t axis = args[1].u_obj;
368+
if((axis != mp_const_none) && (!MP_OBJ_IS_INT(axis))) {
369+
mp_raise_TypeError(translate("axis must be None, or an integer"));
370+
}
371+
372+
373+
// static mp_obj_t linalg_norm(mp_obj_t x) {
358374
mp_float_t dot = 0.0, value;
359375
size_t count = 1;
360376

@@ -370,33 +386,71 @@ static mp_obj_t linalg_norm(mp_obj_t x) {
370386
return mp_obj_new_float(MICROPY_FLOAT_C_FUN(sqrt)(dot * (count - 1)));
371387
} else if(MP_OBJ_IS_TYPE(x, &ulab_ndarray_type)) {
372388
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(x);
373-
if((ndarray->ndim != 1) && (ndarray->ndim != 2)) {
374-
mp_raise_ValueError(translate("norm is defined for 1D and 2D arrays"));
375-
}
376389
uint8_t *array = (uint8_t *)ndarray->array;
377-
390+
// always get a float, so that we don't have to resolve the dtype later
378391
mp_float_t (*func)(void *) = ndarray_get_float_function(ndarray->dtype);
392+
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
393+
mp_float_t *rarray = NULL;
394+
ndarray_obj_t *results;
395+
if(axis != mp_const_none) {
396+
results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_FLOAT);
397+
rarray = results->array;
398+
}
379399

380-
size_t k = 0;
400+
#if ULAB_MAX_DIMS > 3
401+
size_t i = 0;
381402
do {
382-
size_t l = 0;
403+
#endif
404+
#if ULAB_MAX_DIMS > 2
405+
size_t j = 0;
383406
do {
384-
value = func(array);
385-
dot = dot + (value * value - dot) / count++;
386-
array += ndarray->strides[ULAB_MAX_DIMS - 1];
387-
l++;
388-
} while(l < ndarray->shape[ULAB_MAX_DIMS - 1]);
389-
array -= ndarray->strides[ULAB_MAX_DIMS - 1] * ndarray->shape[ULAB_MAX_DIMS - 1];
390-
array += ndarray->strides[ULAB_MAX_DIMS - 2];
391-
k++;
392-
} while(k < ndarray->shape[ULAB_MAX_DIMS - 2]);
393-
return mp_obj_new_float(MICROPY_FLOAT_C_FUN(sqrt)(dot * (count - 1)));
394-
} else {
395-
mp_raise_TypeError(translate("argument must be an interable or ndarray"));
407+
#endif
408+
#if ULAB_MAX_DIMS > 1
409+
size_t k = 0;
410+
do {
411+
#endif
412+
size_t l = 0;
413+
if(axis != mp_const_none) {
414+
count = 1;
415+
dot = 0.0;
416+
}
417+
do {
418+
value = func(array);
419+
dot = dot + (value * value - dot) / count++;
420+
array += _shape_strides.strides[ULAB_MAX_DIMS - 1];
421+
l++;
422+
} while(l < _shape_strides.shape[ULAB_MAX_DIMS - 1]);
423+
if(axis != mp_const_none) {
424+
*rarray++ = MICROPY_FLOAT_C_FUN(sqrt)(dot * (count - 1));
425+
}
426+
#if ULAB_MAX_DIMS > 1
427+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 1] * _shape_strides.shape[ULAB_MAX_DIMS - 1];
428+
array += _shape_strides.strides[ULAB_MAX_DIMS - 2];
429+
k++;
430+
} while(k < _shape_strides.shape[ULAB_MAX_DIMS - 2]);
431+
#endif
432+
#if ULAB_MAX_DIMS > 2
433+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 2] * _shape_strides.shape[ULAB_MAX_DIMS-2];
434+
array += _shape_strides.strides[ULAB_MAX_DIMS - 3];
435+
j++;
436+
} while(j < _shape_strides.shape[ULAB_MAX_DIMS - 3]);
437+
#endif
438+
#if ULAB_MAX_DIMS > 3
439+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 3] * _shape_strides.shape[ULAB_MAX_DIMS-3];
440+
array += _shape_strides.strides[ULAB_MAX_DIMS - 4];
441+
i++;
442+
} while(i < _shape_strides.shape[ULAB_MAX_DIMS - 4]);
443+
#endif
444+
if(axis == mp_const_none) {
445+
return mp_obj_new_float(MICROPY_FLOAT_C_FUN(sqrt)(dot * (count - 1)));
446+
}
447+
return results;
396448
}
449+
return mp_const_none; // we should never reach this point
397450
}
398451

399-
MP_DEFINE_CONST_FUN_OBJ_1(linalg_norm_obj, linalg_norm);
452+
MP_DEFINE_CONST_FUN_OBJ_KW(linalg_norm_obj, 1, linalg_norm);
453+
// MP_DEFINE_CONST_FUN_OBJ_1(linalg_norm_obj, linalg_norm);
400454

401455
#if ULAB_MAX_DIMS > 1
402456
#if ULAB_LINALG_HAS_TRACE

code/numpy/linalg/linalg.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,5 @@ MP_DECLARE_CONST_FUN_OBJ_1(linalg_eig_obj);
2424
MP_DECLARE_CONST_FUN_OBJ_1(linalg_inv_obj);
2525
MP_DECLARE_CONST_FUN_OBJ_1(linalg_trace_obj);
2626
MP_DECLARE_CONST_FUN_OBJ_2(linalg_dot_obj);
27-
MP_DECLARE_CONST_FUN_OBJ_2(linalg_norm_obj);
27+
MP_DECLARE_CONST_FUN_OBJ_KW(linalg_norm_obj);
2828
#endif

code/ulab_tools.c

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -161,31 +161,32 @@ void *ndarray_set_float_function(uint8_t dtype) {
161161

162162
shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
163163
// TODO: replace numerical_reduce_axes with this function, wherever applicable
164-
int8_t ax = mp_obj_get_int(axis);
165-
if(ax < 0) ax += ndarray->ndim;
166-
if((ax < 0) || (ax > ndarray->ndim - 1)) {
167-
mp_raise_ValueError(translate("index out of range"));
164+
if(!mp_obj_is_int(axis) & (axis != mp_const_none)) {
165+
mp_raise_TypeError(translate("axis must be an interable or ndarray"));
168166
}
169167
shape_strides _shape_strides;
170-
_shape_strides.index = ULAB_MAX_DIMS - ndarray->ndim + ax;
171168
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
172-
memset(shape, 0, sizeof(size_t)*ULAB_MAX_DIMS);
173169
_shape_strides.shape = shape;
174170
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
175-
memset(strides, 0, sizeof(uint32_t)*ULAB_MAX_DIMS);
176171
_shape_strides.strides = strides;
177-
if((ndarray->ndim == 1) && (_shape_strides.axis == 0)) {
178-
_shape_strides.index = 0;
179-
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
180-
} else {
181-
for(uint8_t i = ULAB_MAX_DIMS - 1; i > 0; i--) {
182-
if(i > _shape_strides.index) {
183-
_shape_strides.shape[i] = ndarray->shape[i];
184-
_shape_strides.strides[i] = ndarray->strides[i];
185-
} else {
186-
_shape_strides.shape[i] = ndarray->shape[i-1];
187-
_shape_strides.strides[i] = ndarray->strides[i-1];
188-
}
172+
173+
memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS);
174+
memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS);
175+
// for axis == mp_const_none, simply return the original shape and strides
176+
if(axis != mp_const_none) {
177+
// move the axis to the rightmost position, and align everything else to the right
178+
int8_t ax = mp_obj_get_int(axis);
179+
if(ax < 0) ax += ndarray->ndim;
180+
if((ax < 0) || (ax > ndarray->ndim - 1)) {
181+
mp_raise_ValueError(translate("index out of range"));
182+
}
183+
uint8_t index = ULAB_MAX_DIMS - ndarray->ndim + ax;
184+
_shape_strides.shape[ULAB_MAX_DIMS - 1] = ndarray->shape[index];
185+
_shape_strides.strides[ULAB_MAX_DIMS - 1] = ndarray->strides[index];
186+
for(uint8_t i = index; i < ULAB_MAX_DIMS - 1; i++) {
187+
// entries to the right of index must be shifted to the left
188+
_shape_strides.shape[i] = ndarray->shape[i+1];
189+
_shape_strides.strides[i] = ndarray->strides[i+1];
189190
}
190191
}
191192
return _shape_strides;

0 commit comments

Comments
 (0)