Skip to content

Commit 5f716a7

Browse files
authored
Merge pull request #312 from v923z/norm
improved linalg.norm
2 parents e00ad9c + 6867951 commit 5f716a7

File tree

9 files changed

+141
-74
lines changed

9 files changed

+141
-74
lines changed

code/ndarray.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ ndarray_obj_t *ndarray_new_dense_ndarray(uint8_t ndim, size_t *shape, uint8_t dt
632632
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
633633
strides[ULAB_MAX_DIMS-1] = dtype == NDARRAY_BOOL ? 1 : mp_binary_get_size('@', dtype, NULL);
634634
for(size_t i=ULAB_MAX_DIMS; i > 1; i--) {
635-
strides[i-2] = strides[i-1] * shape[i-1];
635+
strides[i-2] = strides[i-1] * MAX(1, shape[i-1]);
636636
}
637637
return ndarray_new_ndarray(ndim, shape, strides, dtype);
638638
}

code/numpy/linalg/linalg.c

Lines changed: 77 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,23 @@ MP_DEFINE_CONST_FUN_OBJ_1(linalg_inv_obj, linalg_inv);
354354
//| ...
355355
//|
356356

357-
static mp_obj_t linalg_norm(mp_obj_t x) {
357+
static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
358+
static const mp_arg_t allowed_args[] = {
359+
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = mp_const_none} } ,
360+
{ MP_QSTR_axis, MP_ARG_OBJ, { .u_rom_obj = mp_const_none } },
361+
};
362+
363+
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
364+
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
365+
366+
mp_obj_t x = args[0].u_obj;
367+
mp_obj_t axis = args[1].u_obj;
368+
if((axis != mp_const_none) && (!MP_OBJ_IS_INT(axis))) {
369+
mp_raise_TypeError(translate("axis must be None, or an integer"));
370+
}
371+
372+
373+
// static mp_obj_t linalg_norm(mp_obj_t x) {
358374
mp_float_t dot = 0.0, value;
359375
size_t count = 1;
360376

@@ -370,33 +386,74 @@ static mp_obj_t linalg_norm(mp_obj_t x) {
370386
return mp_obj_new_float(MICROPY_FLOAT_C_FUN(sqrt)(dot * (count - 1)));
371387
} else if(MP_OBJ_IS_TYPE(x, &ulab_ndarray_type)) {
372388
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(x);
373-
if((ndarray->ndim != 1) && (ndarray->ndim != 2)) {
374-
mp_raise_ValueError(translate("norm is defined for 1D and 2D arrays"));
375-
}
376389
uint8_t *array = (uint8_t *)ndarray->array;
377-
390+
// always get a float, so that we don't have to resolve the dtype later
378391
mp_float_t (*func)(void *) = ndarray_get_float_function(ndarray->dtype);
392+
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
393+
mp_float_t *rarray = NULL;
394+
ndarray_obj_t *results = NULL;
395+
if((axis != mp_const_none) && (ndarray->ndim > 1)) {
396+
results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_FLOAT);
397+
rarray = results->array;
398+
} else {
399+
rarray = m_new(mp_float_t, 1);
400+
}
379401

