Skip to content

Commit 4ebce75

Browse files
authored
fix: support pick_set inside a larger expression (#793)
* tests: adding pick inside a larger expression Signed-off-by: Henry Schreiner <[email protected]> * fix: allow picking and pick_set to coexist Signed-off-by: Henry Schreiner <[email protected]> Signed-off-by: Henry Schreiner <[email protected]>
1 parent 5b0e5a3 commit 4ebce75

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

src/boost_histogram/_internal/hist.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,9 @@ def __getitem__( # noqa: C901
927927
new_reduced.view(flow=True)[...] = reduced.view(flow=True)[tuple_slice]
928928
reduced = new_reduced
929929
integrations = {i - sum(j <= i for j in pick_each) for i in integrations}
930+
pick_set = {
931+
i - sum(j <= i for j in pick_each): v for i, v in pick_set.items()
932+
}
930933
for slice_ in slices:
931934
slice_.iaxis -= sum(j <= slice_.iaxis for j in pick_each)
932935

@@ -945,9 +948,8 @@ def __getitem__( # noqa: C901
945948
selection = copy.copy(pick_set[i])
946949
ax = reduced.axis(i)
947950
if ax.traits_ordered:
948-
raise RuntimeError(
949-
f"Axis {i} is not a categorical axis, cannot pick with list"
950-
)
951+
msg = f"Axis {i} is not a categorical axis, cannot pick with list: {ax}"
952+
raise RuntimeError(msg)
951953

952954
if ax.traits_overflow and ax.size not in pick_set[i]:
953955
selection.append(ax.size)

tests/test_internal_histogram.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,27 @@ def test_str_categories_histogram():
107107
assert hist[bh.loc("c")] == 1
108108

109109

110+
# Issue 715
111+
112+
113+
def test_select_many():
114+
hist = bh.Histogram(
115+
bh.axis.StrCategory(["a", "b"]),
116+
bh.axis.StrCategory(["x", "y", "z"]),
117+
bh.axis.Regular(10, 0, 1),
118+
)
119+
120+
pick_a = hist[bh.loc("a"), ...]
121+
with pytest.warns(UserWarning):
122+
pick_b = pick_a[[bh.loc("x"), bh.loc("y")], ...]
123+
124+
with pytest.warns(UserWarning):
125+
pick = hist[bh.loc("a"), [bh.loc("x"), bh.loc("y")], ...]
126+
127+
assert pick_b.axes[0] == pick.axes[0]
128+
assert len(pick_b.axes) == len(pick.axes)
129+
130+
110131
def test_growing_histogram():
111132
hist = bh.Histogram(
112133
bh.axis.Regular(10, 0, 1, growth=True), storage=bh.storage.Int64()

0 commit comments

Comments
 (0)