Skip to content

Commit 2c71434

Browse files
committed
moved the reduce_axes helper function to ulab_tools
1 parent 7c4f4db commit 2c71434

File tree

5 files changed

+45
-42
lines changed

5 files changed

+45
-42
lines changed

code/ndarray.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ ndarray_obj_t *ndarray_new_dense_ndarray(uint8_t ndim, size_t *shape, uint8_t dt
632632
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
633633
strides[ULAB_MAX_DIMS-1] = dtype == NDARRAY_BOOL ? 1 : mp_binary_get_size('@', dtype, NULL);
634634
for(size_t i=ULAB_MAX_DIMS; i > 1; i--) {
635-
strides[i-2] = strides[i-1] * shape[i-1];
635+
strides[i-2] = strides[i-1] * MAX(1, shape[i-1]);
636636
}
637637
return ndarray_new_ndarray(ndim, shape, strides, dtype);
638638
}

code/numpy/numerical/numerical.c

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -63,38 +63,6 @@ static void numerical_reduce_axes(ndarray_obj_t *ndarray, int8_t axis, size_t *s
6363
}
6464
}
6565

66-
static shape_strides numerical_reduce_axes_(ndarray_obj_t *ndarray, mp_obj_t axis) {
67-
// TODO: replace numerical_reduce_axes with this function, wherever applicable
68-
int8_t ax = mp_obj_get_int(axis);
69-
if(ax < 0) ax += ndarray->ndim;
70-
if((ax < 0) || (ax > ndarray->ndim - 1)) {
71-
mp_raise_ValueError(translate("index out of range"));
72-
}
73-
shape_strides _shape_strides;
74-
_shape_strides.index = ULAB_MAX_DIMS - ndarray->ndim + ax;
75-
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
76-
memset(shape, 0, sizeof(size_t)*ULAB_MAX_DIMS);
77-
_shape_strides.shape = shape;
78-
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
79-
memset(strides, 0, sizeof(uint32_t)*ULAB_MAX_DIMS);
80-
_shape_strides.strides = strides;
81-
if((ndarray->ndim == 1) && (_shape_strides.axis == 0)) {
82-
_shape_strides.index = 0;
83-
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
84-
} else {
85-
for(uint8_t i = ULAB_MAX_DIMS - 1; i > 0; i--) {
86-
if(i > _shape_strides.index) {
87-
_shape_strides.shape[i] = ndarray->shape[i];
88-
_shape_strides.strides[i] = ndarray->strides[i];
89-
} else {
90-
_shape_strides.shape[i] = ndarray->shape[i-1];
91-
_shape_strides.strides[i] = ndarray->strides[i-1];
92-
}
93-
}
94-
}
95-
return _shape_strides;
96-
}
97-
9866
#if ULAB_NUMPY_HAS_ALL | ULAB_NUMPY_HAS_ANY
9967
static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
10068
bool anytype = optype == NUMERICAL_ALL ? 1 : 0;
@@ -148,7 +116,7 @@ static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
148116
} while(i < ndarray->shape[ULAB_MAX_DIMS - 4]);
149117
#endif
150118
} else {
151-
shape_strides _shape_strides = numerical_reduce_axes_(ndarray, axis);
119+
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
152120
ndarray_obj_t *results = ndarray_new_dense_ndarray(MAX(1, ndarray->ndim-1), _shape_strides.shape, NDARRAY_BOOL);
153121
uint8_t *rarray = (uint8_t *)results->array;
154122
if(optype == NUMERICAL_ALL) {

code/numpy/numerical/numerical.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,6 @@
1717

1818
// TODO: implement cumsum
1919

20-
typedef struct {
21-
uint8_t index;
22-
int8_t axis;
23-
size_t *shape;
24-
int32_t *strides;
25-
} shape_strides;
26-
2720
#define RUN_ARGMIN1(ndarray, type, array, results, rarray, index, op)\
2821
({\
2922
uint16_t best_index = 0;\

code/ulab_tools.c

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
*/
1010

1111

12-
12+
#include <string.h>
1313
#include "py/runtime.h"
1414

1515
#include "ulab.h"
@@ -158,3 +158,35 @@ void *ndarray_set_float_function(uint8_t dtype) {
158158
}
159159
}
160160
#endif /* NDARRAY_BINARY_USES_FUN_POINTER */
161+
162+
shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
163+
// TODO: replace numerical_reduce_axes with this function, wherever applicable
164+
int8_t ax = mp_obj_get_int(axis);
165+
if(ax < 0) ax += ndarray->ndim;
166+
if((ax < 0) || (ax > ndarray->ndim - 1)) {
167+
mp_raise_ValueError(translate("index out of range"));
168+
}
169+
shape_strides _shape_strides;
170+
_shape_strides.index = ULAB_MAX_DIMS - ndarray->ndim + ax;
171+
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
172+
memset(shape, 0, sizeof(size_t)*ULAB_MAX_DIMS);
173+
_shape_strides.shape = shape;
174+
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
175+
memset(strides, 0, sizeof(uint32_t)*ULAB_MAX_DIMS);
176+
_shape_strides.strides = strides;
177+
if((ndarray->ndim == 1) && (_shape_strides.axis == 0)) {
178+
_shape_strides.index = 0;
179+
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
180+
} else {
181+
for(uint8_t i = ULAB_MAX_DIMS - 1; i > 0; i--) {
182+
if(i > _shape_strides.index) {
183+
_shape_strides.shape[i] = ndarray->shape[i];
184+
_shape_strides.strides[i] = ndarray->strides[i];
185+
} else {
186+
_shape_strides.shape[i] = ndarray->shape[i-1];
187+
_shape_strides.strides[i] = ndarray->strides[i-1];
188+
}
189+
}
190+
}
191+
return _shape_strides;
192+
}

code/ulab_tools.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,17 @@
1111
#ifndef _TOOLS_
1212
#define _TOOLS_
1313

14+
#include "ndarray.h"
15+
1416
#define SWAP(t, a, b) { t tmp = a; a = b; b = tmp; }
1517

18+
typedef struct _shape_strides_t {
19+
uint8_t index;
20+
int8_t axis;
21+
size_t *shape;
22+
int32_t *strides;
23+
} shape_strides;
24+
1625
mp_float_t ndarray_get_float_uint8(void *);
1726
mp_float_t ndarray_get_float_int8(void *);
1827
mp_float_t ndarray_get_float_uint16(void *);
@@ -23,4 +32,5 @@ void *ndarray_get_float_function(uint8_t );
2332
uint8_t ndarray_upcast_dtype(uint8_t , uint8_t );
2433
void *ndarray_set_float_function(uint8_t );
2534

35+
shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t );
2636
#endif

0 commit comments

Comments
 (0)