Skip to content

Commit e329206

Browse files
authored
Fix (u)int8 upcasting as per docs and numpy (#650)
* fix wrong #if guard in ndarray_inplace_ams * implement (u)int8 upcasting rules as per documentation * bump version
1 parent 4bde4ef commit e329206

File tree

4 files changed

+13
-7
lines changed

4 files changed

+13
-7
lines changed

code/ndarray_operators.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ mp_obj_t ndarray_binary_add(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
181181

182182
if(lhs->dtype == NDARRAY_UINT8) {
183183
if(rhs->dtype == NDARRAY_UINT8) {
184-
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
185-
BINARY_LOOP(results, uint16_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, +);
184+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT8);
185+
BINARY_LOOP(results, uint8_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, +);
186186
} else if(rhs->dtype == NDARRAY_INT8) {
187187
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
188188
BINARY_LOOP(results, int16_t, uint8_t, int8_t, larray, lstrides, rarray, rstrides, +);
@@ -264,8 +264,8 @@ mp_obj_t ndarray_binary_multiply(ndarray_obj_t *lhs, ndarray_obj_t *rhs,
264264

265265
if(lhs->dtype == NDARRAY_UINT8) {
266266
if(rhs->dtype == NDARRAY_UINT8) {
267-
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
268-
BINARY_LOOP(results, uint16_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, *);
267+
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT8);
268+
BINARY_LOOP(results, uint8_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, *);
269269
} else if(rhs->dtype == NDARRAY_INT8) {
270270
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
271271
BINARY_LOOP(results, int16_t, uint8_t, int8_t, larray, lstrides, rarray, rstrides, *);
@@ -1059,7 +1059,7 @@ mp_obj_t ndarray_inplace_ams(ndarray_obj_t *lhs, ndarray_obj_t *rhs, int32_t *rs
10591059
UNWRAP_INPLACE_OPERATOR(lhs, larray, rarray, rstrides, +=);
10601060
}
10611061
#endif
1062-
#if NDARRAY_HAS_INPLACE_ADD
1062+
#if NDARRAY_HAS_INPLACE_MULTIPLY
10631063
if(optype == MP_BINARY_OP_INPLACE_MULTIPLY) {
10641064
UNWRAP_INPLACE_OPERATOR(lhs, larray, rarray, rstrides, *=);
10651065
}

code/ulab.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "user/user.h"
3434
#include "utils/utils.h"
3535

36-
#define ULAB_VERSION 6.4.1
36+
#define ULAB_VERSION 6.4.2
3737
#define xstr(s) str(s)
3838
#define str(s) #s
3939

docs/ulab-change-log.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
Thu, 11 Dec 2023
2+
3+
version 6.4.2
4+
5+
fix upcasting with two uint8 operands (#650)
6+
17
Thu, 10 Aug 2023
28

39
version 6.4.1

tests/2d/numpy/operators.py.exp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ array([1.0, 2.0, 3.0], dtype=float64)
9494
array([1.0, 32.0, 729.0], dtype=float64)
9595
array([1.0, 32.0, 729.0], dtype=float64)
9696
array([1.0, 32.0, 729.0], dtype=float64)
97-
array([5, 7, 9], dtype=uint16)
97+
array([5, 7, 9], dtype=uint8)
9898
array([5, 7, 9], dtype=int16)
9999
array([5, 7, 9], dtype=int8)
100100
array([5, 7, 9], dtype=uint16)

0 commit comments

Comments
 (0)