Skip to content

Commit acfec3e

Browse files
authored
fix reshape (#660)
1 parent 1c37edb commit acfec3e

File tree

4 files changed

+58
-14
lines changed

4 files changed

+58
-14
lines changed

code/ndarray.c

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -558,13 +558,9 @@ ndarray_obj_t *ndarray_new_dense_ndarray(uint8_t ndim, size_t *shape, uint8_t dt
558558
ndarray_obj_t *ndarray_new_ndarray_from_tuple(mp_obj_tuple_t *_shape, uint8_t dtype) {
559559
// creates a dense array from a tuple
560560
// the function should work in the general n-dimensional case
561-
size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
562-
for(size_t i = 0; i < ULAB_MAX_DIMS; i++) {
563-
if(i >= _shape->len) {
564-
shape[ULAB_MAX_DIMS - 1 - i] = 0;
565-
} else {
566-
shape[ULAB_MAX_DIMS - 1 - i] = mp_obj_get_int(_shape->items[i]);
567-
}
561+
size_t *shape = m_new0(size_t, ULAB_MAX_DIMS);
562+
for(size_t i = 0; i < _shape->len; i++) {
563+
shape[ULAB_MAX_DIMS - 1 - i] = mp_obj_get_int(_shape->items[_shape->len - 1 - i]);
568564
}
569565
return ndarray_new_dense_ndarray(_shape->len, shape, dtype);
570566
}
@@ -2021,7 +2017,7 @@ mp_obj_t ndarray_reshape_core(mp_obj_t oin, mp_obj_t _shape, bool inplace) {
20212017
mp_obj_t *items = m_new(mp_obj_t, 1);
20222018
items[0] = _shape;
20232019
shape = mp_obj_new_tuple(1, items);
2024-
} else {
2020+
} else { // at this point it's certain that _shape is a tuple
20252021
shape = MP_OBJ_TO_PTR(_shape);
20262022
}
20272023

@@ -2072,11 +2068,7 @@ mp_obj_t ndarray_reshape_core(mp_obj_t oin, mp_obj_t _shape, bool inplace) {
20722068
if(inplace) {
20732069
mp_raise_ValueError(MP_ERROR_TEXT("cannot assign new shape"));
20742070
}
2075-
if(mp_obj_is_type(_shape, &mp_type_tuple)) {
2076-
ndarray = ndarray_new_ndarray_from_tuple(shape, source->dtype);
2077-
} else {
2078-
ndarray = ndarray_new_linear_array(source->len, source->dtype);
2079-
}
2071+
ndarray = ndarray_new_dense_ndarray(shape->len, new_shape, source->dtype);
20802072
ndarray_copy_array(source, ndarray, 0);
20812073
}
20822074
return MP_OBJ_FROM_PTR(ndarray);

code/ulab.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "user/user.h"
3434
#include "utils/utils.h"
3535

36-
#define ULAB_VERSION 6.5.0
36+
#define ULAB_VERSION 6.5.1
3737
#define xstr(s) str(s)
3838
#define str(s) #s
3939

tests/2d/numpy/reshape.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
try:
2+
from ulab import numpy as np
3+
except ImportError:
4+
import numpy as np
5+
6+
dtypes = (np.uint8, np.int8, np.uint16, np.int16, np.float)
7+
8+
for dtype in dtypes:
9+
print()
10+
print('=' * 50)
11+
a = np.array(range(12), dtype=dtype).reshape((3, 4))
12+
print(a)
13+
b = a[0,:]
14+
print(b.reshape((1,4)))
15+
b = a[:,0]
16+
print(b.reshape((1,3)))
17+

tests/2d/numpy/reshape.py.exp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
2+
==================================================
3+
array([[0, 1, 2, 3],
4+
[4, 5, 6, 7],
5+
[8, 9, 10, 11]], dtype=uint8)
6+
array([[0, 1, 2, 3]], dtype=uint8)
7+
array([[0, 4, 8]], dtype=uint8)
8+
9+
==================================================
10+
array([[0, 1, 2, 3],
11+
[4, 5, 6, 7],
12+
[8, 9, 10, 11]], dtype=int8)
13+
array([[0, 1, 2, 3]], dtype=int8)
14+
array([[0, 4, 8]], dtype=int8)
15+
16+
==================================================
17+
array([[0, 1, 2, 3],
18+
[4, 5, 6, 7],
19+
[8, 9, 10, 11]], dtype=uint16)
20+
array([[0, 1, 2, 3]], dtype=uint16)
21+
array([[0, 4, 8]], dtype=uint16)
22+
23+
==================================================
24+
array([[0, 1, 2, 3],
25+
[4, 5, 6, 7],
26+
[8, 9, 10, 11]], dtype=int16)
27+
array([[0, 1, 2, 3]], dtype=int16)
28+
array([[0, 4, 8]], dtype=int16)
29+
30+
==================================================
31+
array([[0.0, 1.0, 2.0, 3.0],
32+
[4.0, 5.0, 6.0, 7.0],
33+
[8.0, 9.0, 10.0, 11.0]], dtype=float64)
34+
array([[0.0, 1.0, 2.0, 3.0]], dtype=float64)
35+
array([[0.0, 4.0, 8.0]], dtype=float64)

0 commit comments

Comments
 (0)