Skip to content

Commit 7124eaa

Browse files
authored
fix concatenate (#575)
1 parent 5be42be commit 7124eaa

File tree

5 files changed

+33
-1
lines changed

5 files changed

+33
-1
lines changed

code/numpy/create.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,14 @@ mp_obj_t create_concatenate(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
248248
size_t *shape = m_new0(size_t, ULAB_MAX_DIMS);
249249
mp_obj_tuple_t *ndarrays = MP_OBJ_TO_PTR(args[0].u_obj);
250250

251+
// first check, whether
252+
253+
for(uint8_t i = 0; i < ndarrays->len; i++) {
254+
if(!mp_obj_is_type(ndarrays->items[i], &ulab_ndarray_type)) {
255+
mp_raise_ValueError(translate("only ndarrays can be concatenated"));
256+
}
257+
}
258+
251259
// first check, whether the arrays are compatible
252260
ndarray_obj_t *_ndarray = MP_OBJ_TO_PTR(ndarrays->items[0]);
253261
uint8_t dtype = _ndarray->dtype;

code/ulab.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "user/user.h"
3434
#include "utils/utils.h"
3535

36-
#define ULAB_VERSION 6.0.2
36+
#define ULAB_VERSION 6.0.3
3737
#define xstr(s) str(s)
3838
#define str(s) #s
3939

docs/ulab-change-log.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
Sat, 14 Jan 2023
2+
3+
version 6.0.3
4+
5+
fix how concatenate deals with scalar inputs
6+
17
Tue, 3 Jan 2023
28

39
version 6.0.2

tests/2d/numpy/concatenate.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,18 @@
33
except:
44
import numpy as np
55

6+
# test input types; the following should raise ValueErrors
7+
objects = [([1, 2], [3, 4]),
8+
((1, 2), (3, 4)),
9+
(1, 2, 3)]
10+
11+
for obj in objects:
12+
try:
13+
np.concatenate(obj)
14+
except ValueError as e:
15+
print('ValueError: {}; failed with object {}\n'.format(e, obj))
16+
17+
618
a = np.array([1,2,3], dtype=np.float)
719
b = np.array([4,5,6], dtype=np.float)
820

tests/2d/numpy/concatenate.py.exp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
ValueError: only ndarrays can be concatenated; failed with object ([1, 2], [3, 4])
2+
3+
ValueError: only ndarrays can be concatenated; failed with object ((1, 2), (3, 4))
4+
5+
ValueError: only ndarrays can be concatenated; failed with object (1, 2, 3)
6+
17
array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=float64)
28
array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=float64)
39
array([[1.0, 2.0, 3.0],

0 commit comments

Comments
 (0)