Skip to content

Commit 460ef90

Browse files
authored
fix: allow axis + other (#982)
* fix: allow axis + other Signed-off-by: Henry Schreiner <[email protected]> * fix: support unordered axes Signed-off-by: Henry Schreiner <[email protected]> * refactor: pull out a helper function Signed-off-by: Henry Schreiner <[email protected]> --------- Signed-off-by: Henry Schreiner <[email protected]>
1 parent b44204e commit 460ef90

File tree

3 files changed

+58
-22
lines changed

3 files changed

+58
-22
lines changed

src/boost_histogram/histogram.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
if "_core" not in str(err):
4545
raise
4646

47-
new_msg = "Did you forget to compile boost-histogram? Use CMake or Setuptools to build, see the readme."
47+
new_msg = "Did you forget to compile boost-histogram? Use CMake or scikit-build-core to build, see the readme."
4848

4949
if sys.version_info >= (3, 11):
5050
err.add_note(new_msg)
@@ -178,6 +178,24 @@ def _expand_ellipsis(indexes: Iterable[Any], rank: int) -> list[Any]:
178178
raise IndexError("an index can only have a single ellipsis ('...')")
179179

180180

181+
def _combine_group_contents(
182+
new_view: np.typing.NDArray[Any],
183+
reduced_view: np.typing.NDArray[Any],
184+
i: int,
185+
j: int,
186+
jj: int,
187+
) -> None:
188+
"""
189+
Combine two views into one, in-place. This is used for threaded filling.
190+
"""
191+
pos = [slice(None)] * (i)
192+
if new_view.dtype.names:
193+
for field in new_view.dtype.names:
194+
new_view[(*pos, jj, ...)][field] += reduced_view[(*pos, j, ...)][field]
195+
else:
196+
new_view[(*pos, jj, ...)] += reduced_view[(*pos, j, ...)]
197+
198+
181199
H = TypeVar("H", bound="Histogram")
182200

183201

@@ -992,18 +1010,18 @@ def __getitem__(self: H, index: IndexingExpr) -> H | float | Accumulator:
9921010

9931011
for new_j, group in enumerate(groups):
9941012
for _ in range(group):
995-
pos = [slice(None)] * (i)
996-
if new_view.dtype.names:
997-
for field in new_view.dtype.names:
998-
new_view[(*pos, new_j + new_j_base, ...)][
999-
field
1000-
] += reduced_view[(*pos, j, ...)][field]
1001-
else:
1002-
new_view[(*pos, new_j + new_j_base, ...)] += (
1003-
reduced_view[(*pos, j, ...)]
1004-
)
1013+
_combine_group_contents(
1014+
new_view, reduced_view, i, j, new_j + new_j_base
1015+
)
10051016
j += 1
10061017

1018+
if (
1019+
old_axis.traits_underflow
1020+
and not axes[i].traits_ordered
1021+
and axes[i].traits_overflow
1022+
):
1023+
_combine_group_contents(new_view, reduced_view, i, 0, -1)
1024+
10071025
reduced = new_reduced
10081026

10091027
# Will be updated below

src/boost_histogram/tag.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,19 +130,19 @@ def __init__(
130130
edges: Sequence[int | float] | None = None,
131131
axis: PlottableAxis | None = None,
132132
) -> None:
133-
if (
134-
sum(i is not None for i in [factor_or_axis, factor, groups, edges, axis])
135-
!= 1
136-
):
133+
if isinstance(factor_or_axis, int):
134+
factor = factor_or_axis
135+
elif factor_or_axis is not None:
136+
axis = factor_or_axis
137+
138+
total_args = sum(i is not None for i in [factor, groups, edges])
139+
if total_args != 1 and axis is None:
137140
raise ValueError("Exactly one argument should be provided")
141+
138142
self.groups = groups
139143
self.edges = edges
140144
self.axis = axis
141145
self.factor = factor
142-
if isinstance(factor_or_axis, int):
143-
self.factor = factor_or_axis
144-
elif factor_or_axis is not None:
145-
self.axis = factor_or_axis
146146

147147
def __repr__(self) -> str:
148148
repr_str = f"{self.__class__.__name__}"
@@ -177,10 +177,10 @@ def group_mapping(self, axis: PlottableAxis) -> Sequence[int]:
177177
return [self.factor] * len(axis)
178178
if self.edges is not None or self.axis is not None:
179179
newedges = None
180-
if self.axis is not None and hasattr(self.axis, "edges"):
181-
newedges = self.axis.edges
182-
elif self.edges is not None:
180+
if self.edges is not None:
183181
newedges = self.edges
182+
elif self.axis is not None and hasattr(self.axis, "edges"):
183+
newedges = self.axis.edges
184184

185185
if newedges is not None and hasattr(axis, "edges"):
186186
assert newedges[0] == axis.edges[0], "Edges must start at first bin"

tests/test_histogram.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,24 @@ def test_rebin_1d_flow():
688688
assert_array_equal(hs.view(flow=True), [2, 2])
689689

690690

691+
def test_rebin_change_axis_int():
692+
h = bh.Histogram(bh.axis.Regular(5, 0, 5))
693+
h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5])
694+
hs = h[bh.rebin(edges=[0, 3, 5.0], axis=bh.axis.Integer(10, 12))]
695+
assert_array_equal(hs.view(), [2, 2])
696+
assert_array_equal(hs.view(flow=True), [1, 2, 2, 1])
697+
assert_array_equal(hs.axes.edges[0], [10, 11, 12])
698+
699+
700+
def test_rebin_change_axis_cat():
701+
h = bh.Histogram(bh.axis.Regular(5, 0, 5))
702+
h.fill([-1, 1.1, 2.2, 3.3, 4.4, 5.5])
703+
hs = h[bh.rebin(groups=[2, 3], axis=bh.axis.StrCategory(["a", "b"]))]
704+
assert_array_equal(hs.view(), [1, 3])
705+
assert_array_equal(hs.view(flow=True), [1, 3, 4])
706+
assert_array_equal(list(hs.axes[0]), ["a", "b"])
707+
708+
691709
def test_shrink_rebin_1d():
692710
h = bh.Histogram(bh.axis.Regular(20, 0, 4))
693711
h.fill(1.1)

0 commit comments

Comments
 (0)