380-
size_t k = 0;
402+
#if ULAB_MAX_DIMS > 3
403+
size_t i = 0;
381404
do {
382-
size_t l = 0;
405+
#endif
406+
#if ULAB_MAX_DIMS > 2
407+
size_t j = 0;
383408
do {
384-
value = func(array);
385-
dot = dot + (value * value - dot) / count++;
386-
array += ndarray->strides[ULAB_MAX_DIMS - 1];
387-
l++;
388-
} while(l < ndarray->shape[ULAB_MAX_DIMS - 1]);
389-
array -= ndarray->strides[ULAB_MAX_DIMS - 1] * ndarray->shape[ULAB_MAX_DIMS - 1];
390-
array += ndarray->strides[ULAB_MAX_DIMS - 2];
391-
k++;
392-
} while(k < ndarray->shape[ULAB_MAX_DIMS - 2]);
393-
return mp_obj_new_float(MICROPY_FLOAT_C_FUN(sqrt)(dot * (count - 1)));
394-
} else {
395-
mp_raise_TypeError(translate("argument must be an interable or ndarray"));
409+
#endif
410+
#if ULAB_MAX_DIMS > 1
411+
size_t k = 0;
412+
do {
413+
#endif
414+
size_t l = 0;
415+
if(axis != mp_const_none) {
416+
count = 1;
417+
dot = 0.0;
418+
}
419+
do {
420+
value = func(array);
421+
dot = dot + (value * value - dot) / count++;
422+
array += _shape_strides.strides[0];
423+
l++;
424+
} while(l < _shape_strides.shape[0]);
425+
*rarray = MICROPY_FLOAT_C_FUN(sqrt)(dot * (count - 1));
426+
if(results != NULL) {
427+
rarray++;
428+
}
429+
#if ULAB_MAX_DIMS > 1
430+
array -= _shape_strides.strides[0] * _shape_strides.shape[0];
431+
array += _shape_strides.strides[ULAB_MAX_DIMS - 1];
432+
k++;
433+
} while(k < _shape_strides.shape[ULAB_MAX_DIMS - 1]);
434+
#endif
435+
#if ULAB_MAX_DIMS > 2
436+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 1] * _shape_strides.shape[ULAB_MAX_DIMS-1];
437+
array += _shape_strides.strides[ULAB_MAX_DIMS - 2];
438+
j++;
439+
} while(j < _shape_strides.shape[ULAB_MAX_DIMS - 2]);
440+
#endif
441+
#if ULAB_MAX_DIMS > 3
442+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 2] * _shape_strides.shape[ULAB_MAX_DIMS-2];
443+
array += _shape_strides.strides[ULAB_MAX_DIMS - 3];
444+
i++;
445+
} while(i < _shape_strides.shape[ULAB_MAX_DIMS - 3]);
446+
#endif
447+
if(results == NULL) {
448+
return mp_obj_new_float(*rarray);
449+
}
450+
return results;
396451
}
452+
return mp_const_none; // we should never reach this point
397453
}
398454

399-
MP_DEFINE_CONST_FUN_OBJ_1(linalg_norm_obj, linalg_norm);
455+
MP_DEFINE_CONST_FUN_OBJ_KW(linalg_norm_obj, 1, linalg_norm);
456+
// MP_DEFINE_CONST_FUN_OBJ_1(linalg_norm_obj, linalg_norm);
400457

401458
#if ULAB_MAX_DIMS > 1
402459
#if ULAB_LINALG_HAS_TRACE

code/numpy/linalg/linalg.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,5 @@ MP_DECLARE_CONST_FUN_OBJ_1(linalg_eig_obj);
2424
MP_DECLARE_CONST_FUN_OBJ_1(linalg_inv_obj);
2525
MP_DECLARE_CONST_FUN_OBJ_1(linalg_trace_obj);
2626
MP_DECLARE_CONST_FUN_OBJ_2(linalg_dot_obj);
27-
MP_DECLARE_CONST_FUN_OBJ_2(linalg_norm_obj);
27+
MP_DECLARE_CONST_FUN_OBJ_KW(linalg_norm_obj);
2828
#endif

code/numpy/numerical/numerical.c

Lines changed: 11 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -63,38 +63,6 @@ static void numerical_reduce_axes(ndarray_obj_t *ndarray, int8_t axis, size_t *s
6363
}
6464
}
6565

