Skip to content

Commit b5251e7

Browse files
authored
Merge pull request numpy#21499 from mattip/avoid-0dim-reduce
ENH: avoid looping when dimensions[0] == 0 or array.size == 0
2 parents 02d1204 + 7328dba commit b5251e7

File tree

3 files changed

+40
-24
lines changed

3 files changed

+40
-24
lines changed

numpy/core/src/multiarray/array_method.c

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -863,15 +863,17 @@ generic_masked_strided_loop(PyArrayMethod_Context *context,
863863

864864
/* Process unmasked values */
865865
mask = npy_memchr(mask, 0, mask_stride, N, &subloopsize, 0);
866-
int res = strided_loop(context,
867-
dataptrs, &subloopsize, strides, strided_loop_auxdata);
868-
if (res != 0) {
869-
return res;
870-
}
871-
for (int i = 0; i < nargs; i++) {
872-
dataptrs[i] += subloopsize * strides[i];
866+
if (subloopsize > 0) {
867+
int res = strided_loop(context,
868+
dataptrs, &subloopsize, strides, strided_loop_auxdata);
869+
if (res != 0) {
870+
return res;
871+
}
872+
for (int i = 0; i < nargs; i++) {
873+
dataptrs[i] += subloopsize * strides[i];
874+
}
875+
N -= subloopsize;
873876
}
874-
N -= subloopsize;
875877
} while (N > 0);
876878

877879
return 0;

numpy/core/src/umath/fast_loop_macros.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#ifndef _NPY_UMATH_FAST_LOOP_MACROS_H_
1111
#define _NPY_UMATH_FAST_LOOP_MACROS_H_
1212

13+
#include <assert.h>
14+
1315
/*
1416
* MAX_STEP_SIZE is used to determine if we need to use SIMD version of the ufunc.
1517
* Very large step size can be as slow as processing it using scalar. The
@@ -99,12 +101,19 @@ abs_ptrdiff(char *a, char *b)
99101

100102
#define IS_OUTPUT_CONT(tout) (steps[1] == sizeof(tout))
101103

102-
#define IS_BINARY_REDUCE ((args[0] == args[2])\
104+
/*
105+
* Make sure dimensions is non-zero with an assert, to allow subsequent code
106+
* to ignore problems of accessing invalid memory
107+
*/
108+
109+
#define IS_BINARY_REDUCE (assert(dimensions[0] != 0), \
110+
(args[0] == args[2])\
103111
&& (steps[0] == steps[2])\
104112
&& (steps[0] == 0))
105113

106114
/* input contiguous (for binary reduces only) */
107-
#define IS_BINARY_REDUCE_INPUT_CONT(tin) (steps[1] == sizeof(tin))
115+
#define IS_BINARY_REDUCE_INPUT_CONT(tin) (assert(dimensions[0] != 0), \
116+
steps[1] == sizeof(tin))
108117

109118
/* binary loop input and output contiguous */
110119
#define IS_BINARY_CONT(tin, tout) (steps[0] == sizeof(tin) && \

numpy/core/src/umath/ufunc_object.c

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,6 +1321,10 @@ try_trivial_single_output_loop(PyArrayMethod_Context *context,
13211321
*/
13221322
char *data[NPY_MAXARGS];
13231323
npy_intp count = PyArray_MultiplyList(operation_shape, operation_ndim);
1324+
if (count == 0) {
1325+
/* Nothing to do */
1326+
return 0;
1327+
}
13241328
NPY_BEGIN_THREADS_DEF;
13251329

13261330
PyArrayMethod_StridedLoop *strided_loop;
@@ -2819,7 +2823,7 @@ reduce_loop(PyArrayMethod_Context *context,
28192823
npy_intp const *countptr, NpyIter_IterNextFunc *iternext,
28202824
int needs_api, npy_intp skip_first_count)
28212825
{
2822-
int retval;
2826+
int retval = 0;
28232827
char *dataptrs_copy[4];
28242828
npy_intp strides_copy[4];
28252829
npy_bool masked;
@@ -2849,19 +2853,20 @@ reduce_loop(PyArrayMethod_Context *context,
28492853
count = 0;
28502854
}
28512855
}
2852-
2853-
/* Turn the two items into three for the inner loop */
2854-
dataptrs_copy[0] = dataptrs[0];
2855-
dataptrs_copy[1] = dataptrs[1];
2856-
dataptrs_copy[2] = dataptrs[0];
2857-
strides_copy[0] = strides[0];
2858-
strides_copy[1] = strides[1];
2859-
strides_copy[2] = strides[0];
2860-
2861-
retval = strided_loop(context,
2862-
dataptrs_copy, &count, strides_copy, auxdata);
2863-
if (retval < 0) {
2864-
goto finish_loop;
2856+
if (count > 0) {
2857+
/* Turn the two items into three for the inner loop */
2858+
dataptrs_copy[0] = dataptrs[0];
2859+
dataptrs_copy[1] = dataptrs[1];
2860+
dataptrs_copy[2] = dataptrs[0];
2861+
strides_copy[0] = strides[0];
2862+
strides_copy[1] = strides[1];
2863+
strides_copy[2] = strides[0];
2864+
2865+
retval = strided_loop(context,
2866+
dataptrs_copy, &count, strides_copy, auxdata);
2867+
if (retval < 0) {
2868+
goto finish_loop;
2869+
}
28652870
}
28662871

28672872
/* Advance loop, and abort on error (or finish) */

0 commit comments

Comments
 (0)