Skip to content

Commit 52162af

Browse files
authored
MAINT: simplify power fast path logic (numpy#27901)
* MAINT: remove fast paths from array power * MAINT: Add fast paths to power loops * MAINT: Clean loops for integer power in umath * MAINT: Remove blocking regression test for power fast paths * MAINT: Add helper function for power fast paths * BUG: Change misspelled bitwise and to logical and * BUG: Fix missing value on power helper return * BUG: Fix exponent bitwise logic in power fast paths * MAINT: Add power fast paths to floating point umath * MAINT: Add fast power paths to array power when exponent is python object * MAINT: Fix division by zero runtime warning in test regression * MAINT: Adapt object regression test for linalg to power fast paths * MAINT: Remove incorrect declarations in power fast paths * MAINT: Reduce calls to power fast path helper when scalar is ineligible * MAINT: Fix missing sliding loop * BUG: Fix syntax error * MAINT: Fix semantic misuse of -1 for non-error returns * MAINT: Improve error checking in power fast paths to remove PyErr_Clear * MAINT: Improve type checking in power fast paths * MAINT: Efficient handling of ones arrays in scalar fast paths * MAINT: Simplify outer check for scalar power fast paths * MAINT: Reduce code reuse in float power fast paths and add reciprocal * MAINT: Remove Python scalar checking for fast power paths * MAINT: Add benchmarks for power operators in float binary bench * MAINT: Add scalar power fast paths * BUG: Add missing pointer cast * BUG: Allow scalar power fast paths only for non-integers * MAINT: Restore outdated changes in regression test to master
1 parent 1d77082 commit 52162af

File tree

5 files changed

+105
-192
lines changed

5 files changed

+105
-192
lines changed

benchmarks/benchmarks/bench_ufunc.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,12 @@ def time_pow_2(self, dtype):
588588
def time_pow_half(self, dtype):
589589
np.power(self.a, 0.5)
590590

591+
def time_pow_2_op(self, dtype):
592+
self.a ** 2
593+
594+
def time_pow_half_op(self, dtype):
595+
self.a ** 0.5
596+
591597
def time_atan2(self, dtype):
592598
np.arctan2(self.a, self.b)
593599

numpy/_core/src/multiarray/number.c

Lines changed: 36 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -328,165 +328,53 @@ array_inplace_matrix_multiply(PyArrayObject *self, PyObject *other)
328328
return res;
329329
}
330330

