Skip to content

Commit e68bb70

Browse files
authored
fix vectorize (#568)
1 parent 42172c6 commit e68bb70

File tree

5 files changed

+109
-6
lines changed

5 files changed

+109
-6
lines changed

code/numpy/vector.c

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -759,12 +759,51 @@ static mp_obj_t vector_vectorized_function_call(mp_obj_t self_in, size_t n_args,
759759
if(mp_obj_is_type(args[0], &ulab_ndarray_type)) {
760760
ndarray_obj_t *source = MP_OBJ_TO_PTR(args[0]);
761761
COMPLEX_DTYPE_NOT_IMPLEMENTED(source->dtype)
762+
762763
ndarray_obj_t *ndarray = ndarray_new_dense_ndarray(source->ndim, source->shape, self->otypes);
763-
for(size_t i=0; i < source->len; i++) {
764-
avalue[0] = mp_binary_get_val_array(source->dtype, source->array, i);
765-
fvalue = MP_OBJ_TYPE_GET_SLOT(self->type, call)(self->fun, 1, 0, avalue);
766-
ndarray_set_value(self->otypes, ndarray->array, i, fvalue);
767-
}
764+
uint8_t *sarray = (uint8_t *)source->array;
765+
uint8_t *narray = (uint8_t *)ndarray->array;
766+
767+
#if ULAB_MAX_DIMS > 3
768+
size_t i = 0;
769+
do {
770+
#endif
771+
#if ULAB_MAX_DIMS > 2
772+
size_t j = 0;
773+
do {
774+
#endif
775+
#if ULAB_MAX_DIMS > 1
776+
size_t k = 0;
777+
do {
778+
#endif
779+
size_t l = 0;
780+
do {
781+
avalue[0] = mp_binary_get_val_array(source->dtype, sarray, 0);
782+
fvalue = MP_OBJ_TYPE_GET_SLOT(self->type, call)(self->fun, 1, 0, avalue);
783+
ndarray_set_value(self->otypes, narray, 0, fvalue);
784+
sarray += source->strides[ULAB_MAX_DIMS - 1];
785+
narray += ndarray->itemsize;
786+
l++;
787+
} while(l < source->shape[ULAB_MAX_DIMS - 1]);
788+
#if ULAB_MAX_DIMS > 1
789+
sarray -= source->strides[ULAB_MAX_DIMS - 1] * source->shape[ULAB_MAX_DIMS - 1];
790+
sarray += source->strides[ULAB_MAX_DIMS - 2];
791+
k++;
792+
} while(k < source->shape[ULAB_MAX_DIMS - 2]);
793+
#endif /* ULAB_MAX_DIMS > 1 */
794+
#if ULAB_MAX_DIMS > 2
795+
sarray -= source->strides[ULAB_MAX_DIMS - 2] * source->shape[ULAB_MAX_DIMS - 2];
796+
sarray += source->strides[ULAB_MAX_DIMS - 3];
797+
j++;
798+
} while(j < source->shape[ULAB_MAX_DIMS - 3]);
799+
#endif /* ULAB_MAX_DIMS > 2 */
800+
#if ULAB_MAX_DIMS > 3
801+
sarray -= source->strides[ULAB_MAX_DIMS - 3] * source->shape[ULAB_MAX_DIMS - 3];
802+
sarray += source->strides[ULAB_MAX_DIMS - 4];
803+
i++;
804+
} while(i < source->shape[ULAB_MAX_DIMS - 4]);
805+
#endif /* ULAB_MAX_DIMS > 3 */
806+
768807
return MP_OBJ_FROM_PTR(ndarray);
769808
} else if(mp_obj_is_type(args[0], &mp_type_tuple) || mp_obj_is_type(args[0], &mp_type_list) ||
770809
mp_obj_is_type(args[0], &mp_type_range)) { // i.e., the input is a generic iterable

code/ulab.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "user/user.h"
3434
#include "utils/utils.h"
3535

36-
#define ULAB_VERSION 6.0.1
36+
#define ULAB_VERSION 6.0.2
3737
#define xstr(s) str(s)
3838
#define str(s) #s
3939

docs/ulab-change-log.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
Tue, 3 Jan 2023
2+
3+
version 6.0.2
4+
5+
fix vectorize
6+
17
Sat, 5 Nov 2022
28

39
version 6.0.1

tests/2d/numpy/vectorize.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
try:
2+
from ulab import numpy as np
3+
except:
4+
import numpy as np
5+
6+
7+
dtypes = (np.uint8, np.int8, np.uint16, np.int16, np.float)
8+
9+
square = np.vectorize(lambda n: n*n)
10+
11+
for dtype in dtypes:
12+
a = np.array(range(9), dtype=dtype).reshape((3, 3))
13+
print(a)
14+
print(square(a))
15+
16+
b = a[:,2]
17+
print(square(b))
18+
print()

tests/2d/numpy/vectorize.py.exp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
array([[0, 1, 2],
2+
[3, 4, 5],
3+
[6, 7, 8]], dtype=uint8)
4+
array([[0.0, 1.0, 4.0],
5+
[9.0, 16.0, 25.0],
6+
[36.0, 49.0, 64.0]], dtype=float64)
7+
array([4.0, 25.0, 64.0], dtype=float64)
8+
9+
array([[0, 1, 2],
10+
[3, 4, 5],
11+
[6, 7, 8]], dtype=int8)
12+
array([[0.0, 1.0, 4.0],
13+
[9.0, 16.0, 25.0],
14+
[36.0, 49.0, 64.0]], dtype=float64)
15+
array([4.0, 25.0, 64.0], dtype=float64)
16+
17+
array([[0, 1, 2],
18+
[3, 4, 5],
19+
[6, 7, 8]], dtype=uint16)
20+
array([[0.0, 1.0, 4.0],
21+
[9.0, 16.0, 25.0],
22+
[36.0, 49.0, 64.0]], dtype=float64)
23+
array([4.0, 25.0, 64.0], dtype=float64)
24+
25+
array([[0, 1, 2],
26+
[3, 4, 5],
27+
[6, 7, 8]], dtype=int16)
28+
array([[0.0, 1.0, 4.0],
29+
[9.0, 16.0, 25.0],
30+
[36.0, 49.0, 64.0]], dtype=float64)
31+
array([4.0, 25.0, 64.0], dtype=float64)
32+
33+
array([[0.0, 1.0, 2.0],
34+
[3.0, 4.0, 5.0],
35+
[6.0, 7.0, 8.0]], dtype=float64)
36+
array([[0.0, 1.0, 4.0],
37+
[9.0, 16.0, 25.0],
38+
[36.0, 49.0, 64.0]], dtype=float64)
39+
array([4.0, 25.0, 64.0], dtype=float64)
40+

0 commit comments

Comments
 (0)