66-
static shape_strides numerical_reduce_axes_(ndarray_obj_t *ndarray, mp_obj_t axis) {
67-
// TODO: replace numerical_reduce_axes with this function, wherever applicable
68-
int8_t ax = mp_obj_get_int(axis);
69-
if(ax < 0) ax += ndarray->ndim;
70-
if((ax < 0) || (ax > ndarray->ndim - 1)) {
71-
mp_raise_ValueError(translate("index out of range"));
72-
}
73-
shape_strides _shape_strides;
74-
_shape_strides.index = ULAB_MAX_DIMS - ndarray->ndim + ax;
75-
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
76-
memset(shape, 0, sizeof(size_t)*ULAB_MAX_DIMS);
77-
_shape_strides.shape = shape;
78-
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
79-
memset(strides, 0, sizeof(uint32_t)*ULAB_MAX_DIMS);
80-
_shape_strides.strides = strides;
81-
if((ndarray->ndim == 1) && (_shape_strides.axis == 0)) {
82-
_shape_strides.index = 0;
83-
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
84-
} else {
85-
for(uint8_t i = ULAB_MAX_DIMS - 1; i > 0; i--) {
86-
if(i > _shape_strides.index) {
87-
_shape_strides.shape[i] = ndarray->shape[i];
88-
_shape_strides.strides[i] = ndarray->strides[i];
89-
} else {
90-
_shape_strides.shape[i] = ndarray->shape[i-1];
91-
_shape_strides.strides[i] = ndarray->strides[i-1];
92-
}
93-
}
94-
}
95-
return _shape_strides;
96-
}
97-
9866
#if ULAB_NUMPY_HAS_ALL | ULAB_NUMPY_HAS_ANY
9967
static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
10068
bool anytype = optype == NUMERICAL_ALL ? 1 : 0;
@@ -130,25 +98,25 @@ static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
13098
l++;
13199
} while(l < ndarray->shape[ULAB_MAX_DIMS - 1]);
132100
#if ULAB_MAX_DIMS > 1
133-
array -= ndarray->strides[ULAB_MAX_DIMS - 1] * ndarray->shape[ULAB_MAX_DIMS-1];
101+
array -= ndarray->strides[ULAB_MAX_DIMS - 1] * ndarray->shape[ULAB_MAX_DIMS - 1];
134102
array += ndarray->strides[ULAB_MAX_DIMS - 2];
135103
k++;
136104
} while(k < ndarray->shape[ULAB_MAX_DIMS - 2]);
137105
#endif
138106
#if ULAB_MAX_DIMS > 2
139-
array -= ndarray->strides[ULAB_MAX_DIMS - 2] * ndarray->shape[ULAB_MAX_DIMS-2];
107+
array -= ndarray->strides[ULAB_MAX_DIMS - 2] * ndarray->shape[ULAB_MAX_DIMS - 2];
140108
array += ndarray->strides[ULAB_MAX_DIMS - 3];
141109
j++;
142110
} while(j < ndarray->shape[ULAB_MAX_DIMS - 3]);
143111
#endif
144112
#if ULAB_MAX_DIMS > 3
145-
array -= ndarray->strides[ULAB_MAX_DIMS - 3] * ndarray->shape[ULAB_MAX_DIMS-3];
113+
array -= ndarray->strides[ULAB_MAX_DIMS - 3] * ndarray->shape[ULAB_MAX_DIMS - 3];
146114
array += ndarray->strides[ULAB_MAX_DIMS - 4];
147115
i++;
148116
} while(i < ndarray->shape[ULAB_MAX_DIMS - 4]);
149117
#endif
150118
} else {
151-
shape_strides _shape_strides = numerical_reduce_axes_(ndarray, axis);
119+
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
152120
ndarray_obj_t *results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_BOOL);
153121
uint8_t *rarray = (uint8_t *)results->array;
154122
if(optype == NUMERICAL_ALL) {
@@ -173,33 +141,33 @@ static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
173141
// optype == NUMERICAL_ANY
174142
*rarray = 1;
175143
// since we are breaking out of the loop, move the pointer forward
176-
array += ndarray->strides[_shape_strides.index] * (ndarray->shape[_shape_strides.index] - l);
144+
array += _shape_strides.strides[0] * (_shape_strides.shape[0] - l);
177145
break;
178146
} else if((value == MICROPY_FLOAT_CONST(0.0)) & anytype) {
179147
// optype == NUMERICAL_ALL
180148
*rarray = 0;
181149
// since we are breaking out of the loop, move the pointer forward
182-
array += ndarray->strides[_shape_strides.index] * (ndarray->shape[_shape_strides.index] - l);
150+
array += _shape_strides.strides[0] * (_shape_strides.shape[0] - l);
183151
break;
184152
}
185-
array += ndarray->strides[_shape_strides.index];
153+
array += _shape_strides.strides[0];
186154
l++;
187-
} while(l < ndarray->shape[_shape_strides.index]);
155+
} while(l < _shape_strides.shape[0]);
188156
#if ULAB_MAX_DIMS > 1
189157
rarray++;
190-
array -= ndarray->strides[_shape_strides.index] * ndarray->shape[_shape_strides.index];
158+
array -= _shape_strides.strides[0] * _shape_strides.shape[0];
191159
array += _shape_strides.strides[ULAB_MAX_DIMS - 1];
192160
k++;
193161
} while(k < _shape_strides.shape[ULAB_MAX_DIMS - 1]);
194162
#endif
195163
#if ULAB_MAX_DIMS > 2
196-
array -= _shape_strides.strides[ULAB_MAX_DIMS - 1] * _shape_strides.shape[ULAB_MAX_DIMS-1];
164+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 1] * _shape_strides.shape[ULAB_MAX_DIMS - 1];
197165
array += _shape_strides.strides[ULAB_MAX_DIMS - 2];
198166
j++;
199167
} while(j < _shape_strides.shape[ULAB_MAX_DIMS - 2]);
200168
#endif
201169
#if ULAB_MAX_DIMS > 3
202-
array -= _shape_strides.strides[ULAB_MAX_DIMS - 2] * _shape_strides.shape[ULAB_MAX_DIMS-2];
170+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 2] * _shape_strides.shape[ULAB_MAX_DIMS - 2];
203171
array += _shape_strides.strides[ULAB_MAX_DIMS - 3];
204172
i++;
205173
} while(i < _shape_strides.shape[ULAB_MAX_DIMS - 3])

