Skip to content

Commit beda4c1

Browse files
authored
implement unknown shape dimension in reshape (#612)
1 parent 412b13f commit beda4c1

File tree

3 files changed

+47
-10
lines changed

3 files changed

+47
-10
lines changed

code/ndarray.c

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,24 +2084,51 @@ MP_DEFINE_CONST_FUN_OBJ_1(ndarray_transpose_obj, ndarray_transpose);
20842084
#if NDARRAY_HAS_RESHAPE
20852085
mp_obj_t ndarray_reshape_core(mp_obj_t oin, mp_obj_t _shape, bool inplace) {
20862086
ndarray_obj_t *source = MP_OBJ_TO_PTR(oin);
2087-
if(!mp_obj_is_type(_shape, &mp_type_tuple)) {
2088-
mp_raise_TypeError(translate("shape must be a tuple"));
2087+
if(!mp_obj_is_type(_shape, &mp_type_tuple) && !mp_obj_is_int(_shape)) {
2088+
mp_raise_TypeError(translate("shape must be integer or tuple of integers"));
2089+
}
2090+
2091+
mp_obj_tuple_t *shape;
2092+
2093+
if(mp_obj_is_int(_shape)) {
2094+
mp_obj_t *items = m_new(mp_obj_t, 1);
2095+
items[0] = _shape;
2096+
shape = mp_obj_new_tuple(1, items);
2097+
} else {
2098+
shape = MP_OBJ_TO_PTR(_shape);
20892099
}
20902100

2091-
mp_obj_tuple_t *shape = MP_OBJ_TO_PTR(_shape);
20922101
if(shape->len > ULAB_MAX_DIMS) {
20932102
mp_raise_ValueError(translate("maximum number of dimensions is " MP_STRINGIFY(ULAB_MAX_DIMS)));
20942103
}
2095-
size_t *new_shape = m_new0(size_t, ULAB_MAX_DIMS);
20962104

20972105
size_t new_length = 1;
2098-
for(uint8_t i=0; i < shape->len; i++) {
2099-
new_shape[ULAB_MAX_DIMS - i - 1] = mp_obj_get_int(shape->items[shape->len - i - 1]);
2100-
new_length *= new_shape[ULAB_MAX_DIMS - i - 1];
2106+
size_t *new_shape = m_new0(size_t, ULAB_MAX_DIMS);
2107+
uint8_t unknown_dim = 0;
2108+
uint8_t unknown_index = 0;
2109+
2110+
for(uint8_t i = 0; i < shape->len; i++) {
2111+
int32_t ax_len = mp_obj_get_int(shape->items[shape->len - i - 1]);
2112+
if(ax_len >= 0) {
2113+
new_shape[ULAB_MAX_DIMS - i - 1] = (size_t)ax_len;
2114+
new_length *= new_shape[ULAB_MAX_DIMS - i - 1];
2115+
} else {
2116+
unknown_dim++;
2117+
unknown_index = ULAB_MAX_DIMS - i - 1;
2118+
}
2119+
}
2120+
2121+
if(unknown_dim > 1) {
2122+
mp_raise_ValueError(translate("can only specify one unknown dimension"));
2123+
} else if(unknown_dim == 1) {
2124+
new_shape[unknown_index] = source->len / new_length;
2125+
new_length = source->len;
21012126
}
2127+
21022128
if(source->len != new_length) {
2103-
mp_raise_ValueError(translate("input and output shapes are not compatible"));
2129+
mp_raise_ValueError(translate("cannot reshape array"));
21042130
}
2131+
21052132
ndarray_obj_t *ndarray;
21062133
if(ndarray_is_dense(source)) {
21072134
int32_t *new_strides = strides_from_shape(new_shape, source->dtype);
@@ -2118,7 +2145,11 @@ mp_obj_t ndarray_reshape_core(mp_obj_t oin, mp_obj_t _shape, bool inplace) {
21182145
if(inplace) {
21192146
mp_raise_ValueError(translate("cannot assign new shape"));
21202147
}
2121-
ndarray = ndarray_new_ndarray_from_tuple(shape, source->dtype);
2148+
if(mp_obj_is_type(_shape, &mp_type_tuple)) {
2149+
ndarray = ndarray_new_ndarray_from_tuple(shape, source->dtype);
2150+
} else {
2151+
ndarray = ndarray_new_linear_array(source->len, source->dtype);
2152+
}
21222153
ndarray_copy_array(source, ndarray, 0);
21232154
}
21242155
return MP_OBJ_FROM_PTR(ndarray);

code/ulab.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "user/user.h"
3434
#include "utils/utils.h"
3535

36-
#define ULAB_VERSION 6.0.10
36+
#define ULAB_VERSION 6.0.11
3737
#define xstr(s) str(s)
3838
#define str(s) #s
3939

docs/ulab-change-log.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
Sat, 6 May 2023
22

3+
version 6.0.11
4+
5+
.reshape can now interpret unknown shape dimension
6+
7+
Sat, 6 May 2023
8+
39
version 6.0.10
410

511
fix binary division

0 commit comments

Comments
 (0)