Skip to content

Commit 0a8e283

Browse files
authored
fix: (better) error when filling an int with a float (#876)
* fix: fix int fill with float rounding issue Signed-off-by: Henry Schreiner <[email protected]> * fix: cast to array first Signed-off-by: Henry Schreiner <[email protected]> * fix: disallow float filling for integer axes Signed-off-by: Henry Schreiner <[email protected]> --------- Signed-off-by: Henry Schreiner <[email protected]>
1 parent d3cdd07 commit 0a8e283

File tree

4 files changed

+40
-8
lines changed

4 files changed

+40
-8
lines changed

docs/user-guide/storage.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ Mean
8484

8585
This storage tracks a "Profile", that is, the mean value of the accumulation instead of the sum.
8686
It stores the count (as a double), the mean, and a term that is used to compute the variance. When
87-
filling, you can add a ``sample=`` term.
87+
filling, you must add a ``sample=`` term.
8888

8989

9090
WeightedMean

include/bh_python/fill.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,29 @@ inline decltype(auto) special_cast<c_array_t<std::string>>(py::handle x) {
9595
return py::cast<B>(x);
9696
}
9797

98+
// Make sure float arrays don't get cast to integers (-.5 rounds to 0!)
99+
template <>
100+
inline decltype(auto) special_cast<c_array_t<int>>(py::handle x) {
101+
auto np = py::module::import("numpy");
102+
auto dtype = py::cast<py::array>(x).dtype();
103+
if(dtype.equal(np.attr("bool_")) || dtype.equal(np.attr("int8"))
104+
|| dtype.equal(np.attr("int16")) || dtype.equal(np.attr("int32"))
105+
|| dtype.equal(np.attr("int64")))
106+
return py::cast<c_array_t<int>>(x);
107+
throw py::type_error("Only integer arrays supported when targeting integer axes");
108+
}
109+
110+
// Produce a type error for passing float to int
111+
template <>
112+
inline decltype(auto) special_cast<int>(py::handle x) {
113+
try {
114+
return py::cast<int>(x);
115+
} catch(std::runtime_error&) {
116+
throw py::type_error(
117+
"Only integer values supported when targeting integer axes");
118+
}
119+
}
120+
98121
using arg_t = variant::variant<c_array_t<double>,
99122
double,
100123
c_array_t<int>,

tests/test_histogram.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,18 @@ def test_fill_int_1d():
117117
h[-3]
118118

119119

120+
def test_fill_int_with_float_single_1d():
121+
h = bh.Histogram(bh.axis.Integer(-1, 2))
122+
with pytest.raises(TypeError):
123+
h.fill(0.3)
124+
125+
126+
def test_fill_int_with_float_array_1d():
127+
h = bh.Histogram(bh.axis.Integer(-1, 2))
128+
with pytest.raises(TypeError):
129+
h.fill([-0.3, 0.3])
130+
131+
120132
def test_fill_1d(flow):
121133
h = bh.Histogram(bh.axis.Regular(3, -1, 2, underflow=flow, overflow=flow))
122134
with pytest.raises(ValueError):
@@ -935,7 +947,7 @@ def ia(*args):
935947
a.fill(np.empty(2), 1)
936948
with pytest.raises(ValueError):
937949
a.fill(np.empty(2), np.empty(3))
938-
with pytest.raises(ValueError):
950+
with pytest.raises(TypeError):
939951
a.fill("abc")
940952

941953
with pytest.raises(IndexError):
@@ -976,12 +988,9 @@ def ia(*args):
976988
platform.machine() == "ppc64le", reason="ppc64le bug (TBD)", strict=False
977989
)
978990
def test_fill_with_sequence_1():
979-
def fa(*args):
980-
return np.array(args, dtype=float)
981-
982991
a = bh.Histogram(bh.axis.Integer(0, 3), storage=bh.storage.Weight())
983-
v = fa(-1, 0, 1, 2, 3, 4)
984-
w = fa(2, 3, 4, 5, 6, 7)
992+
v = np.array([-1, 0, 1, 2, 3, 4], dtype=int)
993+
w = np.array([2, 3, 4, 5, 6, 7], dtype=float)
985994
a.fill(v, weight=w)
986995
a.fill((0, 1), weight=(2, 3))
987996

tests/test_internal_histogram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def test_int_cat_hist():
280280
assert_array_equal(h.view(), [1, 1, 1])
281281
assert h.sum() == 3
282282

283-
with pytest.raises(RuntimeError):
283+
with pytest.raises(TypeError):
284284
h.fill(0.5)
285285

286286

0 commit comments

Comments
 (0)