code/numpy/numerical/numerical.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,6 @@
1717

1818
// TODO: implement cumsum
1919

20-
typedef struct {
21-
uint8_t index;
22-
int8_t axis;
23-
size_t *shape;
24-
int32_t *strides;
25-
} shape_strides;
26-
2720
#define RUN_ARGMIN1(ndarray, type, array, results, rarray, index, op)\
2821
({\
2922
uint16_t best_index = 0;\

code/ulab.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
#include "user/user.h"
3535

36-
#define ULAB_VERSION 2.3.2
36+
#define ULAB_VERSION 2.3.3
3737
#define xstr(s) str(s)
3838
#define str(s) #s
3939
#define ULAB_VERSION_STRING xstr(ULAB_VERSION) xstr(-) xstr(ULAB_MAX_DIMS) xstr(D)

code/ulab_tools.c

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
*/
1010

1111

12-
12+
#include <string.h>
1313
#include "py/runtime.h"
1414

1515
#include "ulab.h"
@@ -158,3 +158,36 @@ void *ndarray_set_float_function(uint8_t dtype) {
158158
}
159159
}
160160
#endif /* NDARRAY_BINARY_USES_FUN_POINTER */
161+
162+
shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
163+
// TODO: replace numerical_reduce_axes with this function, wherever applicable
164+
if(!mp_obj_is_int(axis) & (axis != mp_const_none)) {
165+
mp_raise_TypeError(translate("axis must be an interable or ndarray"));
166+
}
167+
shape_strides _shape_strides;
168+
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
169+
_shape_strides.shape = shape;
170+
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
171+
_shape_strides.strides = strides;
172+
173+
memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS);
174+
memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS);
175+
// for axis == mp_const_none, simply return the original shape and strides
176+
if(axis != mp_const_none) {
177+
int8_t ax = mp_obj_get_int(axis);
178+
if(ax < 0) ax += ndarray->ndim;
179+
if((ax < 0) || (ax > ndarray->ndim - 1)) {
180+
mp_raise_ValueError(translate("index out of range"));
181+
}
182+
// move the axis to the leftmost position, and align everything else to the right
183+
uint8_t index = ULAB_MAX_DIMS - ndarray->ndim + ax;
184+
_shape_strides.shape[0] = ndarray->shape[index];
185+
_shape_strides.strides[0] = ndarray->strides[index];
186+
for(uint8_t i = 0; i < index; i++) {
187+
// entries to the left of index must be shifted to the right
188+
_shape_strides.shape[i + 1] = ndarray->shape[i];
189+
_shape_strides.strides[i + 1] = ndarray->strides[i];
190+
}
191+
}
192+
return _shape_strides;
193+
}

code/ulab_tools.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,17 @@
1111
#ifndef _TOOLS_
1212
#define _TOOLS_
1313

14+
#include "ndarray.h"
15+
1416
#define SWAP(t, a, b) { t tmp = a; a = b; b = tmp; }
1517

18+
typedef struct _shape_strides_t {
19+
uint8_t index;
20+
int8_t axis;
21+
size_t *shape;
22+
int32_t *strides;
23+
} shape_strides;
24+
1625
mp_float_t ndarray_get_float_uint8(void *);
1726
mp_float_t ndarray_get_float_int8(void *);
1827
mp_float_t ndarray_get_float_uint16(void *);
@@ -23,4 +32,5 @@ void *ndarray_get_float_function(uint8_t );
2332
uint8_t ndarray_upcast_dtype(uint8_t , uint8_t );
2433
void *ndarray_set_float_function(uint8_t );
2534

35+
shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t );
2636
#endif

docs/ulab-change-log.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
Tue, 9 Feb 2021
2+
3+
version 2.3.3
4+
5+
linalg.norm should now work with the axis keyword argument
6+
17
Mon, 8 Feb 2021
28

39
version 2.3.2

0 commit comments

Comments
 (0)