@@ -2084,24 +2084,51 @@ MP_DEFINE_CONST_FUN_OBJ_1(ndarray_transpose_obj, ndarray_transpose);
2084
2084
#if NDARRAY_HAS_RESHAPE
2085
2085
mp_obj_t ndarray_reshape_core (mp_obj_t oin , mp_obj_t _shape , bool inplace ) {
2086
2086
ndarray_obj_t * source = MP_OBJ_TO_PTR (oin );
2087
- if (!mp_obj_is_type (_shape , & mp_type_tuple )) {
2088
- mp_raise_TypeError (translate ("shape must be a tuple" ));
2087
+ if (!mp_obj_is_type (_shape , & mp_type_tuple ) && !mp_obj_is_int (_shape )) {
2088
+ mp_raise_TypeError (translate ("shape must be integer or tuple of integers" ));
2089
+ }
2090
+
2091
+ mp_obj_tuple_t * shape ;
2092
+
2093
+ if (mp_obj_is_int (_shape )) {
2094
+ mp_obj_t * items = m_new (mp_obj_t , 1 );
2095
+ items [0 ] = _shape ;
2096
+ shape = mp_obj_new_tuple (1 , items );
2097
+ } else {
2098
+ shape = MP_OBJ_TO_PTR (_shape );
2089
2099
}
2090
2100
2091
- mp_obj_tuple_t * shape = MP_OBJ_TO_PTR (_shape );
2092
2101
if (shape -> len > ULAB_MAX_DIMS ) {
2093
2102
mp_raise_ValueError (translate ("maximum number of dimensions is " MP_STRINGIFY (ULAB_MAX_DIMS )));
2094
2103
}
2095
- size_t * new_shape = m_new0 (size_t , ULAB_MAX_DIMS );
2096
2104
2097
2105
size_t new_length = 1 ;
2098
- for (uint8_t i = 0 ; i < shape -> len ; i ++ ) {
2099
- new_shape [ULAB_MAX_DIMS - i - 1 ] = mp_obj_get_int (shape -> items [shape -> len - i - 1 ]);
2100
- new_length *= new_shape [ULAB_MAX_DIMS - i - 1 ];
2106
+ size_t * new_shape = m_new0 (size_t , ULAB_MAX_DIMS );
2107
+ uint8_t unknown_dim = 0 ;
2108
+ uint8_t unknown_index = 0 ;
2109
+
2110
+ for (uint8_t i = 0 ; i < shape -> len ; i ++ ) {
2111
+ int32_t ax_len = mp_obj_get_int (shape -> items [shape -> len - i - 1 ]);
2112
+ if (ax_len >= 0 ) {
2113
+ new_shape [ULAB_MAX_DIMS - i - 1 ] = (size_t )ax_len ;
2114
+ new_length *= new_shape [ULAB_MAX_DIMS - i - 1 ];
2115
+ } else {
2116
+ unknown_dim ++ ;
2117
+ unknown_index = ULAB_MAX_DIMS - i - 1 ;
2118
+ }
2119
+ }
2120
+
2121
+ if (unknown_dim > 1 ) {
2122
+ mp_raise_ValueError (translate ("can only specify one unknown dimension" ));
2123
+ } else if (unknown_dim == 1 ) {
2124
+ new_shape [unknown_index ] = source -> len / new_length ;
2125
+ new_length = source -> len ;
2101
2126
}
2127
+
2102
2128
if (source -> len != new_length ) {
2103
- mp_raise_ValueError (translate ("input and output shapes are not compatible " ));
2129
+ mp_raise_ValueError (translate ("cannot reshape array " ));
2104
2130
}
2131
+
2105
2132
ndarray_obj_t * ndarray ;
2106
2133
if (ndarray_is_dense (source )) {
2107
2134
int32_t * new_strides = strides_from_shape (new_shape , source -> dtype );
@@ -2118,7 +2145,11 @@ mp_obj_t ndarray_reshape_core(mp_obj_t oin, mp_obj_t _shape, bool inplace) {
2118
2145
if (inplace ) {
2119
2146
mp_raise_ValueError (translate ("cannot assign new shape" ));
2120
2147
}
2121
- ndarray = ndarray_new_ndarray_from_tuple (shape , source -> dtype );
2148
+ if (mp_obj_is_type (_shape , & mp_type_tuple )) {
2149
+ ndarray = ndarray_new_ndarray_from_tuple (shape , source -> dtype );
2150
+ } else {
2151
+ ndarray = ndarray_new_linear_array (source -> len , source -> dtype );
2152
+ }
2122
2153
ndarray_copy_array (source , ndarray , 0 );
2123
2154
}
2124
2155
return MP_OBJ_FROM_PTR (ndarray );
0 commit comments