331-
/*
332-
* Determine if object is a scalar and if so, convert the object
333-
* to a double and place it in the out_exponent argument
334-
* and return the "scalar kind" as a result. If the object is
335-
* not a scalar (or if there are other error conditions)
336-
* return NPY_NOSCALAR, and out_exponent is undefined.
337-
*/
338-
static NPY_SCALARKIND
339-
is_scalar_with_conversion(PyObject *o2, double* out_exponent)
331+
static int
332+
fast_scalar_power(PyObject *o1, PyObject *o2, int inplace, PyObject **result)
340333
{
341-
PyObject *temp;
342-
const int optimize_fpexps = 1;
343-
344-
if (PyLong_Check(o2)) {
345-
long tmp = PyLong_AsLong(o2);
346-
if (error_converting(tmp)) {
347-
PyErr_Clear();
348-
return NPY_NOSCALAR;
334+
PyObject *fastop = NULL;
335+
if (PyLong_CheckExact(o2)) {
336+
int overflow = 0;
337+
long exp = PyLong_AsLongAndOverflow(o2, &overflow);
338+
if (overflow != 0) {
339+
return -1;
349340
}
350-
*out_exponent = (double)tmp;
351-
return NPY_INTPOS_SCALAR;
352-
}
353341

354-
if (optimize_fpexps && PyFloat_Check(o2)) {
355-
*out_exponent = PyFloat_AsDouble(o2);
356-
return NPY_FLOAT_SCALAR;
357-
}
358-
359-
if (PyArray_Check(o2)) {
360-
if ((PyArray_NDIM((PyArrayObject *)o2) == 0) &&
361-
((PyArray_ISINTEGER((PyArrayObject *)o2) ||
362-
(optimize_fpexps && PyArray_ISFLOAT((PyArrayObject *)o2))))) {
363-
temp = Py_TYPE(o2)->tp_as_number->nb_float(o2);
364-
if (temp == NULL) {
365-
return NPY_NOSCALAR;
366-
}
367-
*out_exponent = PyFloat_AsDouble(o2);
368-
Py_DECREF(temp);
369-
if (PyArray_ISINTEGER((PyArrayObject *)o2)) {
370-
return NPY_INTPOS_SCALAR;
371-
}
372-
else { /* ISFLOAT */
373-
return NPY_FLOAT_SCALAR;
374-
}
342+
if (exp == -1) {
343+
fastop = n_ops.reciprocal;
375344
}
376-
}
377-
else if (PyArray_IsScalar(o2, Integer) ||
378-
(optimize_fpexps && PyArray_IsScalar(o2, Floating))) {
379-
temp = Py_TYPE(o2)->tp_as_number->nb_float(o2);
380-
if (temp == NULL) {
381-
return NPY_NOSCALAR;
382-
}
383-
*out_exponent = PyFloat_AsDouble(o2);
384-
Py_DECREF(temp);
385-
386-
if (PyArray_IsScalar(o2, Integer)) {
387-
return NPY_INTPOS_SCALAR;
345+
else if (exp == 2) {
346+
fastop = n_ops.square;
388347
}
389-
else { /* IsScalar(o2, Floating) */
390-
return NPY_FLOAT_SCALAR;
348+
else {
349+
return 1;
391350
}
392351
}
393-
else if (PyIndex_Check(o2)) {
394-
PyObject* value = PyNumber_Index(o2);
395-
Py_ssize_t val;
396-
if (value == NULL) {
397-
if (PyErr_Occurred()) {
398-
PyErr_Clear();
399-
}
400-
return NPY_NOSCALAR;
352+
else if (PyFloat_CheckExact(o2)) {
353+
double exp = PyFloat_AsDouble(o2);
354+
if (exp == 0.5) {
355+
fastop = n_ops.sqrt;
401356
}
402-
val = PyLong_AsSsize_t(value);
403-
Py_DECREF(value);
404-
if (error_converting(val)) {
405-
PyErr_Clear();
406-
return NPY_NOSCALAR;
357+
else {
358+
return 1;
407359
}
408-
*out_exponent = (double) val;
409-
return NPY_INTPOS_SCALAR;
410360
}
411-
return NPY_NOSCALAR;
412-
}
361+
else {
362+
return 1;
363+
}
413364

414-
/*
415-
* optimize float array or complex array to a scalar power
416-
* returns 0 on success, -1 if no optimization is possible
417-
* the result is in value (can be NULL if an error occurred)
418-
*/
419-
static int
420-
fast_scalar_power(PyObject *o1, PyObject *o2, int inplace,
421-
PyObject **value)
422-
{
423-
double exponent;
424-
NPY_SCALARKIND kind; /* NPY_NOSCALAR is not scalar */
425-
426-
if (PyArray_Check(o1) &&
427-
!PyArray_ISOBJECT((PyArrayObject *)o1) &&
428-
((kind=is_scalar_with_conversion(o2, &exponent))>0)) {
429-
PyArrayObject *a1 = (PyArrayObject *)o1;
430-
PyObject *fastop = NULL;
431-
if (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) {
432-
if (exponent == 1.0) {
433-
fastop = n_ops.positive;
434-
}
435-
else if (exponent == -1.0) {
436-
fastop = n_ops.reciprocal;
437-
}
438-
else if (exponent == 0.0) {
439-
fastop = n_ops._ones_like;
440-
}
441-
else if (exponent == 0.5) {
442-
fastop = n_ops.sqrt;
443-
}
444-
else if (exponent == 2.0) {
445-
fastop = n_ops.square;
446-
}
447-
else {
448-
return -1;
449-
}
365+
PyArrayObject *a1 = (PyArrayObject *)o1;
366+
if (!(PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1))) {
367+
return 1;
368+
}
450369

451-
if (inplace || can_elide_temp_unary(a1)) {
452-
*value = PyArray_GenericInplaceUnaryFunction(a1, fastop);
453-
}
454-
else {
455-
*value = PyArray_GenericUnaryFunction(a1, fastop);
456-
}
457-
return 0;
458-
}
459-
/* Because this is called with all arrays, we need to
460-
* change the output if the kind of the scalar is different
461-
* than that of the input and inplace is not on ---
462-
* (thus, the input should be up-cast)
463-
*/
464-
else if (exponent == 2.0) {
465-
fastop = n_ops.square;
466-
if (inplace) {
467-
*value = PyArray_GenericInplaceUnaryFunction(a1, fastop);
468-
}
469-
else {
470-
/* We only special-case the FLOAT_SCALAR and integer types */
471-
if (kind == NPY_FLOAT_SCALAR && PyArray_ISINTEGER(a1)) {
472-
PyArray_Descr *dtype = PyArray_DescrFromType(NPY_DOUBLE);
473-
a1 = (PyArrayObject *)PyArray_CastToType(a1, dtype,
474-
PyArray_ISFORTRAN(a1));
475-
if (a1 != NULL) {
476-
/* cast always creates a new array */
477-
*value = PyArray_GenericInplaceUnaryFunction(a1, fastop);
478-
Py_DECREF(a1);
479-
}
480-
}
481-
else {
482-
*value = PyArray_GenericUnaryFunction(a1, fastop);
483-
}
484-
}
485-
return 0;
486-
}
370+
if (inplace || can_elide_temp_unary(a1)) {
371+
*result = PyArray_GenericInplaceUnaryFunction(a1, fastop);
487372
}
488-
/* no fast operation found */
489-
return -1;
373+
else {
374+
*result = PyArray_GenericUnaryFunction(a1, fastop);
375+
}
376+
377+
return 0;
490378
}
491379

492380
static PyObject *
@@ -643,7 +531,8 @@ array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo
643531

644532
INPLACE_GIVE_UP_IF_NEEDED(
645533
a1, o2, nb_inplace_power, array_inplace_power);
646-
if (fast_scalar_power((PyObject *)a1, o2, 1, &value) != 0) {
534+
535+
if (fast_scalar_power((PyObject *) a1, o2, 1, &value) != 0) {
647536
value = PyArray_GenericInplaceBinaryFunction(a1, o2, n_ops.power);
648537
}
649538
return value;

numpy/_core/src/umath/loops.c.src

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -486,28 +486,54 @@ _@TYPE@_squared_exponentiation_helper(@type@ base, @type@ exponent_two, int firs
486486
return out;
487487
}
488488

489+
static inline @type@
490+
_@TYPE@_power_fast_path_helper(@type@ in1, @type@ in2, @type@ *op1) {
491+
// Fast path for power calculation
492+
if (in2 == 0 || in1 == 1) {
493+
*op1 = 1;
494+
}
495+
else if (in2 == 1) {
496+
*op1 = in1;
497+
}
498+
else if (in2 == 2) {
499+
*op1 = in1 * in1;
500+
}
501+
else {
502+
return 1;
503+
}
504+
return 0;
505+
}
506+
507+
489508
NPY_NO_EXPORT void
490509
@TYPE@_power(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
491510
{
492511
if (steps[1]==0) {
493512
// stride for second argument is 0
494513
BINARY_DEFS
495514
const @type@ in2 = *(@type@ *)ip2;
496-
#if @SIGNED@
497-
if (in2 < 0) {
498-
npy_gil_error(PyExc_ValueError,
499-
"Integers to negative integer powers are not allowed.");
500-
return;
501-
}
502-
#endif
515+
516+
#if @SIGNED@
517+
if (in2 < 0) {
518+
npy_gil_error(PyExc_ValueError,
519+
"Integers to negative integer powers are not allowed.");
520+
return;
521+
}
522+
#endif
503523

504524
int first_bit = in2 & 1;
505525
@type@ in2start = in2 >> 1;
506526

527+
int fastop_exists = (in2 == 0) || (in2 == 1) || (in2 == 2);
528+
507529
BINARY_LOOP_SLIDING {
508530
@type@ in1 = *(@type@ *)ip1;
509-
510-
*((@type@ *) op1) = _@TYPE@_squared_exponentiation_helper(in1, in2start, first_bit);
531+
if (fastop_exists) {
532+
_@TYPE@_power_fast_path_helper(in1, in2, (@type@ *)op1);
533+
}
534+
else {
535+
*((@type@ *) op1) = _@TYPE@_squared_exponentiation_helper(in1, in2start, first_bit);
536+
}
511537
}
512538
return;
513539
}
@@ -518,22 +544,16 @@ NPY_NO_EXPORT void
518544
#if @SIGNED@
519545
if (in2 < 0) {
520546
npy_gil_error(PyExc_ValueError,
521-
"Integers to negative integer powers are not allowed.");
547+
"Integers to negative integer powers are not allowed.");
522548
return;
523549
}
524550
#endif
525-
if (in2 == 0) {
526-
*((@type@ *)op1) = 1;
527-
continue;
528-
}
529-
if (in1 == 1) {
530-
*((@type@ *)op1) = 1;
531-
continue;
532-
}
533551

534-
int first_bit = in2 & 1;
535-
in2 >>= 1;
536-
*((@type@ *) op1) = _@TYPE@_squared_exponentiation_helper(in1, in2, first_bit);
552+
if (_@TYPE@_power_fast_path_helper(in1, in2, (@type@ *)op1) != 0) {
553+
int first_bit = in2 & 1;
554+
in2 >>= 1;
555+
*((@type@ *) op1) = _@TYPE@_squared_exponentiation_helper(in1, in2, first_bit);
556+
}
537557
}
538558
}
539559
/**end repeat**/

numpy/_core/src/umath/loops_umath_fp.dispatch.c.src

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,30 @@ NPY_NO_EXPORT void NPY_CPU_DISPATCH_CURFX(@TYPE@_@func@)
239239
if (stride_zero) {
240240
BINARY_DEFS
241241
const @type@ in2 = *(@type@ *)ip2;
242-
if (in2 == 2.0) {
243-
BINARY_LOOP_SLIDING {
244-
const @type@ in1 = *(@type@ *)ip1;
242+
int fastop_found = 1;
243+
BINARY_LOOP_SLIDING {
244+
const @type@ in1 = *(@type@ *)ip1;
245+
if (in2 == -1.0) {
246+
*(@type@ *)op1 = 1.0 / in1;
247+
}
248+
else if (in2 == 0.0) {
249+
*(@type@ *)op1 = 1.0;
250+
}
251+
else if (in2 == 0.5) {
252+
*(@type@ *)op1 = @sqrt@(in1);
253+
}
254+
else if (in2 == 1.0) {
255+
*(@type@ *)op1 = in1;
256+
}
257+
else if (in2 == 2.0) {
245258
*(@type@ *)op1 = in1 * in1;
246259
}
260+
else {
261+
fastop_found = 0;
262+
break;
263+
}
264+
}
265+
if (fastop_found) {
247266
return;
248267
}
249268
}

numpy/_core/tests/test_multiarray.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4125,27 +4125,6 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kw):
41254125
assert_equal(A[0], 30)
41264126
assert_(isinstance(A, OutClass))
41274127

4128-
def test_pow_override_with_errors(self):
4129-
# regression test for gh-9112
4130-
class PowerOnly(np.ndarray):
4131-
def __array_ufunc__(self, ufunc, method, *inputs, **kw):
4132-
if ufunc is not np.power:
4133-
raise NotImplementedError
4134-
return "POWER!"
4135-
# explicit cast to float, to ensure the fast power path is taken.
4136-
a = np.array(5., dtype=np.float64).view(PowerOnly)
4137-
assert_equal(a ** 2.5, "POWER!")
4138-
with assert_raises(NotImplementedError):
4139-
a ** 0.5
4140-
with assert_raises(NotImplementedError):
4141-
a ** 0
4142-
with assert_raises(NotImplementedError):
4143-
a ** 1
4144-
with assert_raises(NotImplementedError):
4145-
a ** -1
4146-
with assert_raises(NotImplementedError):
4147-
a ** 2
4148-
41494128
def test_pow_array_object_dtype(self):
41504129
# test pow on arrays of object dtype
41514130
class SomeClass:

0 commit comments

Comments
 (0)