Skip to content

Commit 6867951

Browse files
committed
fixed indexing glitch in tools_reduce_axes
1 parent cacb1b6 commit 6867951

File tree

3 files changed

+28
-28
lines changed

3 files changed

+28
-28
lines changed

code/numpy/linalg/linalg.c

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -419,30 +419,30 @@ static mp_obj_t linalg_norm(size_t n_args, const mp_obj_t *pos_args, mp_map_t *k
419419
do {
420420
value = func(array);
421421
dot = dot + (value * value - dot) / count++;
422-
array += _shape_strides.strides[ULAB_MAX_DIMS - 1];
422+
array += _shape_strides.strides[0];
423423
l++;
424-
} while(l < _shape_strides.shape[ULAB_MAX_DIMS - 1]);
424+
} while(l < _shape_strides.shape[0]);
425425
*rarray = MICROPY_FLOAT_C_FUN(sqrt)(dot * (count - 1));
426426
if(results != NULL) {
427427
rarray++;
428428
}
429429
#if ULAB_MAX_DIMS > 1
430-
array -= _shape_strides.strides[ULAB_MAX_DIMS - 1] * _shape_strides.shape[ULAB_MAX_DIMS - 1];
431-
array += _shape_strides.strides[ULAB_MAX_DIMS - 2];
430+
array -= _shape_strides.strides[0] * _shape_strides.shape[0];
431+
array += _shape_strides.strides[ULAB_MAX_DIMS - 1];
432432
k++;
433-
} while(k < _shape_strides.shape[ULAB_MAX_DIMS - 2]);
433+
} while(k < _shape_strides.shape[ULAB_MAX_DIMS - 1]);
434434
#endif
435435
#if ULAB_MAX_DIMS > 2
436-
array -= _shape_strides.strides[ULAB_MAX_DIMS - 2] * _shape_strides.shape[ULAB_MAX_DIMS-2];
437-
array += _shape_strides.strides[ULAB_MAX_DIMS - 3];
436+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 1] * _shape_strides.shape[ULAB_MAX_DIMS-1];
437+
array += _shape_strides.strides[ULAB_MAX_DIMS - 2];
438438
j++;
439-
} while(j < _shape_strides.shape[ULAB_MAX_DIMS - 3]);
439+
} while(j < _shape_strides.shape[ULAB_MAX_DIMS - 2]);
440440
#endif
441441
#if ULAB_MAX_DIMS > 3
442-
array -= _shape_strides.strides[ULAB_MAX_DIMS - 3] * _shape_strides.shape[ULAB_MAX_DIMS-3];
443-
array += _shape_strides.strides[ULAB_MAX_DIMS - 4];
442+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 2] * _shape_strides.shape[ULAB_MAX_DIMS-2];
443+
array += _shape_strides.strides[ULAB_MAX_DIMS - 3];
444444
i++;
445-
} while(i < _shape_strides.shape[ULAB_MAX_DIMS - 4]);
445+
} while(i < _shape_strides.shape[ULAB_MAX_DIMS - 3]);
446446
#endif
447447
if(results == NULL) {
448448
return mp_obj_new_float(*rarray);

code/numpy/numerical/numerical.c

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,19 @@ static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
9898
l++;
9999
} while(l < ndarray->shape[ULAB_MAX_DIMS - 1]);
100100
#if ULAB_MAX_DIMS > 1
101-
array -= ndarray->strides[ULAB_MAX_DIMS - 1] * ndarray->shape[ULAB_MAX_DIMS-1];
101+
array -= ndarray->strides[ULAB_MAX_DIMS - 1] * ndarray->shape[ULAB_MAX_DIMS - 1];
102102
array += ndarray->strides[ULAB_MAX_DIMS - 2];
103103
k++;
104104
} while(k < ndarray->shape[ULAB_MAX_DIMS - 2]);
105105
#endif
106106
#if ULAB_MAX_DIMS > 2
107-
array -= ndarray->strides[ULAB_MAX_DIMS - 2] * ndarray->shape[ULAB_MAX_DIMS-2];
107+
array -= ndarray->strides[ULAB_MAX_DIMS - 2] * ndarray->shape[ULAB_MAX_DIMS - 2];
108108
array += ndarray->strides[ULAB_MAX_DIMS - 3];
109109
j++;
110110
} while(j < ndarray->shape[ULAB_MAX_DIMS - 3]);
111111
#endif
112112
#if ULAB_MAX_DIMS > 3
113-
array -= ndarray->strides[ULAB_MAX_DIMS - 3] * ndarray->shape[ULAB_MAX_DIMS-3];
113+
array -= ndarray->strides[ULAB_MAX_DIMS - 3] * ndarray->shape[ULAB_MAX_DIMS - 3];
114114
array += ndarray->strides[ULAB_MAX_DIMS - 4];
115115
i++;
116116
} while(i < ndarray->shape[ULAB_MAX_DIMS - 4]);
@@ -141,33 +141,33 @@ static mp_obj_t numerical_all_any(mp_obj_t oin, mp_obj_t axis, uint8_t optype) {
141141
// optype == NUMERICAL_ANY
142142
*rarray = 1;
143143
// since we are breaking out of the loop, move the pointer forward
144-
array += ndarray->strides[_shape_strides.index] * (ndarray->shape[_shape_strides.index] - l);
144+
array += _shape_strides.strides[0] * (_shape_strides.shape[0] - l);
145145
break;
146146
} else if((value == MICROPY_FLOAT_CONST(0.0)) & anytype) {
147147
// optype == NUMERICAL_ALL
148148
*rarray = 0;
149149
// since we are breaking out of the loop, move the pointer forward
150-
array += ndarray->strides[_shape_strides.index] * (ndarray->shape[_shape_strides.index] - l);
150+
array += _shape_strides.strides[0] * (_shape_strides.shape[0] - l);
151151
break;
152152
}
153-
array += ndarray->strides[_shape_strides.index];
153+
array += _shape_strides.strides[0];
154154
l++;
155-
} while(l < ndarray->shape[_shape_strides.index]);
155+
} while(l < _shape_strides.shape[0]);
156156
#if ULAB_MAX_DIMS > 1
157157
rarray++;
158-
array -= ndarray->strides[_shape_strides.index] * ndarray->shape[_shape_strides.index];
158+
array -= _shape_strides.strides[0] * _shape_strides.shape[0];
159159
array += _shape_strides.strides[ULAB_MAX_DIMS - 1];
160160
k++;
161161
} while(k < _shape_strides.shape[ULAB_MAX_DIMS - 1]);
162162
#endif
163163
#if ULAB_MAX_DIMS > 2
164-
array -= _shape_strides.strides[ULAB_MAX_DIMS - 1] * _shape_strides.shape[ULAB_MAX_DIMS-1];
164+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 1] * _shape_strides.shape[ULAB_MAX_DIMS - 1];
165165
array += _shape_strides.strides[ULAB_MAX_DIMS - 2];
166166
j++;
167167
} while(j < _shape_strides.shape[ULAB_MAX_DIMS - 2]);
168168
#endif
169169
#if ULAB_MAX_DIMS > 3
170-
array -= _shape_strides.strides[ULAB_MAX_DIMS - 2] * _shape_strides.shape[ULAB_MAX_DIMS-2];
170+
array -= _shape_strides.strides[ULAB_MAX_DIMS - 2] * _shape_strides.shape[ULAB_MAX_DIMS - 2];
171171
array += _shape_strides.strides[ULAB_MAX_DIMS - 3];
172172
i++;
173173
} while(i < _shape_strides.shape[ULAB_MAX_DIMS - 3])

