Skip to content

Commit 0ad8b1e

Browse files
ikrommydpfackeldey
andauthored
fix: typetracer nplike's all function leads to infinite recursion for axis=None (#3765)
* fix typetracer nplikes all function * add test * instantiate nplike only once in test --------- Co-authored-by: Peter Fackeldey <fackeldey.peter@gmail.com>
1 parent 0941b37 commit 0ad8b1e

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

src/awkward/_nplikes/typetracer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1559,7 +1559,7 @@ def all(
15591559
if axis is None:
15601560
return self.all(
15611561
cast(TypeTracerArray, self.reshape(x, (-1,))),
1562-
axis=axis,
1562+
axis=0,
15631563
keepdims=keepdims,
15641564
maybe_out=maybe_out,
15651565
)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
2+
3+
from __future__ import annotations
4+
5+
import numpy as np
6+
7+
from awkward._nplikes.typetracer import TypeTracer, TypeTracerArray
8+
9+
nplike = TypeTracer.instance()
10+
11+
12+
def test_all():
13+
buffer = TypeTracerArray._new(np.dtype("float32"), (4,))
14+
result = nplike.all(buffer, axis=None)
15+
assert isinstance(result, TypeTracerArray)
16+
assert result.dtype == np.dtype("bool")
17+
assert result.shape == ()
18+
19+
buffer = TypeTracerArray._new(np.dtype("float32"), (4, 3))
20+
result = nplike.all(buffer, axis=None)
21+
assert isinstance(result, TypeTracerArray)
22+
assert result.dtype == np.dtype("bool")
23+
assert result.shape == ()
24+
25+
26+
def test_any():
27+
buffer = TypeTracerArray._new(np.dtype("float32"), (4,))
28+
result = nplike.any(buffer, axis=None)
29+
assert isinstance(result, TypeTracerArray)
30+
assert result.dtype == np.dtype("bool")
31+
assert result.shape == ()
32+
33+
buffer = TypeTracerArray._new(np.dtype("float32"), (4, 3))
34+
result = nplike.any(buffer, axis=None)
35+
assert isinstance(result, TypeTracerArray)
36+
assert result.dtype == np.dtype("bool")
37+
assert result.shape == ()

0 commit comments

Comments
 (0)