Skip to content

Commit df7fbff

Browse files
ikrommydianna
andauthored
fix: make to_packed work for typetracer backed IndexedOptionArray, BitMaskedArray, and UnionArray (#3608)
* make to_packed work for indexedoptionarray with typetracer * add test * better use known_data * fix other layouts too and add more tests * better like this * guard the legnth, not the nplike --------- Co-authored-by: Ianna Osborne <ianna.osborne@cern.ch>
1 parent 3e82eef commit df7fbff

File tree

5 files changed

+178
-6
lines changed

5 files changed

+178
-6
lines changed

src/awkward/contents/bitmaskedarray.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,9 +810,13 @@ def _to_packed(self, recursive: bool = True) -> Self:
810810
)
811811

812812
else:
813-
excess_length = math.ceil(self.length / 8.0)
813+
if self.length is not unknown_length:
814+
excess_length = math.ceil(self.length / 8.0)
815+
else:
816+
excess_length = unknown_length
814817
if (
815818
self._mask.length is not unknown_length
819+
and excess_length is not unknown_length
816820
and self._mask.length == excess_length
817821
):
818822
mask = self._mask

src/awkward/contents/bytemaskedarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def to_ByteMaskedArray(self, valid_when):
335335
def to_BitMaskedArray(self, valid_when, lsb_order):
336336
if not self._backend.nplike.known_data:
337337
self._touch_data(recursive=False)
338-
if self._backend.nplike.known_data:
338+
if self.length is not unknown_length:
339339
excess_length = math.ceil(self.length / 8.0)
340340
else:
341341
excess_length = unknown_length

src/awkward/contents/indexedoptionarray.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import copy
6-
import operator
76
from collections.abc import Mapping, MutableMapping, Sequence
87

98
import awkward as ak
@@ -1722,7 +1721,7 @@ def _to_packed(self, recursive: bool = True) -> Self:
17221721
nplike = self._backend.nplike
17231722
original_index = self._index.data
17241723
is_none = original_index < 0
1725-
num_none = operator.index(nplike.count_nonzero(is_none))
1724+
num_none = nplike.index_as_shape_item(nplike.count_nonzero(is_none))
17261725
new_index = nplike.empty(self._index.length, dtype=self._index.dtype)
17271726
if isinstance(nplike, Jax):
17281727
new_index = new_index.at[is_none].set(-1)

src/awkward/contents/unionarray.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,15 +1674,21 @@ def continuation():
16741674
def _to_packed(self, recursive: bool = True) -> Self:
16751675
nplike = self._backend.nplike
16761676
tags = self._tags.data
1677-
original_index = index = self._index.data[: tags.shape[0]]
1677+
original_index = index = self._index.data[
1678+
: nplike.shape_item_as_index(tags.shape[0])
1679+
]
16781680

16791681
contents = list(self._contents)
16801682

16811683
for tag in range(len(self._contents)):
16821684
is_tag = tags == tag
16831685
num_tag = nplike.index_as_shape_item(nplike.count_nonzero(is_tag))
16841686