code/ulab_tools.c

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,19 +174,19 @@ shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
174174
memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS);
175175
// for axis == mp_const_none, simply return the original shape and strides
176176
if(axis != mp_const_none) {
177-
// move the axis to the rightmost position, and align everything else to the right
178177
int8_t ax = mp_obj_get_int(axis);
179178
if(ax < 0) ax += ndarray->ndim;
180179
if((ax < 0) || (ax > ndarray->ndim - 1)) {
181180
mp_raise_ValueError(translate("index out of range"));
182181
}
182+
// move the axis to the leftmost position, and align everything else to the right
183183
uint8_t index = ULAB_MAX_DIMS - ndarray->ndim + ax;
184-
_shape_strides.shape[ULAB_MAX_DIMS - 1] = ndarray->shape[index];
185-
_shape_strides.strides[ULAB_MAX_DIMS - 1] = ndarray->strides[index];
186-
for(uint8_t i = index; i < ULAB_MAX_DIMS - 1; i++) {
187-
// entries to the right of index must be shifted to the left
188-
_shape_strides.shape[i] = ndarray->shape[i+1];
189-
_shape_strides.strides[i] = ndarray->strides[i+1];
184+
_shape_strides.shape[0] = ndarray->shape[index];
185+
_shape_strides.strides[0] = ndarray->strides[index];
186+
for(uint8_t i = 0; i < index; i++) {
187+
// entries to the left of index must be shifted to the right
188+
_shape_strides.shape[i + 1] = ndarray->shape[i];
189+
_shape_strides.strides[i + 1] = ndarray->strides[i];
190190
}
191191
}
192192
return _shape_strides;

0 commit comments

Comments
 (0)