Skip to content

Commit 928a7de

Browse files
authored
feat: fix operators on arrays, adds support for it (#417)
* feat: support more forms of operators * feat: shift work to the view as needed * fix: better broadcasting, supports flow if present
1 parent 2aea1f4 commit 928a7de

File tree

3 files changed

+140
-60
lines changed

3 files changed

+140
-60
lines changed

include/bh_python/register_histogram.hpp

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,7 @@ auto register_histogram(py::module& m, const char* name, const char* desc) {
5858
return a;
5959
})
6060

61-
.def(py::self + py::self)
62-
// .def(py::self + value_type())
6361
.def(py::self += py::self)
64-
// .def(py::self += value_type())
6562

6663
.def("__eq__",
6764
[](const histogram_t& self, const py::object& other) {
@@ -90,22 +87,21 @@ auto register_histogram(py::module& m, const char* name, const char* desc) {
9087

9188
;
9289

93-
// Atomics for example do not support these operations
90+
// Protection against an overzealous warning system
91+
// https://bugs.llvm.org/show_bug.cgi?id=43124
92+
#ifdef __clang__
93+
#pragma GCC diagnostic push
94+
#pragma GCC diagnostic ignored "-Wself-assign-overloaded"
95+
#endif
9496
def_optionally(hist,
95-
bh::detail::has_operator_rmul<histogram_t, double>{},
96-
py::self *= double());
97+
bh::detail::has_operator_rdiv<histogram_t, histogram_t>{},
98+
py::self /= py::self);
9799
def_optionally(hist,
98-
bh::detail::has_operator_rmul<histogram_t, double>{},
99-
py::self * double());
100-
def_optionally(hist,
101-
bh::detail::has_operator_rmul<histogram_t, double>{},
102-
double() * py::self);
103-
def_optionally(hist,
104-
bh::detail::has_operator_rdiv<histogram_t, double>{},
105-
py::self /= double());
106-
def_optionally(hist,
107-
bh::detail::has_operator_rdiv<histogram_t, double>{},
108-
py::self / double());
100+
bh::detail::has_operator_rmul<histogram_t, histogram_t>{},
101+
py::self *= py::self);
102+
#ifdef __clang__
103+
#pragma GCC diagnostic pop
104+
#endif
109105

110106
hist.def(
111107
"to_numpy",

src/boost_histogram/_internal/hist.py

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,6 @@ def _fill_cast(value):
4040
return np.ascontiguousarray(value)
4141

4242

43-
def _hist_or_val(other):
44-
return other._hist if hasattr(other, "_hist") else other
45-
46-
4743
def _arg_shortcut(item):
4844
msg = "Developer shortcut: will be removed in a future version"
4945
if isinstance(item, tuple) and len(item) == 3:
@@ -185,22 +181,18 @@ def __array__(self):
185181
return self.view(False)
186182

187183
def __add__(self, other):
188-
if hasattr(other, "_hist"):
189-
return self._new_hist(self._hist.__add__(other._hist))
190-
else:
191-
retval = self.copy()
192-
retval += other
193-
return retval
184+
result = self.copy()
185+
return result.__iadd__(other)
194186

195187
def __iadd__(self, other):
196188
if isinstance(other, (int, float)) and other == 0:
197189
return self
198-
if hasattr(other, "_hist"):
199-
self._hist.__iadd__(other._hist)
200-
# Addition may change category axes
201-
self.axes = self._generate_axes_()
202-
return self
203-
return NotImplemented
190+
self._compute_inplace_op("__iadd__", other)
191+
192+
# Addition may change the axes if they can grow
193+
self.axes = self._generate_axes_()
194+
195+
return self
204196

205197
def __radd__(self, other):
206198
return self + other
@@ -213,46 +205,55 @@ def __ne__(self, other):
213205

214206
# If these fail, the underlying object throws the correct error
215207
def __mul__(self, other):
216-
return self._new_hist(self._hist.__mul__(other))
208+
result = self.copy()
209+
return result._compute_inplace_op("__imul__", other)
217210

218211
def __rmul__(self, other):
219212
return self * other
220213

221-
def __imul__(self, other):
222-
self._hist.__imul__(_hist_or_val(other))
223-
return self
224-
225214
def __truediv__(self, other):
226-
if isinstance(other, Histogram):
227-
result = self.copy()
228-
result.__itruediv__(other)
229-
return result
230-
else:
231-
return self._new_hist(self._hist.__truediv__(_hist_or_val(other)))
215+
result = self.copy()
216+
return result._compute_inplace_op("__itruediv__", other)
232217

233218
def __div__(self, other):
234-
if isinstance(other, Histogram):
235-
result = self.copy()
236-
result.__idiv__(other)
237-
return result
238-
else:
239-
return self._new_hist(self._hist.__div__(_hist_or_val(other)))
219+
result = self.copy()
220+
return result._compute_inplace_op("__idiv__", other)
240221

241-
def __itruediv__(self, other):
222+
def _compute_inplace_op(self, name, other):
242223
if isinstance(other, Histogram):
243-
view = self.view(flow=True)
244-
view.__itruediv__(other.view(flow=True))
224+
getattr(self._hist, name)(other._hist)
225+
elif isinstance(other, _histograms):
226+
getattr(self._hist, name)(other)
227+
elif hasattr(other, "shape"):
228+
if len(other.shape) != self.ndim:
229+
raise ValueError(
230+
"Number of dimensions {0} must match histogram {1}".format(
231+
len(other.shape), self.ndim
232+
)
233+
)
234+
elif all((a == b or a == 1) for a, b in zip(other.shape, self.shape)):
235+
view = self.view(flow=False)
236+
getattr(view, name)(other)
237+
elif all((a == b or a == 1) for a, b in zip(other.shape, self.axes.extent)):
238+
view = self.view(flow=True)
239+
getattr(view, name)(other)
240+
else:
241+
raise ValueError(
242+
"Wrong shape, expected {0} or {1}".format(self.shape, self.extent)
243+
)
245244
else:
246-
self._hist.__itruediv__(_hist_or_val(other))
245+
view = self.view(flow=False)
246+
getattr(view, name)(other)
247247
return self
248248

249249
def __idiv__(self, other):
250-
if isinstance(other, Histogram):
251-
view = self.view(flow=True)
252-
view.__idiv__(other.view(flow=True))
253-
else:
254-
self._hist.__idiv__(_hist_or_val(other))
255-
return self
250+
return self._compute_inplace_op("__idiv__", other)
251+
252+
def __itruediv__(self, other):
253+
return self._compute_inplace_op("__itruediv__", other)
254+
255+
def __imul__(self, other):
256+
return self._compute_inplace_op("__imul__", other)
256257

257258
# TODO: Marked as too complex by flake8. Should be factored out a bit.
258259
@inject_signature("self, *args, weight=None, sample=None, threads=None")

tests/test_histogram.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,3 +1078,86 @@ def test_shape():
10781078
def test_empty_shape():
10791079
h = bh.Histogram()
10801080
assert h.shape == ()
1081+
1082+
1083+
# issue #416 a
1084+
def test_hist_division():
1085+
edges = [0, 0.25, 0.5, 0.75, 1, 2, 3, 4, 7, 10]
1086+
edges = [-x for x in reversed(edges)] + edges[1:]
1087+
1088+
h = bh.Histogram(bh.axis.Variable(edges))
1089+
h[...] = 1
1090+
h1 = h.copy()
1091+
1092+
dens = h.view().copy()
1093+
dens /= h.axes[0].widths * h.sum()
1094+
1095+
h1 /= h.axes[0].widths * h.sum()
1096+
1097+
assert_array_equal(h1.view(), dens)
1098+
1099+
1100+
# issue #416 b
1101+
# def test_hist_division():
1102+
# edges = [0, .25, .5, .75, 1, 2, 3, 4, 7, 10]
1103+
# edges = [-x for x in reversed(edges)] + edges[1:]
1104+
#
1105+
# h = bh.Histogram(bh.axis.Variable(edges))
1106+
# h[...] = 1
1107+
#
1108+
# dens = h.view().copy() / h.axes[0].widths * h.sum()
1109+
# h1 = h.copy()
1110+
#
1111+
# h1[:] /= h.axes[0].widths * h.sum()
1112+
#
1113+
# assert_allclose(h1.view(), dens)
1114+
1115+
1116+
def test_add_hists():
1117+
edges = [0, 0.25, 0.5, 0.75, 1, 2, 3, 4, 7, 10]
1118+
edges = [-x for x in reversed(edges)] + edges[1:]
1119+
1120+
h = bh.Histogram(bh.axis.Variable(edges))
1121+
h[...] = 1
1122+
1123+
h1 = h.copy()
1124+
h1 += h.view()
1125+
1126+
h2 = h.copy()
1127+
h2 += h1
1128+
1129+
h3 = h.copy()
1130+
h3 += 5
1131+
1132+
assert_array_equal(h, 1)
1133+
assert_array_equal(h1, 2)
1134+
assert_array_equal(h2, 3)
1135+
assert_array_equal(h3, 6)
1136+
1137+
1138+
def test_add_broadcast():
1139+
h = bh.Histogram(bh.axis.Regular(10, 0, 1), bh.axis.Regular(20, 0, 1))
1140+
1141+
h1 = h.copy()
1142+
h2 = h.copy()
1143+
1144+
h1[...] = 1
1145+
assert h1.view().sum() == 10 * 20
1146+
assert h1.view(flow=True).sum() == 10 * 20
1147+
1148+
h2 = h + [[1]]
1149+
assert h2.sum() == 10 * 20
1150+
assert h2.sum(flow=True) == 10 * 20
1151+
1152+
h3 = h + np.ones((10, 20))
1153+
assert h3.sum() == 10 * 20
1154+
assert h3.sum(flow=True) == 10 * 20
1155+
1156+
h4 = h + np.ones((12, 22))
1157+
assert h4.view(flow=True).sum() == 12 * 22
1158+
1159+
h5 = h + np.ones((10, 1))
1160+
assert h5.sum(flow=True) == 10 * 20
1161+
1162+
h5 = h + np.ones((1, 22))
1163+
assert h5.sum(flow=True) == 12 * 22

0 commit comments

Comments
 (0)