1685-
if len(contents[tag]) > num_tag:
1687+
if (
1688+
contents[tag].length is not unknown_length
1689+
and num_tag is not unknown_length
1690+
and contents[tag].length > num_tag
1691+
):
16861692
if original_index is index:
16871693
index = index.copy()
16881694
new_index_values = self._backend.nplike.arange(
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
2+
3+
from __future__ import annotations
4+
5+
import numpy as np
6+
import pytest
7+
8+
import awkward as ak
9+
10+
11+
@pytest.mark.parametrize("forget_length", [True, False])
12+
@pytest.mark.parametrize("recursive", [True, False])
13+
def test_numpy_array(forget_length, recursive):
14+
matrix = np.arange(64).reshape(8, -1)
15+
layout = ak.contents.NumpyArray(matrix[:, 0])
16+
assert (
17+
layout.to_typetracer(forget_length).to_packed(recursive).form
18+
== layout.to_packed(recursive).form
19+
)
20+
21+
22+
@pytest.mark.parametrize("forget_length", [True, False])
23+
@pytest.mark.parametrize("recursive", [True, False])
24+
def test_empty_array(forget_length, recursive):
25+
layout = ak.contents.EmptyArray()
26+
assert (
27+
layout.to_typetracer(forget_length).to_packed(recursive).form
28+
== layout.to_packed(recursive).form
29+
)
30+
31+
32+
@pytest.mark.parametrize("forget_length", [True, False])
33+
@pytest.mark.parametrize("recursive", [True, False])
34+
def test_indexed_option_array(forget_length, recursive):
35+
index = ak.index.Index64(np.r_[0, -1, 2, -1, 4])
36+
content = ak.contents.NumpyArray(np.arange(8))
37+
layout = ak.contents.IndexedOptionArray(index, content)
38+
assert (
39+
layout.to_typetracer(forget_length).to_packed(recursive).form
40+
== layout.to_packed(recursive).form
41+
)
42+
43+
44+
@pytest.mark.parametrize("forget_length", [True, False])
45+
@pytest.mark.parametrize("recursive", [True, False])
46+
def test_indexed_array(forget_length, recursive):
47+
index = ak.index.Index64(np.array([0, 1, 2, 3, 6, 7, 8]))
48+
content = ak.contents.NumpyArray(np.arange(10))
49+
layout = ak.contents.IndexedArray(index, content)
50+
assert (
51+
layout.to_typetracer(forget_length).to_packed(recursive).form
52+
== layout.to_packed(recursive).form
53+
)
54+
55+
56+
@pytest.mark.parametrize("forget_length", [True, False])
57+
@pytest.mark.parametrize("recursive", [True, False])
58+
def test_list_array(forget_length, recursive):
59+
content = ak.contents.NumpyArray(
60+
np.array([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9])
61+
)
62+
starts = ak.index.Index64(np.array([0, 3, 3, 5, 6]))
63+
stops = ak.index.Index64(np.array([3, 3, 5, 6, 9]))
64+
layout = ak.contents.ListArray(starts, stops, content)
65+
assert (
66+
layout.to_typetracer(forget_length).to_packed(recursive).form
67+
== layout.to_packed(recursive).form
68+
)
69+
70+
71+
@pytest.mark.parametrize("forget_length", [True, False])
72+
@pytest.mark.parametrize("recursive", [True, False])
73+
def test_list_offset_array(forget_length, recursive):
74+
content = ak.contents.NumpyArray(
75+
np.array([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9])
76+
)
77+
offsets = ak.index.Index64(np.array([0, 3, 3, 5, 6]))
78+
layout = ak.contents.ListOffsetArray(offsets, content)
79+
assert (
80+
layout.to_typetracer(forget_length).to_packed(recursive).form
81+
== layout.to_packed(recursive).form
82+
)
83+
84+
85+
@pytest.mark.parametrize("forget_length", [True, False])
86+
@pytest.mark.parametrize("recursive", [True, False])
87+
def test_unmasked_array(forget_length, recursive):
88+
content = ak.contents.NumpyArray(
89+
np.array([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9])
90+
)
91+
layout = ak.contents.UnmaskedArray(content)
92+
assert (
93+
layout.to_typetracer(forget_length).to_packed(recursive).form
94+
== layout.to_packed(recursive).form
95+
)
96+
97+
98+
@pytest.mark.parametrize("forget_length", [True, False])
99+
@pytest.mark.parametrize("recursive", [True, False])
100+
def test_union_array(forget_length, recursive):
101+
a = ak.contents.NumpyArray(np.arange(4))
102+
b = ak.contents.NumpyArray(np.arange(4) + 4)
103+
c = ak.contents.RegularArray(ak.contents.NumpyArray(np.arange(12)), 3)
104+
layout = ak.contents.UnionArray.simplified(
105+
ak.index.Index8([1, 1, 2, 2, 0, 0]),
106+
ak.index.Index64([0, 1, 0, 1, 0, 1]),
107+
[a, b, c],
108+
)
109+
assert (
110+
layout.to_typetracer(forget_length).to_packed(recursive).form
111+
== layout.to_packed(recursive).form
112+
)
113+
114+
115+
@pytest.mark.parametrize("forget_length", [True, False])
116+
@pytest.mark.parametrize("recursive", [True, False])
117+
def test_record_array(forget_length, recursive):
118+
a = ak.contents.NumpyArray(np.arange(10))
119+
b = ak.contents.NumpyArray(np.arange(10) * 2 + 4)
120+
layout = ak.contents.RecordArray([a, b], None, 5)
121+
assert (
122+
layout.to_typetracer(forget_length).to_packed(recursive).form
123+
== layout.to_packed(recursive).form
124+
)
125+
126+
127+
@pytest.mark.parametrize("forget_length", [True, False])
128+
@pytest.mark.parametrize("recursive", [True, False])
129+
def test_regular_array(forget_length, recursive):
130+
content = ak.contents.NumpyArray(np.arange(10))
131+
layout = ak.contents.RegularArray(content, 3)
132+
assert (
133+
layout.to_typetracer(forget_length).to_packed(recursive).form
134+
== layout.to_packed(recursive).form
135+
)
136+
137+
138+
@pytest.mark.parametrize("forget_length", [True, False])
139+
@pytest.mark.parametrize("recursive", [True, False])
140+
def test_bit_masked_array(forget_length, recursive):
141+
mask = ak.index.IndexU8(np.array([0b10101010]))
142+
content = ak.contents.NumpyArray(np.arange(16))
143+
layout = ak.contents.BitMaskedArray(mask, content, False, 8, False)
144+
assert (
145+
layout.to_typetracer(forget_length).to_packed(recursive).form
146+
== layout.to_packed(recursive).form
147+
)
148+
149+
150+
@pytest.mark.parametrize("forget_length", [True, False])
151+
@pytest.mark.parametrize("recursive", [True, False])
152+
def test_byte_masked_array(forget_length, recursive):
153+
mask = ak.index.Index8(np.array([1, 0, 1, 0, 1, 0, 1, 0]))
154+
content = ak.contents.NumpyArray(np.arange(16))
155+
layout = ak.contents.ByteMaskedArray(
156+
mask,
157+
content,
158+
False,
159+
)
160+
assert (
161+
layout.to_typetracer(forget_length).to_packed(recursive).form
162+
== layout.to_packed(recursive).form
163+
)

0 commit comments

Comments
 (0)