Skip to content

Commit cad3815

Browse files
committed
fix(backport): Mean/WeightedMean accumulator addition fix (#537)
* fix: Mean accumulator math fix * fix: also fix WeighedMean * tests: add summing tests
1 parent c0fc56d commit cad3815

File tree

5 files changed

+124
-11
lines changed

5 files changed

+124
-11
lines changed

include/bh_python/accumulators/mean.hpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,20 @@ struct mean {
5959
}
6060

6161
mean& operator+=(const mean& rhs) noexcept {
62-
if(count != 0 || rhs.count != 0) {
63-
const auto tmp = value * count + rhs.value * rhs.count;
64-
count += rhs.count;
65-
value = tmp / count;
66-
}
62+
if(rhs.count == 0)
63+
return *this;
64+
65+
const auto mu1 = value;
66+
const auto mu2 = rhs.value;
67+
const auto n1 = count;
68+
const auto n2 = rhs.count;
69+
70+
count += rhs.count;
71+
value = (n1 * mu1 + n2 * mu2) / count;
6772
sum_of_deltas_squared += rhs.sum_of_deltas_squared;
73+
sum_of_deltas_squared
74+
+= n1 * (value - mu1) * (value - mu1) + n2 * (value - mu2) * (value - mu2);
75+
6876
return *this;
6977
}
7078

include/bh_python/accumulators/weighted_mean.hpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,23 @@ struct weighted_mean {
6262
}
6363

6464
weighted_mean& operator+=(const weighted_mean& rhs) {
65-
if(sum_of_weights != 0 || rhs.sum_of_weights != 0) {
66-
const auto tmp = value * sum_of_weights + rhs.value * rhs.sum_of_weights;
67-
sum_of_weights += rhs.sum_of_weights;
68-
sum_of_weights_squared += rhs.sum_of_weights_squared;
69-
value = tmp / sum_of_weights;
70-
}
65+
if(rhs.sum_of_weights == 0)
66+
return *this;
67+
68+
const auto mu1 = value;
69+
const auto mu2 = rhs.value;
70+
const auto n1 = sum_of_weights;
71+
const auto n2 = rhs.sum_of_weights;
72+
73+
sum_of_weights += rhs.sum_of_weights;
74+
sum_of_weights_squared += rhs.sum_of_weights_squared;
75+
76+
value = (n1 * mu1 + n2 * mu2) / sum_of_weights;
77+
7178
_sum_of_weighted_deltas_squared += rhs._sum_of_weighted_deltas_squared;
79+
_sum_of_weighted_deltas_squared
80+
+= n1 * (value - mu1) * (value - mu1) + n2 * (value - mu2) * (value - mu2);
81+
7282
return *this;
7383
}
7484

include/bh_python/register_accumulator.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ py::class_<A> register_accumulator(py::module acc, Args&&... args) {
5353

5454
.def(py::self *= double())
5555

56+
.def("__add__",
57+
[](const A& self, const A& other) {
58+
A retval(self);
59+
retval += other;
60+
return retval;
61+
})
62+
5663
// The c++ name is replaced with the Python name here
5764
.def("__repr__",
5865
[](py::object self) {

tests/test_accumulators.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# -*- coding: utf-8 -*-
2+
from pytest import approx
23

34
import boost_histogram as bh
45

@@ -87,3 +88,47 @@ def test_mean():
8788
assert a == bh.accumulators.Mean(3, 2, 1)
8889

8990
assert repr(a) == "Mean(count=3, value=2, variance=1)"
91+
92+
93+
def test_sum_mean():
94+
a = bh.accumulators.Mean()
95+
a.fill([1, 2, 3])
96+
97+
b = bh.accumulators.Mean()
98+
b.fill([5, 6])
99+
100+
c = bh.accumulators.Mean()
101+
c.fill([1, 2, 3, 5, 6])
102+
103+
ab = a + b
104+
assert ab.value == approx(c.value)
105+
assert ab.variance == approx(c.variance)
106+
assert ab.count == approx(c.count)
107+
108+
a += b
109+
assert a.value == approx(c.value)
110+
assert a.variance == approx(c.variance)
111+
assert a.count == approx(c.count)
112+
113+
114+
def test_sum_weighed_mean():
115+
a = bh.accumulators.WeightedMean()
116+
a.fill([1, 2, 3], weight=[2, 5, 3])
117+
118+
b = bh.accumulators.WeightedMean()
119+
b.fill([5, 6], weight=[12, 17])
120+
121+
c = bh.accumulators.WeightedMean()
122+
c.fill([1, 2, 3, 5, 6], weight=[2, 5, 3, 12, 17])
123+
124+
ab = a + b
125+
assert ab.value == approx(c.value)
126+
assert ab.variance == approx(c.variance)
127+
assert ab.sum_of_weights == approx(c.sum_of_weights)
128+
assert ab.sum_of_weights_squared == approx(c.sum_of_weights_squared)
129+
130+
a += b
131+
assert a.value == approx(c.value)
132+
assert a.variance == approx(c.variance)
133+
assert a.sum_of_weights == approx(c.sum_of_weights)
134+
assert a.sum_of_weights_squared == approx(c.sum_of_weights_squared)

tests/test_storage.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import pytest
44
from numpy.testing import assert_array_equal
5+
from pytest import approx
56

67
import boost_histogram as bh
78

@@ -216,3 +217,45 @@ def test_modify_weights_by_view():
216217

217218
assert hist.view().value[0] == pytest.approx(1.5)
218219
assert hist.view().value[1] == pytest.approx(2)
220+
221+
222+
# Issue #531
223+
def test_summing_mean_storage():
224+
np.random.seed(42)
225+
values = np.random.normal(loc=1.3, scale=0.1, size=1000)
226+
samples = np.random.normal(loc=1.3, scale=0.1, size=1000)
227+
228+
h1 = bh.Histogram(bh.axis.Regular(20, -1, 3), storage=bh.storage.Mean())
229+
h1.fill(values, sample=samples)
230+
231+
h2 = bh.Histogram(bh.axis.Regular(1, -1, 3), storage=bh.storage.Mean())
232+
h2.fill(values, sample=samples)
233+
234+
s1 = h1.sum()
235+
s2 = h2.sum()
236+
237+
assert s1.value == approx(s2.value)
238+
assert s1.count == approx(s2.count)
239+
assert s1.variance == approx(s2.variance)
240+
241+
242+
# Issue #531
243+
def test_summing_weighted_mean_storage():
244+
np.random.seed(42)
245+
values = np.random.normal(loc=1.3, scale=0.1, size=1000)
246+
samples = np.random.normal(loc=1.3, scale=0.1, size=1000)
247+
weights = np.random.uniform(0.1, 5, size=1000)
248+
249+
h1 = bh.Histogram(bh.axis.Regular(20, -1, 3), storage=bh.storage.WeightedMean())
250+
h1.fill(values, sample=samples, weight=weights)
251+
252+
h2 = bh.Histogram(bh.axis.Regular(1, -1, 3), storage=bh.storage.WeightedMean())
253+
h2.fill(values, sample=samples, weight=weights)
254+
255+
s1 = h1.sum()
256+
s2 = h2.sum()
257+
258+
assert s1.value == approx(s2.value)
259+
assert s1.sum_of_weights == approx(s2.sum_of_weights)
260+
assert s1.sum_of_weights_squared == approx(s2.sum_of_weights_squared)
261+
assert s1.variance == approx(s2.variance)

0 commit comments

Comments
 (0)