Skip to content

Commit 7aeb73a

Browse files
committed
ndarray: Fix memoryview(ulab.array(...))
For now this only handles the 1D case. In theory it would work for any dense array, however, I found that ndarray_is_dense didn't behave for me so I implemented this instead. Add a test. Before the change, this test would segfault. Closes #328.
1 parent 743d864 commit 7aeb73a

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

code/ndarray.c

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1986,8 +1986,15 @@ mp_obj_t ndarray_info(mp_obj_t obj_in) {
19861986
MP_DEFINE_CONST_FUN_OBJ_1(ndarray_info_obj, ndarray_info);
19871987
#endif
19881988

1989+
// (the get_buffer protocol returns 0 for success, 1 for failure)
19891990
mp_int_t ndarray_get_buffer(mp_obj_t self_in, mp_buffer_info_t *bufinfo, mp_uint_t flags) {
19901991
ndarray_obj_t *self = MP_OBJ_TO_PTR(self_in);
1991-
// buffer_p.get_buffer() returns zero for success, while mp_get_buffer returns true for success
1992-
return !mp_get_buffer(self->array, bufinfo, flags);
1992+
if (self->ndim != 1 || self->strides[0] > 1) {
1993+
// For now, only allow fetching buffer of a 1d-array
1994+
return 1;
1995+
}
1996+
bufinfo->len = self->itemsize * self->len;
1997+
bufinfo->buf = self->array;
1998+
bufinfo->typecode = self->dtype;
1999+
return 0;
19932000
}

tests/common/buffer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
try:
2+
import ulab as np
3+
except:
4+
import numpy as np
5+
6+
def print_as_buffer(a):
7+
print(len(memoryview(a)), list(memoryview(a)))
8+
print_as_buffer(np.ones(3))
9+
print_as_buffer(np.zeros(3))
10+
print_as_buffer(np.ones(1, dtype=np.int8))
11+
print_as_buffer(np.ones(2, dtype=np.uint8))
12+
print_as_buffer(np.ones(3, dtype=np.int16))
13+
print_as_buffer(np.ones(4, dtype=np.uint16))
14+
print_as_buffer(np.ones(5, dtype=np.float))
15+
print_as_buffer(np.linspace(0, 1, 9))
16+

tests/common/buffer.py.exp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
3 [1.0, 1.0, 1.0]
2+
3 [0.0, 0.0, 0.0]
3+
1 [1]
4+
2 [1, 1]
5+
3 [1, 1, 1]
6+
4 [1, 1, 1, 1]
7+
5 [1.0, 1.0, 1.0, 1.0, 1.0]
8+
9 [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]

0 commit comments

Comments
 (0)